All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.netflix.rewrite.refactor.op.AddImport.kt Maven / Gradle / Ivy

/**
 * Copyright 2016 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.netflix.rewrite.refactor.op

import com.netflix.rewrite.ast.*
import com.netflix.rewrite.refactor.RefactorVisitor
import com.netflix.rewrite.search.FindType

class AddImport(val clazz: String,
                val staticMethod: String? = null,
                val onlyIfReferenced: Boolean = false,
                override val ruleName: String = "add-import"): RefactorVisitor() {

    private var coveredByExistingImport = false
    private val packageComparator = PackageComparator()

    private lateinit var cu: Tr.CompilationUnit
    private val classType by lazy { Type.Class.build(clazz) }

    private var hasReferences: Boolean = false

    override fun visitCompilationUnit(cu: Tr.CompilationUnit): List> {
        this.cu = cu
        hasReferences = FindType(clazz).visit(cu).isNotEmpty()
        return super.visitCompilationUnit(cu)
    }

    override fun visitImport(import: Tr.Import): List> {
        val importedType = import.qualid.simpleName

        if (addingStaticImport()) {
            if (import.matches(clazz) && import.static && (importedType == staticMethod || importedType == "*")) {
                coveredByExistingImport = true
            }
        }
        else {
            if (import.matches(clazz)) {
                coveredByExistingImport = true
            } else if (import.qualid.target.printTrimmed() == classType.packageName() && importedType == "*") {
                coveredByExistingImport = true
            }
        }

        return emptyList()
    }

    override fun visitEnd(): List> {
        if(onlyIfReferenced && !hasReferences)
            return emptyList()

        if(classType.packageName().isEmpty())
            return emptyList()

        val lastPrior = lastPriorImport()
        val classImportField = TreeBuilder.buildName(clazz, format(" ")) as Tr.FieldAccess

        val importStatementToAdd = if(addingStaticImport()) {
            Tr.Import(Tr.FieldAccess(classImportField, Tr.Ident.build(staticMethod!!, null, Formatting.Empty), null, Formatting.Empty), true, Formatting.Infer)
        } else Tr.Import(classImportField, false, Formatting.Infer)

        return if(coveredByExistingImport) {
            emptyList()
        }
        else if(lastPrior == null) {
            transform(ruleName) { copy(imports = listOf(importStatementToAdd) + cu.imports) }
        }
        else {
            transform(ruleName) {
                copy(imports = cu.imports.takeWhile { it !== lastPrior } + listOf(lastPrior, importStatementToAdd) +
                        cu.imports.takeLastWhile { it !== lastPrior })
            }
        }
    }

    fun lastPriorImport(): Tr.Import? {
        return cu.imports.lastOrNull { import ->
            // static imports go after all non-static imports
            if(addingStaticImport() && !import.static)
                return@lastOrNull true

            // non-static imports should always go before static imports
            if(!addingStaticImport() && import.static)
                return@lastOrNull false

            val comp = packageComparator.compare(import.qualid.target.printTrimmed(),
                    if(addingStaticImport()) clazz else classType.packageName())
            if(comp == 0) {
                import.qualid.simpleName < (if(addingStaticImport()) staticMethod!! else classType.className())
            }
            else comp < 0
        }
    }

    fun addingStaticImport() = staticMethod is String
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy