All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.netflix.rewrite.ast.visitor.AstVisitor.kt Maven / Gradle / Ivy

/**
 * Copyright 2016 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.netflix.rewrite.ast.visitor

import com.netflix.rewrite.ast.*
import java.util.*

open class AstVisitor {
    var default: (Tree?) -> R

    constructor(default: R) {
        this.default = { default }
    }

    constructor(default: (Tree?) -> R) {
        this.default = default
    }

    /**
     * Some sensible defaults for reduce (boolean OR, list concatenation, or else just the value of r1).
     * Override if your particular visitor needs to reduce values in a different way.
     */
    @Suppress("UNUSED_PARAMETER", "UNCHECKED_CAST")
    open fun reduce(r1: R, r2: R): R = when (r1) {
        is Boolean -> (r1 || r2 as Boolean) as R
        is Iterable<*> -> r1.plus(r2 as Iterable<*>) as R
        else -> r1 ?: r2
    }

    private val cursorStack = Stack()

    fun cursor(): Cursor = Cursor(cursorStack.toList())

    fun visit(tree: Tree?): R =
            if (tree != null) {
                cursorStack.push(tree)
                val t = reduce(tree.accept(this), visitTree(tree))
                cursorStack.pop()
                t
            } else default(tree)

    private fun R.andThen(nodes: Iterable?): R =
            if (nodes != null) reduce(this, visit(nodes)) else this

    private fun R.andThen(node: Tree?): R = if (node != null) reduce(this, visit(node)) else this

    private fun R.andThenVisitTypeNames(names: Iterable?): R =
            if (names != null) names.fold(this) { acc, name -> reduce(acc, visitTypeName(name)) } else this

    private fun R.andThenVisitTypeName(name: NameTree?): R =
            if (name != null) reduce(this, visitTypeName(name)) else this

    fun visit(nodes: Iterable?): R =
            nodes?.let {
                var r: R = default(null)
                var first = true
                for (node in nodes) {
                    r = if (first) visit(node) else r.andThen(node)
                    first = false
                }
                r
            } ?: default(null)

    open fun visitTree(t: Tree): R = default(t)

    open fun visitAnnotation(annotation: Tr.Annotation): R =
            visit(annotation.annotationType)
                    .andThen(annotation.args?.args)
                    .andThenVisitTypeName(annotation.annotationType)

    open fun visitArrayAccess(arrayAccess: Tr.ArrayAccess): R =
            visit(arrayAccess.indexed)
                    .andThen(arrayAccess.dimension.index)

    open fun visitArrayType(arrayType: Tr.ArrayType): R =
            visit(arrayType.elementType)
                    .andThenVisitTypeName(arrayType.elementType)

    open fun visitAssert(assert: Tr.Assert): R =
            visit(assert.condition)

    open fun visitAssign(assign: Tr.Assign): R =
            visit(assign.variable)
                    .andThen(assign.assignment)

    open fun visitAssignOp(assign: Tr.AssignOp): R =
            visit(assign.variable)
                    .andThen(assign.assignment)

    open fun visitBinary(binary: Tr.Binary): R =
            visit(binary.left)
                    .andThen(binary.right)

    open fun visitBlock(block: Tr.Block): R =
            visit(block.statements)

    open fun visitBreak(breakStatement: Tr.Break): R =
            visit(breakStatement.label)

    open fun visitCase(case: Tr.Case): R =
            visit(case.pattern)
                    .andThen(case.statements)

    open fun visitCatch(catch: Tr.Catch): R =
            visit(catch.param)
                    .andThen(catch.body)

    open fun visitClassDecl(classDecl: Tr.ClassDecl): R =
            visit(classDecl.annotations)
                    .andThen(classDecl.modifiers)
                    .andThen(classDecl.name)
                    .andThen(classDecl.typeParams?.params)
                    .andThen(classDecl.extends)
                    .andThen(classDecl.implements)
                    .andThen(classDecl.body)
                    .andThenVisitTypeName(classDecl.extends)
                    .andThenVisitTypeNames(classDecl.implements)

    open fun visitCompilationUnit(cu: Tr.CompilationUnit): R = reduce(
            visit(cu.imports)
                    .andThen(cu.packageDecl)
                    .andThen(cu.classes),
            visitEnd()
    )

    open fun visitContinue(continueStatement: Tr.Continue): R =
            visit(continueStatement.label)

    open fun visitDoWhileLoop(doWhileLoop: Tr.DoWhileLoop): R =
            visit(doWhileLoop.condition)
                    .andThen(doWhileLoop.body)

    open fun visitEmpty(empty: Tr.Empty): R = default(empty)

    open fun visitEnd() = default(null)

    open fun visitEnumValue(enum: Tr.EnumValue): R =
            visit(enum.name)
                    .andThen(enum.initializer?.args)

    open fun visitEnumValueSet(enums: Tr.EnumValueSet): R =
            visit(enums.enums)

    open fun visitExpression(expr: Expression): R = default(expr)

    open fun visitFieldAccess(field: Tr.FieldAccess): R =
            visit(field.target)
                    .andThen(field.name)
                    .andThenVisitTypeName(field.asClassReference())

    open fun visitForLoop(forLoop: Tr.ForLoop) =
            visit(forLoop.control.init)
                    .andThen(forLoop.control.condition)
                    .andThen(forLoop.control.update)
                    .andThen(forLoop.body)

    open fun visitForEachLoop(forEachLoop: Tr.ForEachLoop): R =
            visit(forEachLoop.control.variable)
                    .andThen(forEachLoop.control.iterable)
                    .andThen(forEachLoop.body)

    open fun visitIdentifier(ident: Tr.Ident): R = default(ident)

    open fun visitIf(iff: Tr.If): R =
            visit(iff.ifCondition)
                    .andThen(iff.thenPart)
                    .andThen(iff.elsePart?.statement)

    open fun visitImport(import: Tr.Import): R =
            visit(import.qualid)

    open fun visitInstanceOf(instanceOf: Tr.InstanceOf): R =
            visit(instanceOf.expr)
                    .andThen(instanceOf.clazz)

    open fun visitLabel(label: Tr.Label): R =
            visit(label.label)
                    .andThen(label.statement)

    open fun visitLambda(lambda: Tr.Lambda): R =
            visit(lambda.paramSet.params)
                    .andThen(lambda.body)

    open fun visitLiteral(literal: Tr.Literal): R = default(literal)

    open fun visitMemberReference(memberRef: Tr.MemberReference): R =
            visit(memberRef.containing)
                    .andThen(memberRef.reference)

    open fun visitMethod(method: Tr.MethodDecl): R =
            visit(method.annotations)
                    .andThen(method.modifiers)
                    .andThen(method.typeParameters?.params)
                    .andThen(method.returnTypeExpr)
                    .andThen(method.name)
                    .andThen(method.params.params)
                    .andThen(method.throws?.exceptions)
                    .andThen(method.body)
                    .andThen(method.defaultValue)
                    .andThenVisitTypeName(method.returnTypeExpr)
                    .andThenVisitTypeNames(method.throws?.exceptions)

    open fun visitMethodInvocation(meth: Tr.MethodInvocation): R {
        val selectTypeVisit = if (meth.select is NameTree && meth.type?.hasFlags(Flag.Static) ?: false)
            visitTypeName(meth.select)
        else default(meth)

        return reduce(visit(meth.select)
                .andThen(meth.typeParameters?.params)
                .andThen(meth.name)
                .andThen(meth.args.args)
                .andThenVisitTypeNames(meth.typeParameters?.params), selectTypeVisit)
    }

    open fun visitMultiCatch(multiCatch: Tr.MultiCatch): R =
            visit(multiCatch.alternatives)
                    .andThenVisitTypeNames(multiCatch.alternatives)

    open fun visitMultiVariable(multiVariable: Tr.VariableDecls): R {
        val varTypeVisit = if (multiVariable.typeExpr !is Tr.MultiCatch) {
            multiVariable.typeExpr?.let { visitTypeName(it) } ?: default(multiVariable)
        } else default(multiVariable)

        return reduce(visit(multiVariable.annotations)
                .andThen(multiVariable.modifiers)
                .andThen(multiVariable.typeExpr)
                .andThen(multiVariable.vars), varTypeVisit)
    }

    open fun visitNewArray(newArray: Tr.NewArray): R =
            visit(newArray.typeExpr)
                    .andThen(newArray.dimensions.map { it.size })
                    .andThen(newArray.initializer?.elements)
                    .andThenVisitTypeName(newArray.typeExpr)

    open fun visitNewClass(newClass: Tr.NewClass): R =
            visit(newClass.clazz)
                    .andThen(newClass.args.args)
                    .andThen(newClass.classBody)
                    .andThenVisitTypeName(newClass.clazz)

    open fun visitPackage(pkg: Tr.Package): R =
            visit(pkg.expr)

    open fun visitParameterizedType(type: Tr.ParameterizedType): R =
            visit(type.clazz)
                    .andThen(type.typeArguments?.args)
                    .andThenVisitTypeName(type.clazz)
                    .andThenVisitTypeNames(type.typeArguments?.args?.filterIsInstance())

    open fun  visitParentheses(parens: Tr.Parentheses): R =
            visit(parens.tree)

    open fun visitPrimitive(primitive: Tr.Primitive): R =
            default(primitive)

    open fun visitReturn(retrn: Tr.Return): R =
            visit(retrn.expr)

    open fun visitSwitch(switch: Tr.Switch): R =
            visit(switch.selector)
                    .andThen(switch.cases)

    open fun visitSynchronized(synch: Tr.Synchronized): R =
            visit(synch.lock)
                    .andThen(synch.body)

    open fun visitTernary(ternary: Tr.Ternary): R =
            visit(ternary.condition)
                    .andThen(ternary.truePart)
                    .andThen(ternary.falsePart)

    open fun visitThrow(thrown: Tr.Throw): R =
            visit(thrown.exception)

    open fun visitTry(tryable: Tr.Try): R =
            visit(tryable.resources?.decls)
                    .andThen(tryable.body)
                    .andThen(tryable.catches)
                    .andThen(tryable.finally?.block)

    open fun visitTypeCast(typeCast: Tr.TypeCast): R =
            visit(typeCast.clazz)
                    .andThen(typeCast.expr)
                    .andThenVisitTypeName(typeCast.clazz.tree)

    open fun visitTypeParameter(typeParameter: Tr.TypeParameter): R =
            visit(typeParameter.annotations)
                    .andThen(typeParameter.name)
                    .andThen(typeParameter.bounds?.types)
                    .andThenVisitTypeNames(typeParameter.bounds?.types)

    open fun visitTypeParameters(typeParameters: Tr.TypeParameters): R =
            visit(typeParameters.params)

    open fun visitTypeName(name: NameTree): R = default(name)

    open fun visitUnary(unary: Tr.Unary): R = visit(unary.expr)

    open fun visitUnparsedSource(unparsed: Tr.UnparsedSource): R =
            default(unparsed)

    open fun visitVariable(variable: Tr.VariableDecls.NamedVar): R =
            visit(variable.name)
                    .andThen(variable.dimensionsAfterName)
                    .andThen(variable.initializer)

    open fun visitWhileLoop(whileLoop: Tr.WhileLoop): R =
            visit(whileLoop.condition)
                    .andThen(whileLoop.body)

    open fun visitWildcard(wildcard: Tr.Wildcard): R {
        return visit(wildcard.bound)
                .andThen(wildcard.boundedType)
                .andThenVisitTypeName(wildcard.boundedType)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy