All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.openrewrite.java.AddImportTest.kt Maven / Gradle / Ivy

/*
 * Copyright 2020 the original author or authors.
 * 

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at *

* https://www.apache.org/licenses/LICENSE-2.0 *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.openrewrite.java import org.assertj.core.api.Assertions import org.junit.jupiter.api.Test import org.openrewrite.* import org.openrewrite.Tree.randomId import org.openrewrite.java.marker.JavaSourceSet import org.openrewrite.java.tree.Flag import org.openrewrite.java.tree.J import org.openrewrite.java.tree.JavaType interface AddImportTest : JavaRecipeTest { fun addImports(vararg adds: () -> TreeVisitor<*, ExecutionContext>): Recipe = adds .map { add -> toRecipe(add) } .reduce { r1, r2 -> return r1.doNext(r2) } @Test fun dontDuplicateImports(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("org.springframework.http.HttpStatus", null, false) }, { AddImport("org.springframework.http.HttpStatus.Series", null, false) } ), before = "class A {}", after = """ import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus.Series; class A {} """ ) @Test fun dontDuplicateImports2(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("org.junit.jupiter.api.Test", null, false) } ), before = """ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class A {} """, after = """ import org.junit.jupiter.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class A {} """, cycles = 1, expectedCyclesThatMakeChanges = 1 ) @Test fun dontDuplicateImports3(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("org.junit.jupiter.api.Assertions", "assertNull", false) } ), before = """ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; class A {} """, after = """ import static org.junit.jupiter.api.Assertions.*; import java.util.List; class A {} """, cycles = 1, expectedCyclesThatMakeChanges = 1 ) @Test fun dontImportYourself(jp: JavaParser) = assertUnchanged( jp, recipe = addImports({ AddImport("com.myorg.A", null, false) }), before = """ package com.myorg; class A { } """ ) @Issue("https://github.com/openrewrite/rewrite/issues/777") @Test fun dontImportFromSamePackage(jp: JavaParser) = assertUnchanged( jp, recipe = addImports({ AddImport("com.myorg.B", null, false) }), dependsOn = arrayOf( """ package com.myorg; class B { } """ ), before = """ package com.myorg; class A { } """ ) @Issue("https://github.com/openrewrite/rewrite/issues/772") @Test fun importOrderingIssue(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("org.springframework.http.HttpHeaders", null, false) }, ), before = """ import javax.ws.rs.core.Response.ResponseBuilder; import java.util.Locale; class A {} """, after = """ import org.springframework.http.HttpHeaders; import javax.ws.rs.core.Response.ResponseBuilder; import java.util.Locale; class A {} """ ) @Test fun addMultipleImports(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) }, { AddImport("java.util.Set", null, false) } ), before = """ class A {} """, after = """ import java.util.List; import java.util.Set; class A {} """ ) @Test fun addNamedImport(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = "class A {}", after = """ import java.util.List; class A {} """ ) @Test fun doNotAddImportIfNotReferenced(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("java.util.List", null, true) } ), before = """ package a; class A {} """ ) @Test fun addImportInsertsNewMiddleBlock(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = """ package a; import com.sun.naming.*; import static java.util.Collections.*; class A {} """, after = """ package a; import com.sun.naming.*; import java.util.List; import static java.util.Collections.*; class A {} """ ) @Test fun addFirstImport(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = """ package a; class A {} """, after = """ package a; import java.util.List; class A {} """ ) @Issue("https://github.com/openrewrite/rewrite/issues/484") @Test fun addImportIfReferenced(jp: JavaParser) = assertChanged( jp, recipe = toRecipe { object : JavaIsoVisitor() { override fun visitClassDeclaration( classDecl: J.ClassDeclaration, ctx: ExecutionContext ): J.ClassDeclaration { val c = super.visitClassDeclaration(classDecl, ctx) var b = c.body if (ctx.getMessage("cyclesThatResultedInChanges", 0) == 0) { val t = JavaTemplate.builder( { cursor }, "BigDecimal d = BigDecimal.valueOf(1).setScale(1, RoundingMode.HALF_EVEN);" ) .imports("java.math.BigDecimal", "java.math.RoundingMode") .build() b = b.withTemplate(t, b.coordinates.lastStatement()) maybeAddImport("java.math.BigDecimal") maybeAddImport("java.math.RoundingMode") } return c.withBody(b) } } }, before = """ package a; class A { } """, after = """ package a; import java.math.BigDecimal; import java.math.RoundingMode; class A { BigDecimal d = BigDecimal.valueOf(1).setScale(1, RoundingMode.HALF_EVEN); } """ ) @Test fun doNotAddWildcardImportIfNotReferenced(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("java.util.*", null, true) } ), before = """ package a; class A {} """ ) @Test fun lastImportWhenFirstClassDeclarationHasJavadoc(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.Collections", "*", false) } ), before = """ import java.util.List; /** * My type */ class A {} """, after = """ import java.util.List; import static java.util.Collections.*; /** * My type */ class A {} """ ) @Test fun namedImportAddedAfterPackageDeclaration(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = """ package a; class A {} """, after = """ package a; import java.util.List; class A {} """ ) @Test fun importsAddedInAlphabeticalOrder(jp: JavaParser) { val otherPackages = listOf("c", "c.c", "c.c.c") val otherImports = otherPackages.mapIndexed { i, pkg -> "package $pkg;\npublic class C$i {}" } listOf("b" to 0, "c.b" to 1, "c.c.b" to 2).forEach { val (pkg, order) = it val expectedImports = otherPackages.mapIndexed { i, otherPkg -> "$otherPkg.C$i" }.toMutableList() expectedImports.add(order, "$pkg.B") assertChanged( jp, dependsOn = arrayOf( *otherImports.toTypedArray(), """ package $pkg; public class B {} """ ), recipe = addImports( { AddImport("$pkg.B", null, false) } ), before = """ package a; import c.C0; import c.c.C1; import c.c.c.C2; class A {} """, after = "package a;\n\n${expectedImports.joinToString("\n") { fqn -> "import $fqn;" }}\n\nclass A {}" ) jp.reset() } } @Test fun doNotAddImportIfAlreadyExists(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = """ package a; import java.util.List; class A {} """ ) @Test fun doNotAddImportIfCoveredByStarImport(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = """ package a; import java.util.*; class A {} """ ) @Test fun dontAddImportWhenClassHasNoPackage(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("C", null, false) } ), before = "class A {}" ) @Test fun dontAddImportForPrimitive(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("int", null, false) } ), before = "class A {}" ) @Test fun addNamedImportIfStarStaticImportExists(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.List", null, false) } ), before = """ package a; import static java.util.List.*; class A {} """, after = """ package a; import java.util.List; import static java.util.List.*; class A {} """ ) @Test fun addNamedStaticImport(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("java.util.Collections", "emptyList", false) } ), before = """ import java.util.*; class A {} """, after = """ import java.util.*; import static java.util.Collections.emptyList; class A {} """ ) @Issue("https://github.com/openrewrite/rewrite/issues/108") @Test fun addStaticImportField(jp: JavaParser) = assertChanged( jp, recipe = addImports( { AddImport("mycompany.Type", "FIELD", false) } ), dependsOn = arrayOf( """ package mycompany; public class Type { public static String FIELD; } """ ), before = "class A {}", after = """ import static mycompany.Type.FIELD; class A {} """ ) @Test fun dontAddStaticWildcardImportIfNotReferenced(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("java.util.Collections", "*", true) } ), before = """ package a; class A {} """ ) @Test fun addNamedStaticImportWhenReferenced(jp: JavaParser) = assertChanged( jp, recipe = object : Recipe() { override fun getDisplayName(): String { return "Test" } override fun getVisitor(): TreeVisitor<*, ExecutionContext> { return object : JavaIsoVisitor() { override fun visitMethodInvocation(m: J.MethodInvocation, ctx: ExecutionContext) = m.withSelect(null) } } }.doNext( addImports({ AddImport("java.util.Collections", "emptyList", true) }) ), before = """ package a; import java.util.List; class A { public A() { List list = java.util.Collections.emptyList(); } } """, after = """ package a; import java.util.List; import static java.util.Collections.emptyList; class A { public A() { List list = emptyList(); } } """ ) @Test fun doNotAddNamedStaticImportIfNotReferenced(jp: JavaParser) = assertUnchanged( jp, recipe = addImports( { AddImport("java.util.Collections", "emptyList", true) } ), before = """ package a; class A {} """ ) @Test fun addStaticWildcardImportWhenReferenced(jp: JavaParser) = assertChanged( jp, recipe = FixEmptyListMethodType().doNext( addImports( { AddImport("java.util.Collections", "*", true) } ) ), before = """ package a; import java.util.List; class A { public A() { List list = java.util.Collections.emptyList(); } } """, after = """ package a; import java.util.List; import static java.util.Collections.*; class A { public A() { List list = emptyList(); } } """ ) @Issue("https://github.com/openrewrite/rewrite/issues/477") @Test fun dontAddImportForStaticImportsIndirectlyReferenced(jp: JavaParser.Builder<*, *>) = assertUnchanged( jp.classpath("jackson-databind").build(), recipe = toRecipe { object : JavaIsoVisitor() { override fun visitCompilationUnit(cu: J.CompilationUnit, p: ExecutionContext): J.CompilationUnit { maybeAddImport("com.fasterxml.jackson.databind.ObjectMapper") return super.visitCompilationUnit(cu, p) } } }, dependsOn = arrayOf( """ import com.fasterxml.jackson.databind.ObjectMapper; class Helper { static ObjectMapper OBJECT_MAPPER; } """ ), before = """ class Test { void test() { Helper.OBJECT_MAPPER.writer(); } } """ ) @Issue("https://github.com/openrewrite/rewrite/issues/776") @Test fun addImportAndFoldIntoWildcard(jp: JavaParser) = assertChanged( jp, dependsOn = arrayOf( """ package foo; public class B { } public class C { } """ ), recipe = addImports( { AddImport("java.util.ArrayList", null, false) } ), before = """ import foo.B; import foo.C; import java.util.Collections; import java.util.List; import java.util.HashSet; import java.util.HashMap; class A { B b = new B(); C c = new C(); Map map = new HashMap<>(); Set set = new HashSet<>(); List test = Collections.singletonList("test"); List test2 = new ArrayList<>(); } """, after = """ import foo.B; import foo.C; import java.util.*; class A { B b = new B(); C c = new C(); Map map = new HashMap<>(); Set set = new HashSet<>(); List test = Collections.singletonList("test"); List test2 = new ArrayList<>(); } """ ) @Issue("https://github.com/openrewrite/rewrite/issues/780") @Test fun addImportWhenDuplicatesExist(jp: JavaParser) = assertChanged( jp, recipe = addImports({ AddImport("org.springframework.http.MediaType", null, false) }), before = """ import javax.ws.rs.Path; import javax.ws.rs.Path; class A {} """, after = """ import org.springframework.http.MediaType; import javax.ws.rs.Path; import javax.ws.rs.Path; class A {} """ ) @Issue("https://github.com/openrewrite/rewrite/issues/867") @Test fun addImportWithCommentOnClassAndNoImportsOrPackageName(jp: JavaParser) = assertChanged( jp, recipe = toRecipe { object : JavaIsoVisitor() { val t = JavaTemplate.builder({ cursor }, """ /** * Do suppress those warnings */ @SuppressWarnings("other") """.trimIndent()) //.doBeforeParseTemplate(print) .build() override fun visitClassDeclaration( classDecl: J.ClassDeclaration, p: ExecutionContext ): J.ClassDeclaration { val cd = super.visitClassDeclaration(classDecl, p) if (cd.leadingAnnotations.size == 0) { maybeAddImport("java.lang.SuppressWarnings") return cd.withTemplate(t, cd.coordinates.addAnnotation { _, _ -> 0 }) } return cd } } }, before = """ class Test { class Inner1 { } } """, after = """ import java.lang.SuppressWarnings; /** * Do suppress those warnings */ @SuppressWarnings("other") class Test { /** * Do suppress those warnings */ @SuppressWarnings("other") class Inner1 { } } """ ) @Issue("https://github.com/openrewrite/rewrite/issues/880") @Test fun doNotFoldNormalImportWithNamespaceConflict(jp: JavaParser) { val inputs = arrayOf( """ package org.test; import org.bar.*; import org.foo.FooA; import org.foo.FooB; import org.foo.FooC; import org.foo.FooD; public class Test { FooA fooA = new FooA(); FooB fooB = new FooB(); FooC fooC = new FooC(); FooD fooD = new FooD(); Shared shared = new Shared(); BarA barA = new BarA(); BarB barB = new BarB(); BarC barC = new BarC(); BarD barD = new BarD(); BarE barE = new BarE(); } """.trimIndent(), """package org.foo; public class Shared {}""".trimIndent(), """package org.foo; public class FooA {}""".trimIndent(), """package org.foo; public class FooB {}""".trimIndent(), """package org.foo; public class FooC {}""".trimIndent(), """package org.foo; public class FooD {}""".trimIndent(), """package org.foo; public class FooE {}""".trimIndent(), """package org.bar; public class Shared {}""".trimIndent(), """package org.bar; public class BarA {}""".trimIndent(), """package org.bar; public class BarB {}""".trimIndent(), """package org.bar; public class BarC {}""".trimIndent(), """package org.bar; public class BarD {}""".trimIndent(), """package org.bar; public class BarE {}""".trimIndent() ) val sourceFiles = parser.parse(executionContext, *inputs) val classNames = arrayOf( "org.foo.Shared", "org.foo.FooA", "org.foo.FooB", "org.foo.FooC", "org.foo.FooD", "org.foo.FooE", "org.bar.Shared", "org.bar.BarA", "org.bar.BarB", "org.bar.BarC", "org.bar.BarD", "org.bar.BarE") val fqns: MutableSet = mutableSetOf() classNames.forEach { fqns.add(JavaType.Class.build(it)) } val sourceSet = JavaSourceSet(randomId(),"main", fqns) val markedFiles: MutableList = mutableListOf() sourceFiles.forEach { markedFiles.add(it.withMarkers(it.markers.addIfAbsent(sourceSet))) } val recipe: AddImport = AddImport("org.foo.Shared", null, false) val result = recipe.visit(markedFiles[0], InMemoryExecutionContext()) Assertions.assertThat((result as J.CompilationUnit).imports.size == 6).isTrue Assertions.assertThat((result).imports[5].qualid.printTrimmed()).isEqualTo("org.foo.Shared") } @Issue("https://github.com/openrewrite/rewrite/issues/880") @Test fun doNotFoldStaticsWithNamespaceConflict(jp: JavaParser) { val classNames = arrayOf("org.fuz.Fuz", "org.buz.Buz") val fqns: MutableSet = mutableSetOf() val flags = setOf(Flag.Public, Flag.Static) val methodSignature: JavaType.Method.Signature = JavaType.Method.Signature(JavaType.buildType("boolean"), listOf()) val variables: MutableList = mutableListOf() val methodsFoo: MutableList = mutableListOf() val methodNamesFoo = arrayOf("assertShared", "assertA", "assertB", "assertC") methodNamesFoo.forEach { methodsFoo.add( JavaType.Method.build(flags, JavaType.Class.build("org.fuz.Fuz"), it, null, methodSignature, listOf(), listOf(), listOf())) } fqns.add(JavaType.Class.build( Flag.flagsToBitMap(flags), classNames[0], JavaType.Class.Kind.Class, variables, listOf(), methodsFoo, null, null, listOf(), false)) val methodsBar: MutableList = mutableListOf() val methodNamesBar = arrayOf("assertShared", "assertThatA", "assertThatB", "assertThatC") methodNamesBar.forEach { methodsBar.add( JavaType.Method.build(flags, JavaType.Class.build("org.buz.Buz"), it, null, methodSignature, listOf(), listOf(), listOf())) } fqns.add(JavaType.Class.build( Flag.flagsToBitMap(flags), classNames[1], JavaType.Class.Kind.Class, variables, listOf(), methodsBar, null, null, listOf(), false)) val sourceSet = JavaSourceSet(randomId(),"main", fqns) val markedFiles: MutableList = mutableListOf() val inputs = arrayOf( """ package org.fuz; public class Fuz { public static boolean assertShared() { return true; } public static boolean assertA() { return true; } public static boolean assertB() { return true; } public static boolean assertC() { return true; } } """.trimIndent() , """ package org.buz; public class Buz { public static boolean assertShared() { return true; } public static boolean assertThatA() { return true; } public static boolean assertThatB() { return true; } public static boolean assertThatC() { return true; } } """.trimIndent(), """ package org.test; import static org.fuz.Fuz.assertA; import static org.fuz.Fuz.assertB; import static org.fuz.Fuz.assertC; import static org.buz.Buz.assertThatA; import static org.buz.Buz.assertThatB; public class Test { boolean fooA = assertA(); boolean fooB = assertB(); boolean fooC = assertC(); boolean barA = assertThatA(); boolean barB = assertThatB(); boolean barC = org.buz.Buz.assertThatC(); } """.trimIndent() ) // Inputs are processed last so that fqns are setup properly in flyweights. val sourceFiles = parser.parse(executionContext, *inputs) sourceFiles.forEach { markedFiles.add(it.withMarkers(it.markers.addIfAbsent(sourceSet))) } val recipe: AddImport = AddImport("org.buz.Buz", "assertThatC", false) val result = recipe.visit(markedFiles[2], InMemoryExecutionContext()) Assertions.assertThat((result as J.CompilationUnit).imports.size == 6).isTrue Assertions.assertThat((result).imports[5].qualid.printTrimmed()).isEqualTo("org.buz.Buz.assertThatC") } /** * This visitor removes the "java.util.Collections" receiver from method invocations of "java.util.Collections.emptyList()". * This allows us to test that AddImport with setOnlyIfReferenced = true will add a static import when an applicable static method call is present */ private class FixEmptyListMethodType : Recipe() { override fun getDisplayName(): String { return "Fix Empty List" } override fun getVisitor(): TreeVisitor<*, ExecutionContext> { return object : JavaIsoVisitor() { override fun visitMethodInvocation( method: J.MethodInvocation, ctx: ExecutionContext ): J.MethodInvocation { val original: J.MethodInvocation = super.visitMethodInvocation(method, ctx) if (original.name.simpleName == "emptyList") { return original.withSelect(null) } return original } } } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy