All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.netflix.rewrite.ast.Tree.kt Maven / Gradle / Ivy

/**
 * Copyright 2016 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.netflix.rewrite.ast

import com.fasterxml.jackson.annotation.*
import com.koloboke.collect.map.hash.HashObjObjMaps
import com.netflix.rewrite.ast.visitor.AstVisitor
import com.netflix.rewrite.ast.visitor.PrintVisitor
import com.netflix.rewrite.ast.visitor.RetrieveCursorVisitor
import com.netflix.rewrite.refactor.Refactor
import com.netflix.rewrite.search.*
import java.io.Serializable
import java.lang.IllegalArgumentException
import java.util.*
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import java.util.regex.Pattern
import kotlin.reflect.KClass

@JsonTypeInfo(use = JsonTypeInfo.Id.MINIMAL_CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@c")
interface Tree {
    val formatting: Formatting

    /**
     * An overload that allows us to create a copy of any Tree element, optionally
     * changing formatting
     */
    fun  changeFormatting(fmt: Formatting = formatting): T
    
    /**
     * An id that can be used to identify a particular AST element, even after transformations have taken place on it
     */
    val id: Long

    fun  accept(v: AstVisitor): R = v.default(null)
    fun changeFormatting(): Tree = throw NotImplementedError()
    fun printTrimmed() = print().trimIndent().trim()
    fun print() = PrintVisitor().visit(this)
}

interface Statement : Tree

interface Expression : Tree {
    val type: Type?

}

/**
 * A tree representing a simple or fully qualified name
 */
interface NameTree : Tree {
    val type: Type?
}

/**
 * A tree identifying a type (e.g. a simple or fully qualified class name, a primitive, array, or parameterized type)
 */
interface TypeTree: NameTree

@JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator::class, property = "@ref")
sealed class Tr : Serializable, Tree {
    
    companion object {
        fun id(): Long = UUID.randomUUID().leastSignificantBits
    }
    data class Annotation(var annotationType: NameTree,
                          var args: Arguments?,
                          override val type: Type?,
                          override val formatting: Formatting = Formatting.Empty,
                          override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R =
            v.reduce(v.visitAnnotation(this), v.visitExpression(this))

        data class Arguments(val args: List, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()
    }

    data class ArrayAccess(val indexed: Expression,
                           val dimension: Dimension,
                           override val type: Type?,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitArrayAccess(this), v.visitExpression(this))

        data class Dimension(val index: Expression, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()
    }

    data class ArrayType(val elementType: TypeTree,
                         val dimensions: List,
                         override val formatting: Formatting = Formatting.Empty,
                         override val id: Long = id()): TypeTree, Expression, Tr() {

        @Transient
        override val type = elementType.type

        override fun  accept(v: AstVisitor): R = v.visitArrayType(this)

        data class Dimension(val inner: Empty, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()
    }

    data class Assert(val condition: Expression,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()): Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitAssert(this)
    }

    data class Assign(val variable: Expression,
                      val assignment: Expression,
                      override val type: Type?,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()) : Expression, Statement, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitAssign(this), v.visitExpression(this))
    }

    data class AssignOp(val variable: Expression,
                        val operator: Operator,
                        val assignment: Expression,
                        override val type: Type?,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()) : Expression, Statement, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitAssignOp(this), v.visitExpression(this))

        sealed class Operator: Tr() {
            // Arithmetic
            data class Addition(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Subtraction(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Multiplication(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Division(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Modulo(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()

            // Bitwise
            data class BitAnd(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class BitOr(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class BitXor(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class LeftShift(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class RightShift(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class UnsignedRightShift(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
        }
    }

    data class Binary(val left: Expression,
                      val operator: Operator,
                      val right: Expression,
                      override val type: Type?,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitBinary(this), v.visitExpression(this))

        sealed class Operator: Tr() {
            // Arithmetic
            data class Addition(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Subtraction(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Multiplication(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Division(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Modulo(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()

            // Relational
            data class LessThan(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class GreaterThan(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class LessThanOrEqual(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class GreaterThanOrEqual(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class Equal(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class NotEqual(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()

            // Bitwise
            data class BitAnd(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class BitOr(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class BitXor(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class LeftShift(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class RightShift(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class UnsignedRightShift(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()

            // Boolean
            data class Or(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
            data class And(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Operator()
        }
    }

    data class Block(val static: Tr.Empty?,
                                  val statements: List,
                                  override val formatting: Formatting = Formatting.Empty,
                                  val endOfBlockSuffix: String,
                                  override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitBlock(this)
    }

    data class Break(val label: Ident?,
                     override val formatting: Formatting = Formatting.Empty,
                     override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitBreak(this)
    }

    data class Case(val pattern: Expression?, // null for the default case
                    val statements: List,
                    override val formatting: Formatting = Formatting.Empty,
                    override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitCase(this)
    }

    data class Catch(val param: Parentheses,
                     val body: Block,
                     override val formatting: Formatting = Formatting.Empty,
                     override val id: Long = id()) : Tr() {

        override fun  accept(v: AstVisitor): R = v.visitCatch(this)
    }

    data class ClassDecl(val annotations: List,
                         val modifiers: List,
                         val kind: Kind,
                         val name: Ident,
                         val typeParams: TypeParameters?,
                         val extends: TypeTree?,
                         val implements: List,
                         val body: Block,
                         val type: Type?,
                         override val formatting: Formatting = Formatting.Empty,
                         override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitClassDecl(this)

        /**
         * Values will always occur before any fields, constructors, or methods
         */
        fun enumValues(): EnumValueSet? = body.statements.find { it is EnumValueSet } as EnumValueSet?

        fun fields(): List = body.statements.filterIsInstance()
        fun methods(): List = body.statements.filterIsInstance()

        sealed class Kind: Tr() {
            data class Class(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Kind()
            data class Enum(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Kind()
            data class Interface(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Kind()
            data class Annotation(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Kind()
        }

        /**
         * Find fields defined on this class, but do not include inherited fields up the type hierarchy
         */
        fun findFields(clazz: Class<*>): List = FindFields(clazz.name).visit(this)

        fun findFields(clazz: String): List = FindFields(clazz).visit(this)

        /**
         * Find fields defined up the type hierarchy, but do not include fields defined directly on this class
         */
        fun findInheritedFields(clazz: Class<*>): List = FindInheritedFields(clazz.name).visit(this)

        fun findInheritedFields(clazz: String): List = FindInheritedFields(clazz).visit(this)

        fun findMethodCalls(signature: String): List = FindMethods(signature).visit(this)

        fun findType(clazz: Class<*>): List = FindType(clazz.name).visit(this)
        fun findType(clazz: String): List = FindType(clazz).visit(this)

        fun findAnnotations(signature: String): List = FindAnnotations(signature).visit(this)

        fun hasType(clazz: Class<*>): Boolean = HasType(clazz.name).visit(this)
        fun hasType(clazz: String): Boolean = HasType(clazz).visit(this)

        fun  hasModifier(modifier: Class) = modifiers.any { it::class.java == modifier }

        fun hasModifier(modifier: String) = Modifier::class.nestedClasses
                .filter { it.simpleName?.toLowerCase() == modifier.toLowerCase() }
                .filterIsInstance>()
                .filter { hasModifier(it.java) }
                .any()

        val isEnum: Boolean @JsonIgnore get() = kind is Kind.Enum
        val isClass: Boolean @JsonIgnore get() = kind is Kind.Class
        val isInterface: Boolean @JsonIgnore get() = kind is Kind.Interface
        val isAnnotation: Boolean @JsonIgnore get() = kind is Kind.Annotation

        @Transient val simpleName: String = name.simpleName
    }

    data class CompilationUnit(val sourcePath: String,
                               val packageDecl: Package?,
                               val imports: List,
                               val classes: List,
                               override val formatting: Formatting = Formatting.Empty,
                               override val id: Long = id()) : Tr() {

        override fun  accept(v: AstVisitor): R = v.visitCompilationUnit(this)

        fun hasImport(clazz: Class<*>): Boolean = HasImport(clazz.name).visit(this)
        fun hasImport(clazz: String): Boolean = HasImport(clazz).visit(this)

        fun hasType(clazz: Class<*>): Boolean = HasType(clazz.name).visit(this)
        fun hasType(clazz: String): Boolean = HasType(clazz).visit(this)

        fun findMethodCalls(signature: String): List = FindMethods(signature).visit(this)

        fun findType(clazz: Class<*>): List = FindType(clazz.name).visit(this)
        fun findType(clazz: String): List = FindType(clazz).visit(this)

        fun refactor() = Refactor(this)

        fun refactor(ops: Refactor.() -> Unit): Refactor {
            val r = refactor()
            ops(r)
            return r
        }

        fun refactor(ops: Consumer): Refactor {
            val r = refactor()
            ops.accept(r)
            return r
        }

        fun firstClass() = classes.firstOrNull()

        fun cursor(t: Tree?): Cursor? = RetrieveCursorVisitor(t).visit(this)
        fun cursor(id: Long): Cursor? = RetrieveCursorVisitor(id).visit(this)

        /**
         * Because Jackson will not place a polymorphic type tag on the root of the AST when we are serializing a list of ASTs together
         */
        @Suppress("ProtectedInFinal")
        @get:JsonProperty("@c")
        protected val jacksonPolymorphicTypeTag = ".Tr\$CompilationUnit"
    }

    data class Continue(val label: Ident?,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitContinue(this)
    }

    data class DoWhileLoop(val body: Statement,
                           val condition: Parentheses,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitDoWhileLoop(this)
    }

    data class Empty(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Statement, Expression, TypeTree, NameTree, Tr() {
        override val type: Type? = null
        override fun  accept(v: AstVisitor): R = v.reduce(v.visitEmpty(this), v.visitExpression(this))
    }

    data class EnumValue(val name: Ident,
                         val initializer: Arguments?,
                         override val formatting: Formatting = Formatting.Empty,
                         override val id: Long = id()): Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitEnumValue(this)

        data class Arguments(val args: List, override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Tr()

        @Transient val simpleName: String = name.simpleName
    }

    data class EnumValueSet(val enums: List,
                            val terminatedWithSemicolon: Boolean,
                            override val formatting: Formatting = Formatting.Empty,
                            override val id: Long = id()): Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitEnumValueSet(this)
    }

    data class FieldAccess(val target: Expression,
                           val name: Ident,
                           override val type: Type?,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()) : Expression, NameTree, TypeTree, Tr() {

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitFieldAccess(this), v.visitExpression(this))

        /**
         * Make debugging a bit easier
         */
        override fun toString(): String = "FieldAccess(${printTrimmed()})"

        @Transient val simpleName: String = name.simpleName

        /**
         * @return For expressions like String.class, this casts target expression to a `NameTree`.
         * If the field access is not * a reference to a class type, returns null.
         */
        fun asClassReference(): NameTree? = when(target) {
            is NameTree -> {
                val fqn = when (type) {
                    is Type.Class -> type.fullyQualifiedName
                    is Type.ShallowClass -> type.fullyQualifiedName
                    else -> null
                }
                if(fqn == "java.lang.Class") target else null
            }
            else -> null
        }
    }

    data class ForEachLoop(val control: Control,
                           val body: Statement,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitForEachLoop(this)

        data class Control(val variable: VariableDecls,
                           val iterable: Expression,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()): Tr()
    }

    data class ForLoop(val control: Control,
                       val body: Statement,
                       override val formatting: Formatting = Formatting.Empty,
                       override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitForLoop(this)

        data class Control(val init: Statement, // either Tr.Empty, Tr.VariableDecls, or Tr.Assign
                           val condition: Expression,
                           val update: List,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()): Tr()
    }

    data class Ident private constructor(private val ident: IdentFlyweight,
                                         override val formatting: Formatting,
                                         override val id: Long) : Expression, NameTree, TypeTree, Tr() {

        override val type: Type? get() = ident.type
        val simpleName: String get() = ident.simpleName

        fun copy(simpleName: String = this.simpleName, type: Type? = this.type, formatting: Formatting = this.formatting, id: Long = this.id) =
            copy(ident.copy(simpleName, type), formatting, id)

        companion object {
            private val flyweights = HashObjObjMaps.newMutableMap>()

            @JvmStatic @JsonCreator
            fun build(simpleName: String, type: Type? = null, formatting: Formatting = Formatting.Empty, id: Long = id()): Ident {
                val fly = synchronized(flyweights) {
                    flyweights
                            .getOrPut(simpleName, { HashObjObjMaps.newMutableMap(mapOf(type to IdentFlyweight(simpleName, type))) })
                            .getOrPut(type, { IdentFlyweight(simpleName, type) })
                }
                return Ident(fly, formatting, id)
            }
        }

        private data class IdentFlyweight(val simpleName: String, val type: Type?): Serializable

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitIdentifier(this), v.visitExpression(this))

        /**
         * Make debugging a bit easier
         */
        override fun toString(): String = "Ident(${printTrimmed()})"
    }

    data class If(val ifCondition: Parentheses,
                  val thenPart: Statement,
                  val elsePart: Else?,
                  override val formatting: Formatting = Formatting.Empty,
                  override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitIf(this)

        data class Else(val statement: Statement,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()): Tr()
    }

    data class Import(val qualid: FieldAccess,
                      val static: Boolean,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()) : Tr() {

        override fun  accept(v: AstVisitor): R = v.visitImport(this)

        fun matches(clazz: String): Boolean = when (qualid.simpleName) {
            "*" -> qualid.target.printTrimmed() == clazz.split('.').takeWhile { it[0].isLowerCase() }.joinToString(".")
            else -> qualid.printTrimmed() == clazz
        }
    }

    data class InstanceOf(val expr: Expression,
                          val clazz: Tree,
                          override val type: Type?,
                          override val formatting: Formatting = Formatting.Empty,
                          override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitInstanceOf(this), v.visitExpression(this))
    }

    data class Label(val label: Ident,
                     val statement: Statement,
                     override val formatting: Formatting = Formatting.Empty,
                     override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitLabel(this)
    }

    data class Lambda(val paramSet: Parameters,
                      val arrow: Arrow,
                      val body: Tree,
                      override val type: Type?,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitLambda(this), v.visitExpression(this))

        data class Arrow(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Tr()

        data class Parameters(val parenthesized: Boolean,
                              val params: List, // Tr.VariableDecls or Tr.Empty
                              override val formatting: Formatting = Formatting.Empty,
                              override val id: Long = id()): Tr()
    }

    data class Literal(val value: Any?,
                       val valueSource: String,
                       override val type: Type.Primitive, // Strings are included
                       override val formatting: Formatting = Formatting.Empty,
                       override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitLiteral(this), v.visitExpression(this))

        /**
         * Primitive values sometimes contain a prefix and suffix that hold the special characters,
         * e.g. the "" around String, the L at the end of a long, etc.
         */
        fun  transformValue(transform: (T) -> Any): String {
            val valueMatcher = "(.*)${Pattern.quote(value.toString())}(.*)".toRegex().find(this.printTrimmed().replace("\\", ""))
            @Suppress("UNREACHABLE_CODE")
            return when (valueMatcher) {
                is MatchResult -> {
                    val (prefix, suffix) = valueMatcher.groupValues.drop(1)
                    @Suppress("UNCHECKED_CAST")
                    return "$prefix${transform(value as T)}$suffix"
                }
                else -> error("Encountered a literal `$this` that could not be transformed")
            }
        }
    }

    data class MemberReference(val containing: Expression,
                               val reference: Ident,
                               override val type: Type?,
                               override val formatting: Formatting = Formatting.Empty,
                               override val id: Long = id()): Expression, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitMemberReference(this)
    }

    data class MethodDecl(val annotations: List,
                          val modifiers: List,
                          val typeParameters: TypeParameters?,
                          val returnTypeExpr: TypeTree?, // null for constructors
                          val name: Ident,
                          val params: Parameters,
                          val throws: Throws?,
                          val body: Block?,
                          val defaultValue: Default?,
                          override val formatting: Formatting = Formatting.Empty,
                          override val id: Long = id()) : Tr() {

        override fun  accept(v: AstVisitor): R = v.visitMethod(this)

        fun hasType(clazz: Class<*>): Boolean = HasType(clazz.name).visit(this)
        fun hasType(clazz: String): Boolean = HasType(clazz).visit(this)

        data class Parameters(val params: List, override val formatting: Formatting = Formatting.Empty,
                              override val id: Long = id()): Tr()

        data class Throws(val exceptions: List,
                          override val formatting: Formatting = Formatting.Empty,
                          override val id: Long = id()): Tr()

        data class Default(val value: Expression,
                           override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()): Tr()

        fun  hasModifier(modifier: Class) = modifiers.any { it::class.java == modifier }

        fun hasModifier(modifier: String) = Modifier::class.nestedClasses
                .filter { it.simpleName?.toLowerCase() == modifier.toLowerCase() }
                .filterIsInstance>()
                .filter { hasModifier(it.java) }
                .any()

        fun findAnnotations(signature: String): List = FindAnnotations(signature).visit(this)

        @Transient val simpleName: String = name.simpleName
    }

    data class MethodInvocation(val select: Expression?,
                                val typeParameters: TypeParameters?,
                                val name: Ident,
                                val args: Arguments,
                                override val type: Type.Method?,
                                override val formatting: Formatting = Formatting.Empty,
                                override val id: Long = id()) : Expression, Statement, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitMethodInvocation(this), v.visitExpression(this))

        fun returnType(): Type? = type?.resolvedSignature?.returnType

        fun firstMethodInChain(): MethodInvocation =
                (select as? MethodInvocation)?.firstMethodInChain() ?: this

        fun argExpressions() = args.args.filter { it !is Tr.Empty }

        data class Arguments(val args: List, override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Tr()
        data class TypeParameters(val params: List, override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Tr()

        @Transient val simpleName: String = name.simpleName
    }

    @JsonTypeInfo(use = JsonTypeInfo.Id.MINIMAL_CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@c")
    sealed class Modifier : Tr() {
        data class Default(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Modifier()
        data class Public(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Protected(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Private(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Abstract(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Static(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Final(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Native(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Strictfp(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Synchronized(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Modifier()
        data class Transient(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Modifier()
        data class Volatile(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Modifier()
    }

    data class MultiCatch(val alternatives: List,
                          override val formatting: Formatting = Formatting.Empty,
                          override val id: Long = id()): TypeTree, Tr() {
        @get:JsonIgnore
        override val type: Type by lazy { Type.MultiCatchType(alternatives.map { it.type }.filterNotNull()) }

        override fun  accept(v: AstVisitor): R = v.visitMultiCatch(this)
    }

    data class NewArray(val typeExpr: TypeTree?, // null in the case of an array as an annotation parameter
                        val dimensions: List,
                        val initializer: Initializer?,
                        override val type: Type?,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitNewArray(this), v.visitExpression(this))

        data class Dimension(val size: Expression, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()

        data class Initializer(val elements: List, override val formatting: Formatting = Formatting.Empty,
                               override val id: Long = id()): Tr()
    }

    data class NewClass(val clazz: TypeTree,
                        val args: Arguments,
                        val classBody: Block?, // non-null for anonymous classes
                        override val type: Type?,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()) : Expression, Statement, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitNewClass(this), v.visitExpression(this))

        data class Arguments(val args: List, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()
    }

    data class Package(val expr: Expression, override val formatting: Formatting = Formatting.Empty, override val id: Long = id()) : Tr() {

        override fun  accept(v: AstVisitor): R = v.visitPackage(this)
    }

    data class ParameterizedType(val clazz: NameTree,
                                 val typeArguments: TypeArguments?,
                                 override val formatting: Formatting = Formatting.Empty,
                                 override val id: Long = id()): TypeTree, Expression, Tr() {

        @Transient override val type = clazz.type

        override fun  accept(v: AstVisitor): R = v.visitParameterizedType(this)

        data class TypeArguments(val args: List, /* TypeTree or Wildcard */
                                 override val formatting: Formatting = Formatting.Empty,
                                 override val id: Long = id()): Tr()
    }

    data class Parentheses(val tree: T,
                                        override val formatting: Formatting = Formatting.Empty,
                                        override val id: Long = id()) : Expression, Tr() {

        @Transient
        override val type = when(tree) {
            is Expression -> tree.type
            else -> null
        }

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitParentheses(this), v.visitExpression(this))
    }

    data class Primitive(override val type: Type.Primitive,
                         override val formatting: Formatting = Formatting.Empty,
                         override val id: Long = id()) : Expression, NameTree, TypeTree, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitPrimitive(this), v.visitExpression(this))
    }

    data class Return(val expr: Expression?,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitReturn(this)
    }

    data class Switch(val selector: Parentheses,
                      val cases: Block,
                      override val formatting: Formatting = Formatting.Empty,
                      override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitSwitch(this)
    }

    data class Synchronized(val lock: Parentheses,
                            val body: Block,
                            override val formatting: Formatting = Formatting.Empty,
                            override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitSynchronized(this)
    }

    data class Ternary(val condition: Expression,
                       val truePart: Expression,
                       val falsePart: Expression,
                       override val type: Type?,
                       override val formatting: Formatting = Formatting.Empty,
                       override val id: Long = id()) : Expression, Tr() {

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitTernary(this), v.visitExpression(this))
    }

    data class Throw(val exception: Expression,
                     override val formatting: Formatting = Formatting.Empty,
                     override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitThrow(this)
    }

    data class Try(val resources: Resources?,
                   val body: Block,
                   val catches: List,
                   val finally: Finally?,
                   override val formatting: Formatting = Formatting.Empty,
                   override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitTry(this)

        data class Resources(val decls: List, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()

        data class Finally(val block: Block, override val formatting: Formatting = Formatting.Empty,
                           override val id: Long = id()): Tr()
    }

    data class TypeCast(val clazz: Parentheses,
                        val expr: Expression,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()): Expression, Tr() {

        @Transient
        override val type = clazz.type

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitTypeCast(this), v.visitExpression(this))
    }

    data class TypeParameter(val annotations: List,
                             val name: NameTree,
                             val bounds: Bounds?,
                             override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()) : Tr() {

        override fun  accept(v: AstVisitor): R = v.visitTypeParameter(this)

        data class Bounds(val types: List,
                          override val formatting: Formatting = Formatting.Empty,
                          override val id: Long = id()) : Tr()
    }

    data class TypeParameters(val params: List,
                              override val formatting: Formatting = Formatting.Empty,
                              override val id: Long = id()): Tr() {

        override fun  accept(v: AstVisitor): R = v.visitTypeParameters(this)
    }

    /**
     * Increment and decrement operations are valid statements, other operations are not
     */
    data class Unary(val operator: Operator,
                     val expr: Expression,
                     override val type: Type?,
                     override val formatting: Formatting = Formatting.Empty,
                     override val id: Long = id()) : Expression, Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.reduce(v.visitUnary(this), v.visitExpression(this))

        sealed class Operator: Tr() {
            // only PostIncrement and PostDecrement may have non-empty formatting
            override val formatting: Formatting = Formatting.Empty

            // Arithmetic
            data class PreIncrement(override val id: Long = id()): Operator()
            data class PreDecrement(override val id: Long = id()): Operator()
            data class PostIncrement(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Operator()
            data class PostDecrement(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Operator()
            data class Positive(override val id: Long = id()): Operator()
            data class Negative(override val id: Long = id()): Operator()

            // Bitwise
            data class Complement(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Operator()

            // Boolean
            data class Not(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Operator()
        }
    }

    data class UnparsedSource(val source: String, override val formatting: Formatting = Formatting.Empty,
                              override val id: Long = id()): Expression, Statement, Tr() {

        @Transient
        override val type: Type? = null

        override fun  accept(v: AstVisitor): R =
                v.reduce(v.visitUnparsedSource(this), v.visitExpression(this))
    }

    data class VariableDecls(
            val annotations: List,
            val modifiers: List,
            val typeExpr: TypeTree?, // can be null when this is a lambda parameter with an inferred type expression
            val varArgs: Varargs?,
            val dimensionsBeforeName: List,
            val vars: List,
            override val formatting: Formatting = Formatting.Empty,
            override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitMultiVariable(this)

        data class Varargs(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Tr()

        data class Dimension(val whitespace: Tr.Empty, override val formatting: Formatting = Formatting.Empty,
                             override val id: Long = id()): Tr()

        data class NamedVar(val name: Ident,
                            val dimensionsAfterName: List, // thanks for making it hard, Java
                            val initializer: Expression?,
                            val type: Type?,
                            override val formatting: Formatting = Formatting.Empty,
                            override val id: Long = id()): Tr() {

            override fun  accept(v: AstVisitor): R = v.visitVariable(this)

            @Transient val simpleName: String = name.simpleName
        }

        fun  hasModifier(modifier: Class) = modifiers.any { it::class.java == modifier }

        fun hasModifier(modifier: String) = Modifier::class.nestedClasses
                .filter { it.simpleName?.toLowerCase() == modifier.toLowerCase() }
                .filterIsInstance>()
                .filter { hasModifier(it.java) }
                .any()

        fun findAnnotations(signature: String): List = FindAnnotations(signature).visit(this)
    }

    data class WhileLoop(val condition: Parentheses,
                         val body: Statement,
                         override val formatting: Formatting = Formatting.Empty,
                         override val id: Long = id()) : Statement, Tr() {

        override fun  accept(v: AstVisitor): R = v.visitWhileLoop(this)
    }

    data class Wildcard(val bound: Bound?,
                        val boundedType: NameTree?,
                        override val formatting: Formatting = Formatting.Empty,
                        override val id: Long = id()): Tr(), Expression {

        @Transient
        override val type = null

        override fun  accept(v: AstVisitor): R = v.visitWildcard(this)

        sealed class Bound: Tr() {
            data class Super(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Bound()
            data class Extends(override val formatting: Formatting = Formatting.Empty, override val id: Long = id()): Bound()
        }
    }

    /**
     * An overload that allows us to create a copy of any Tree element, optionally
     * changing formatting
     */
    override fun  changeFormatting(fmt: Formatting): T {
        @Suppress("UNCHECKED_CAST")
        return when(this) {
            is Tr.Annotation -> copy(formatting = fmt)
            is Tr.Annotation.Arguments -> copy(formatting = fmt)
            is Tr.ArrayAccess -> copy(formatting = fmt)
            is Tr.ArrayAccess.Dimension -> copy(formatting = fmt)
            is Tr.ArrayType -> copy(formatting = fmt)
            is Tr.ArrayType.Dimension -> copy(formatting = fmt)
            is Tr.Assert -> copy(formatting = fmt)
            is Tr.Assign -> copy(formatting = fmt)
            is Tr.AssignOp -> copy(formatting = fmt)
            is Tr.AssignOp.Operator -> when(this) {
                is Tr.AssignOp.Operator.Addition -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.Subtraction -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.Multiplication -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.Division -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.Modulo -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.BitAnd -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.BitOr -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.BitXor -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.LeftShift -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.RightShift -> copy(formatting = fmt)
                is Tr.AssignOp.Operator.UnsignedRightShift -> copy(formatting = fmt)
            }
            is Tr.Binary -> copy(formatting = fmt)
            is Tr.Binary.Operator -> when(this) {
                is Tr.Binary.Operator.Addition -> copy(formatting = fmt)
                is Tr.Binary.Operator.Subtraction -> copy(formatting = fmt)
                is Tr.Binary.Operator.Multiplication -> copy(formatting = fmt)
                is Tr.Binary.Operator.Division -> copy(formatting = fmt)
                is Tr.Binary.Operator.Modulo -> copy(formatting = fmt)
                is Tr.Binary.Operator.LessThan -> copy(formatting = fmt)
                is Tr.Binary.Operator.GreaterThan -> copy(formatting = fmt)
                is Tr.Binary.Operator.LessThanOrEqual -> copy(formatting = fmt)
                is Tr.Binary.Operator.GreaterThanOrEqual -> copy(formatting = fmt)
                is Tr.Binary.Operator.Equal -> copy(formatting = fmt)
                is Tr.Binary.Operator.NotEqual -> copy(formatting = fmt)
                is Tr.Binary.Operator.BitAnd -> copy(formatting = fmt)
                is Tr.Binary.Operator.BitOr -> copy(formatting = fmt)
                is Tr.Binary.Operator.BitXor -> copy(formatting = fmt)
                is Tr.Binary.Operator.LeftShift -> copy(formatting = fmt)
                is Tr.Binary.Operator.RightShift -> copy(formatting = fmt)
                is Tr.Binary.Operator.UnsignedRightShift -> copy(formatting = fmt)
                is Tr.Binary.Operator.Or -> copy(formatting = fmt)
                is Tr.Binary.Operator.And -> copy(formatting = fmt)
            }
            is Tr.Block<*> -> copy(formatting = fmt)
            is Tr.Break -> copy(formatting = fmt)
            is Tr.Case -> copy(formatting = fmt)
            is Tr.Catch -> copy(formatting = fmt)
            is Tr.ClassDecl -> copy(formatting = fmt)
            is Tr.ClassDecl.Kind -> when(this) {
                is Tr.ClassDecl.Kind.Class -> copy(formatting = fmt)
                is Tr.ClassDecl.Kind.Enum -> copy(formatting = fmt)
                is Tr.ClassDecl.Kind.Interface -> copy(formatting = fmt)
                is Tr.ClassDecl.Kind.Annotation -> copy(formatting = fmt)
            }
            is Tr.CompilationUnit -> copy(formatting = fmt)
            is Tr.Continue -> copy(formatting = fmt)
            is Tr.DoWhileLoop -> copy(formatting = fmt)
            is Tr.Empty -> copy(formatting = fmt)
            is Tr.EnumValue -> copy(formatting = fmt)
            is Tr.EnumValue.Arguments -> copy(formatting = fmt)
            is Tr.EnumValueSet -> copy(formatting = fmt)
            is Tr.FieldAccess -> copy(formatting = fmt)
            is Tr.ForEachLoop -> copy(formatting = fmt)
            is Tr.ForEachLoop.Control -> copy(formatting = fmt)
            is Tr.ForLoop -> copy(formatting = fmt)
            is Tr.ForLoop.Control -> copy(formatting = fmt)
            is Tr.Ident -> copy(formatting = fmt)
            is Tr.If -> copy(formatting = fmt)
            is Tr.If.Else -> copy(formatting = fmt)
            is Tr.Import -> copy(formatting = fmt)
            is Tr.InstanceOf -> copy(formatting = fmt)
            is Tr.Label -> copy(formatting = fmt)
            is Tr.Lambda -> copy(formatting = fmt)
            is Tr.Lambda.Parameters -> copy(formatting = fmt)
            is Tr.Lambda.Arrow -> copy(formatting = fmt)
            is Tr.Literal -> copy(formatting = fmt)
            is Tr.MemberReference -> copy(formatting = fmt)
            is Tr.MethodDecl -> copy(formatting = fmt)
            is Tr.MethodDecl.Default -> copy(formatting = fmt)
            is Tr.MethodDecl.Parameters -> copy(formatting = fmt)
            is Tr.MethodDecl.Throws -> copy(formatting = fmt)
            is Tr.MethodInvocation -> copy(formatting = fmt)
            is Tr.MethodInvocation.Arguments -> copy(formatting = fmt)
            is Tr.MethodInvocation.TypeParameters -> copy(formatting = fmt)
            is Tr.MultiCatch -> copy(formatting = fmt)
            is Tr.NewArray -> copy(formatting = fmt)
            is Tr.NewArray.Dimension -> copy(formatting = fmt)
            is Tr.NewArray.Initializer -> copy(formatting = fmt)
            is Tr.NewClass -> copy(formatting = fmt)
            is Tr.NewClass.Arguments -> copy(formatting = fmt)
            is Tr.Package -> copy(formatting = fmt)
            is Tr.ParameterizedType -> copy(formatting = fmt)
            is Tr.ParameterizedType.TypeArguments -> copy(formatting = fmt)
            is Tr.Parentheses<*> -> copy(formatting = fmt)
            is Tr.Primitive -> copy(formatting = fmt)
            is Tr.Return -> copy(formatting = fmt)
            is Tr.Switch -> copy(formatting = fmt)
            is Tr.Synchronized -> copy(formatting = fmt)
            is Tr.Ternary -> copy(formatting = fmt)
            is Tr.Throw -> copy(formatting = fmt)
            is Tr.Try -> copy(formatting = fmt)
            is Tr.Try.Resources -> copy(formatting = fmt)
            is Tr.Try.Finally -> copy(formatting = fmt)
            is Tr.TypeCast -> copy(formatting = fmt)
            is Tr.TypeParameter -> copy(formatting = fmt)
            is Tr.TypeParameter.Bounds -> copy(formatting = fmt)
            is Tr.TypeParameters -> copy(formatting = fmt)
            is Tr.Unary -> copy(formatting = fmt)
            is Tr.Unary.Operator -> when(this) {
                is Tr.Unary.Operator.PreIncrement -> copy() // do nothing
                is Tr.Unary.Operator.PreDecrement -> copy() // do nothing
                is Tr.Unary.Operator.PostIncrement -> copy(formatting = fmt)
                is Tr.Unary.Operator.PostDecrement -> copy(formatting = fmt)
                is Tr.Unary.Operator.Positive -> copy() // do nothing
                is Tr.Unary.Operator.Negative -> copy() // do nothing
                is Tr.Unary.Operator.Complement -> copy(formatting = fmt)
                is Tr.Unary.Operator.Not -> copy(formatting = fmt)
            }
            is Tr.UnparsedSource -> copy(formatting = fmt)
            is Tr.VariableDecls -> copy(formatting = fmt)
            is Tr.Modifier -> when(this) {
                is Tr.Modifier.Abstract -> copy(formatting = fmt)
                is Tr.Modifier.Default -> copy(formatting = fmt)
                is Tr.Modifier.Final -> copy(formatting = fmt)
                is Tr.Modifier.Native -> copy(formatting = fmt)
                is Tr.Modifier.Private -> copy(formatting = fmt)
                is Tr.Modifier.Protected -> copy(formatting = fmt)
                is Tr.Modifier.Public -> copy(formatting = fmt)
                is Tr.Modifier.Static -> copy(formatting = fmt)
                is Tr.Modifier.Strictfp -> copy(formatting = fmt)
                is Tr.Modifier.Synchronized -> copy(formatting = fmt)
                is Tr.Modifier.Transient -> copy(formatting = fmt)
                is Tr.Modifier.Volatile -> copy(formatting = fmt)
            }
            is Tr.VariableDecls.Varargs -> copy(formatting = fmt)
            is Tr.VariableDecls.Dimension -> copy(formatting = fmt)
            is Tr.VariableDecls.NamedVar -> copy(formatting = fmt)
            is Tr.WhileLoop -> copy(formatting = fmt)
            is Tr.Wildcard -> copy(formatting = fmt)
            is Tr.Wildcard.Bound -> when(this) {
                is Tr.Wildcard.Bound.Super -> copy(formatting = fmt)
                is Tr.Wildcard.Bound.Extends -> copy(formatting = fmt)
            }
        } as T
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy