All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.kt Maven / Gradle / Ivy

Go to download

The core implementation module for Anvil, responsible for hooking into the Kotlin compiler and orchestrating code generation

There is a newer version: 0.4.0
Show newest version
package com.squareup.anvil.compiler.codegen.dagger

import com.google.auto.service.AutoService
import com.google.devtools.ksp.KspExperimental
import com.google.devtools.ksp.getAllSuperTypes
import com.google.devtools.ksp.getAnnotationsByType
import com.google.devtools.ksp.getConstructors
import com.google.devtools.ksp.getDeclaredFunctions
import com.google.devtools.ksp.getVisibility
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunction
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSNode
import com.google.devtools.ksp.symbol.KSValueParameter
import com.google.devtools.ksp.symbol.Visibility.PROTECTED
import com.google.devtools.ksp.symbol.Visibility.PUBLIC
import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker
import com.squareup.anvil.compiler.api.AnvilCompilationException
import com.squareup.anvil.compiler.api.AnvilContext
import com.squareup.anvil.compiler.api.CodeGenerator
import com.squareup.anvil.compiler.api.GeneratedFileWithSources
import com.squareup.anvil.compiler.api.createGeneratedFile
import com.squareup.anvil.compiler.assistedFactoryFqName
import com.squareup.anvil.compiler.assistedFqName
import com.squareup.anvil.compiler.assistedInjectFqName
import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator
import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.AssistedParameterKey.Companion.toAssistedParameterKey
import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.Embedded.AssistedFactoryFunction.Companion.toAssistedFactoryFunction
import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.KspGenerator.AssistedFactoryFunction.Companion.toAssistedFactoryFunction
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor
import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider
import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException
import com.squareup.anvil.compiler.codegen.ksp.KspErrorTypeException
import com.squareup.anvil.compiler.codegen.ksp.contextualToTypeName
import com.squareup.anvil.compiler.codegen.ksp.isAnnotationPresent
import com.squareup.anvil.compiler.codegen.ksp.isInterface
import com.squareup.anvil.compiler.codegen.ksp.resolveKSClassDeclaration
import com.squareup.anvil.compiler.internal.createAnvilSpec
import com.squareup.anvil.compiler.internal.joinSimpleNames
import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference
import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference
import com.squareup.anvil.compiler.internal.reference.ClassReference
import com.squareup.anvil.compiler.internal.reference.MemberFunctionReference
import com.squareup.anvil.compiler.internal.reference.ParameterReference
import com.squareup.anvil.compiler.internal.reference.Visibility
import com.squareup.anvil.compiler.internal.reference.allSuperTypeClassReferences
import com.squareup.anvil.compiler.internal.reference.argumentAt
import com.squareup.anvil.compiler.internal.reference.asClassName
import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier.OVERRIDE
import com.squareup.kotlinpoet.KModifier.PRIVATE
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeName
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.TypeVariableName
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.jvm.jvmStatic
import com.squareup.kotlinpoet.ksp.TypeParameterResolver
import com.squareup.kotlinpoet.ksp.toClassName
import com.squareup.kotlinpoet.ksp.toTypeParameterResolver
import com.squareup.kotlinpoet.ksp.toTypeVariableName
import com.squareup.kotlinpoet.ksp.writeTo
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
import dagger.internal.InstanceFactory
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.psi.KtFile
import java.io.File
import javax.inject.Provider

internal object AssistedFactoryCodeGen : AnvilApplicabilityChecker {

  override fun isApplicable(context: AnvilContext) = context.generateFactories

  internal class KspGenerator(
    override val env: SymbolProcessorEnvironment,
  ) : AnvilSymbolProcessor() {
    @AutoService(SymbolProcessorProvider::class)
    class Provider : AnvilSymbolProcessorProvider(AssistedFactoryCodeGen, ::KspGenerator)

    override fun processChecked(resolver: Resolver): List {
      val deferred = mutableListOf()
      resolver.getSymbolsWithAnnotation(assistedFactoryFqName.asString())
        .filterIsInstance()
        .forEach { clazz ->
          generateFactoryClass(clazz)
            ?.writeTo(env.codeGenerator, aggregating = false, listOf(clazz.containingFile!!))
            ?: deferred.add(clazz)
        }
      return deferred
    }

    private fun generateFactoryClass(
      clazz: KSClassDeclaration,
    ): FileSpec? {
      val typeParameterResolver = clazz.typeParameters.toTypeParameterResolver()
      val function = try {
        clazz.requireSingleAbstractFunction(typeParameterResolver)
      } catch (e: KspErrorTypeException) {
        return null
      }

      val returnType = try {
        function.returnType
      } catch (e: Exception) {
        // Catch the exception and throw the same error that Dagger would.
        throw KspAnvilException(
          message = "Invalid return type: ${clazz.qualifiedName?.asString()}. An assisted factory's " +
            "abstract method must return a type with an @AssistedInject-annotated constructor.",
          node = function.node,
          cause = e,
        )
      }

      // The return type of the function must have an @AssistedInject constructor.
      val constructor = returnType
        .getConstructors()
        .singleOrNull {
          it.isAnnotationPresent()
        }
        ?: throw KspAnvilException(
          message = "Invalid return type: ${returnType.qualifiedName?.asString()}. An assisted factory's abstract " +
            "method must return a type with an @AssistedInject-annotated constructor.",
          node = clazz,
        )

      val functionParameters = function.parameterKeys
      val assistedParameters = constructor.parameters.filter { parameter ->
        parameter.isAnnotationPresent()
      }

      // Check that the parameters of the function match the @Assisted parameters of the constructor.
      if (assistedParameters.size != functionParameters.size) {
        throw KspAnvilException(
          message = "The parameters in the factory method must match the @Assisted parameters in " +
            "${returnType.qualifiedName?.asString()}.",
          node = clazz,
        )
      }

      // Compute for each parameter its key.
      val functionParameterKeys = function.parameterKeys
      val assistedParameterKeys = assistedParameters.map {
        it.toAssistedParameterKey(
          it.type.resolve().contextualToTypeName(it.type, typeParameterResolver),
        )
      }

      // The factory function may not have two or more parameters with the same key.
      val duplicateKeys = functionParameterKeys
        .groupBy { it.key }
        .filter { it.value.size > 1 }
        .values
        .flatten()

      if (duplicateKeys.isNotEmpty()) {
        // Complain about the first duplicate key that occurs, similar to Dagger.
        val key = functionParameterKeys.first { it in duplicateKeys }

        throw KspAnvilException(
          message = buildString {
            append("@AssistedFactory method has duplicate @Assisted types: ")
            if (key.identifier.isNotEmpty()) {
              append("@Assisted(\"${key.identifier}\") ")
            }
            append(key.typeName)
          },
          node = clazz,
        )
      }

      // Check that for each parameter of the factory function there is a parameter with the same
      // key in the @AssistedInject constructor.
      val notMatchingKeys = (functionParameterKeys + assistedParameterKeys)
        .groupBy { it.key }
        .filter { it.value.size == 1 }
        .values
        .flatten()

      if (notMatchingKeys.isNotEmpty()) {
        throw KspAnvilException(
          message = "The parameters in the factory method must match the @Assisted parameters in " +
            "${returnType.qualifiedName?.asString()}.",
          node = clazz,
        )
      }

      val typeParameters = clazz.typeParameters

      val functionName = function.simpleName
      val baseFactoryIsInterface = clazz.isInterface()
      val functionParameterPairs = function.parameterPairs

      val spec = buildSpec(
        originClassNAme = clazz.toClassName(),
        targetType = returnType.toClassName(),
        functionName = functionName,
        typeParameters = typeParameters.map { it.toTypeVariableName(typeParameterResolver) },
        assistedParameterKeys = assistedParameterKeys,
        baseFactoryIsInterface = baseFactoryIsInterface,
        functionParameterPairs = functionParameterPairs.map { (ref, typeName) ->
          ref.name!!.asString() to typeName
        },
        functionParameterKeys = functionParameterKeys,
      )

      return spec
    }

    private fun KSClassDeclaration.requireSingleAbstractFunction(
      typeParameterResolver: TypeParameterResolver,
    ): AssistedFactoryFunction {
      val implementingType = asType(emptyList())

      // `clazz` must be first in the list because of `distinctBy { ... }`, which keeps the first
      // matched element. If the function's inherited, it can be overridden as well. Prioritizing
      // the version from the file we're parsing ensures the correct variance of the referenced types.
      // TODO can't use getAllFunctions() yet due to https://github.com/google/ksp/issues/1619
      val assistedFunctions = sequenceOf(this)
        .plus(
          getAllSuperTypes()
            .onEach { if (it.isError) throw KspErrorTypeException() }
            .mapNotNull { it.resolveKSClassDeclaration() },
        )
        .distinctBy { it.qualifiedName?.asString() }
        .flatMap { clazz ->
          clazz.getDeclaredFunctions()
            .filter {
              it.isAbstract &&
                (it.getVisibility() == PUBLIC || it.getVisibility() == PROTECTED)
            }
        }
        .distinctBy { it.simpleName.asString() }
        .map {
          if (it.returnType?.resolve()?.isError == true) throw KspErrorTypeException()
          it.asMemberOf(implementingType)
            .toAssistedFactoryFunction(it, typeParameterResolver)
        }
        .toList()

      // Check for exact number of functions.
      return when (assistedFunctions.size) {
        0 -> throw KspAnvilException(
          message = "The @AssistedFactory-annotated type is missing an abstract, non-default " +
            "method whose return type matches the assisted injection type.",
          node = this,
        )

        1 -> assistedFunctions[0]
        else -> {
          val foundFunctions = assistedFunctions
            .sortedBy { it.simpleName }
            .joinToString { func ->
              "${func.qualifiedName}(${func.parameterPairs.map { it.first.name }})"
            }
          throw KspAnvilException(
            message = "The @AssistedFactory-annotated type should contain a single abstract, " +
              "non-default method but found multiple: [$foundFunctions]",
            node = this,
          )
        }
      }
    }

    /**
     * Represents a parsed function in an `@AssistedInject.Factory`-annotated interface.
     */
    private data class AssistedFactoryFunction(
      val simpleName: String,
      val qualifiedName: String,
      val returnType: KSClassDeclaration,
      val node: KSNode,
      val parameterKeys: List,
      /**
       * Pair of parameter reference to parameter type.
       */
      val parameterPairs: List>,
    ) {

      companion object {
        fun KSFunction.toAssistedFactoryFunction(
          originalDeclaration: KSFunctionDeclaration,
          typeParameterResolver: TypeParameterResolver,
        ): AssistedFactoryFunction {
          return AssistedFactoryFunction(
            simpleName = originalDeclaration.simpleName.asString(),
            qualifiedName = originalDeclaration.qualifiedName!!.asString(),
            returnType = returnType!!.resolveKSClassDeclaration()!!,
            node = originalDeclaration,
            parameterKeys = originalDeclaration.parameters.mapIndexed { index, param ->
              param.toAssistedParameterKey(
                parameterTypes[index]!!.contextualToTypeName(param, typeParameterResolver),
              )
            },
            parameterPairs = originalDeclaration.parameters.mapIndexed { index, param ->
              param to parameterTypes[index]!!.contextualToTypeName(param, typeParameterResolver)
            },
          )
        }
      }
    }
  }

  @AutoService(CodeGenerator::class)
  internal class Embedded : PrivateCodeGenerator() {

    override fun isApplicable(context: AnvilContext) = AssistedFactoryCodeGen.isApplicable(context)

    override fun generateCodePrivate(
      codeGenDir: File,
      module: ModuleDescriptor,
      projectFiles: Collection,
    ): List = projectFiles
      .classAndInnerClassReferences(module)
      .filter { it.isAnnotatedWith(assistedFactoryFqName) }
      .map { clazz ->
        generateFactoryClass(codeGenDir, clazz)
      }
      .toList()

    private fun generateFactoryClass(
      codeGenDir: File,
      clazz: ClassReference.Psi,
    ): GeneratedFileWithSources {
      val function = clazz.requireSingleAbstractFunction()

      val returnType = try {
        function.function.resolveGenericReturnType(clazz)
      } catch (e: AnvilCompilationException) {
        // Catch the exception and throw the same error that Dagger would.
        throw AnvilCompilationExceptionFunctionReference(
          message = "Invalid return type: ${clazz.fqName}. An assisted factory's " +
            "abstract method must return a type with an @AssistedInject-annotated constructor.",
          functionReference = function.function,
          cause = e,
        )
      }

      // The return type of the function must have an @AssistedInject constructor.
      val constructor = returnType
        .constructors
        .singleOrNull { it.isAnnotatedWith(assistedInjectFqName) }
        ?: throw AnvilCompilationExceptionClassReference(
          message = "Invalid return type: ${returnType.fqName}. An assisted factory's abstract " +
            "method must return a type with an @AssistedInject-annotated constructor.",
          classReference = clazz,
        )

      val functionParameters = function.parameterKeys
      val assistedParameters = constructor.parameters.filter { parameter ->
        parameter.annotations.any { it.fqName == assistedFqName }
      }

      // Check that the parameters of the function match the @Assisted parameters of the constructor.
      if (assistedParameters.size != functionParameters.size) {
        throw AnvilCompilationExceptionClassReference(
          message = "The parameters in the factory method must match the @Assisted parameters in " +
            "${returnType.fqName}.",
          classReference = clazz,
        )
      }

      // Compute for each parameter its key.
      val functionParameterKeys = function.parameterKeys
      val assistedParameterKeys = assistedParameters.map { it.toAssistedParameterKey(clazz) }

      // The factory function may not have two or more parameters with the same key.
      val duplicateKeys = functionParameterKeys
        .groupBy { it.key }
        .filter { it.value.size > 1 }
        .values
        .flatten()

      if (duplicateKeys.isNotEmpty()) {
        // Complain about the first duplicate key that occurs, similar to Dagger.
        val key = functionParameterKeys.first { it in duplicateKeys }

        throw AnvilCompilationExceptionClassReference(
          message = buildString {
            append("@AssistedFactory method has duplicate @Assisted types: ")
            if (key.identifier.isNotEmpty()) {
              append("@Assisted(\"${key.identifier}\") ")
            }
            append(key.typeName)
          },
          classReference = clazz,
        )
      }

      // Check that for each parameter of the factory function there is a parameter with the same
      // key in the @AssistedInject constructor.
      val notMatchingKeys = (functionParameterKeys + assistedParameterKeys)
        .groupBy { it.key }
        .filter { it.value.size == 1 }
        .values
        .flatten()

      if (notMatchingKeys.isNotEmpty()) {
        throw AnvilCompilationExceptionClassReference(
          message = "The parameters in the factory method must match the @Assisted parameters in " +
            "${returnType.fqName}.",
          classReference = clazz,
        )
      }

      val typeParameters = clazz.typeParameters

      val functionName = function.function.name
      val baseFactoryIsInterface = clazz.isInterface()
      val functionParameterPairs = function.parameterPairs

      val spec = buildSpec(
        originClassNAme = clazz.asClassName(),
        targetType = returnType.asClassName(),
        functionName = functionName,
        typeParameters = typeParameters.map { it.typeVariableName },
        assistedParameterKeys = assistedParameterKeys,
        baseFactoryIsInterface = baseFactoryIsInterface,
        functionParameterPairs = functionParameterPairs.map { (ref, typeName) ->
          ref.name to typeName
        },
        functionParameterKeys = functionParameterKeys,
      )

      return createGeneratedFile(
        codeGenDir = codeGenDir,
        packageName = spec.packageName,
        fileName = spec.name,
        content = spec.toString(),
        sourceFile = clazz.containingFileAsJavaFile,
      )
    }

    private fun ClassReference.Psi.requireSingleAbstractFunction(): AssistedFactoryFunction {
      // `clazz` must be first in the list because of `distinctBy { ... }`, which keeps the first
      // matched element. If the function's inherited, it can be overridden as well. Prioritizing
      // the version from the file we're parsing ensures the correct variance of the referenced types.
      val assistedFunctions = allSuperTypeClassReferences(includeSelf = true)
        .distinctBy { it.fqName }
        .flatMap { clazz ->
          clazz.functions
            .filter {
              it.isAbstract() &&
                (it.visibility() == Visibility.PUBLIC || it.visibility() == Visibility.PROTECTED)
            }
        }
        .distinctBy { it.name }
        .map { it.toAssistedFactoryFunction(this) }
        .toList()

      // Check for exact number of functions.
      return when (assistedFunctions.size) {
        0 -> throw AnvilCompilationExceptionClassReference(
          message = "The @AssistedFactory-annotated type is missing an abstract, non-default " +
            "method whose return type matches the assisted injection type.",
          classReference = this,
        )

        1 -> assistedFunctions[0]
        else -> {
          val foundFunctions = assistedFunctions
            .sortedBy { it.function.name }
            .joinToString { func ->
              "${func.function.fqName}(${func.parameterPairs.map { it.first.name }})"
            }
          throw AnvilCompilationExceptionClassReference(
            message = "The @AssistedFactory-annotated type should contain a single abstract, " +
              "non-default method but found multiple: [$foundFunctions]",
            classReference = this,
          )
        }
      }
    }

    /**
     * Represents a parsed function in an `@AssistedInject.Factory`-annotated interface.
     */
    private data class AssistedFactoryFunction(
      val function: MemberFunctionReference,
      val parameterKeys: List,
      /**
       * Pair of parameter reference to parameter type.
       */
      val parameterPairs: List>,
    ) {
      companion object {
        fun MemberFunctionReference.toAssistedFactoryFunction(
          factoryClass: ClassReference.Psi,
        ): AssistedFactoryFunction {
          return AssistedFactoryFunction(
            function = this,
            parameterKeys = parameters.map { it.toAssistedParameterKey(factoryClass) },
            parameterPairs = parameters.map { it to it.resolveTypeName(factoryClass) },
          )
        }
      }
    }
  }

  private const val DELEGATE_FACTORY_NAME = "delegateFactory"

  private fun buildSpec(
    originClassNAme: ClassName,
    targetType: ClassName,
    functionName: String,
    typeParameters: List,
    assistedParameterKeys: List,
    baseFactoryIsInterface: Boolean,
    functionParameterPairs: List>,
    functionParameterKeys: List,
  ): FileSpec {
    val generatedFactoryTypeName = targetType.joinSimpleNames(suffix = "_Factory")
      .optionallyParameterizedByNames(typeParameters)

    val baseFactoryTypeName = originClassNAme.optionallyParameterizedByNames(typeParameters)
    val returnTypeName = targetType.optionallyParameterizedByNames(typeParameters)
    val implClassName = originClassNAme.joinSimpleNames(suffix = "_Impl")
    val implParameterizedTypeName = implClassName.optionallyParameterizedByNames(typeParameters)

    return FileSpec.createAnvilSpec(implClassName.packageName, implClassName.simpleName) {
      TypeSpec.classBuilder(implClassName)
        .apply {
          addTypeVariables(typeParameters)

          if (baseFactoryIsInterface) {
            addSuperinterface(baseFactoryTypeName)
          } else {
            superclass(baseFactoryTypeName)
          }

          primaryConstructor(
            FunSpec.constructorBuilder()
              .addParameter(DELEGATE_FACTORY_NAME, generatedFactoryTypeName)
              .build(),
          )

          addProperty(
            PropertySpec.builder(DELEGATE_FACTORY_NAME, generatedFactoryTypeName)
              .initializer(DELEGATE_FACTORY_NAME)
              .addModifiers(PRIVATE)
              .build(),
          )
        }
        .addFunction(
          FunSpec.builder(functionName)
            .addModifiers(OVERRIDE)
            .returns(returnTypeName)
            .apply {
              functionParameterPairs.forEach { parameter ->
                addParameter(parameter.first, parameter.second)
              }

              // We call the @AssistedInject constructor. Therefore, find for each assisted
              // parameter the function parameter where the keys match.
              val argumentList = assistedParameterKeys.joinToString { assistedParameterKey ->
                val functionIndex = functionParameterKeys.indexOfFirst {
                  it.key == assistedParameterKey.key
                }
                check(functionIndex >= 0) {
                  // Sanity check, this should not happen with the noMatchingKeys list check above.
                  "Unexpected assistedIndex."
                }

                functionParameterPairs[functionIndex].first
              }

              addStatement("return $DELEGATE_FACTORY_NAME.get($argumentList)")
            }
            .build(),
        )
        .apply {
          fun createFactory(name: String, providerTypeName: ClassName): FunSpec {
            return FunSpec.builder(name)
              .jvmStatic()
              .addTypeVariables(typeParameters)
              .addParameter(DELEGATE_FACTORY_NAME, generatedFactoryTypeName)
              .returns(providerTypeName.parameterizedBy(baseFactoryTypeName))
              .addStatement(
                "return %T.create(%T($DELEGATE_FACTORY_NAME))",
                InstanceFactory::class,
                implParameterizedTypeName,
              )
              .build()
          }
          TypeSpec.companionObjectBuilder()
            .addFunction(createFactory("create", Provider::class.asClassName()))
            // New in Dagger 2.50: factories for dagger.internal.Provider
            .addFunction(
              createFactory("createFactoryProvider", dagger.internal.Provider::class.asClassName()),
            )
            .build()
            .let {
              addType(it)
            }
        }
        .build()
        .let { addType(it) }
    }
  }

  // Dagger matches parameters of the factory function with the parameters of the @AssistedInject
  // constructor through a key. Initially, they used the order of parameters, but that has changed.
  // The key is a combination of the type and identifier (value parameter) of the
  // @Assisted("...") annotation. For each parameter the key must be unique.
  private data class AssistedParameterKey(
    val typeName: TypeName,
    val identifier: String,
  ) {

    // Key value is similar to a hash function.  There used to be a special case for KotlinTypes
    // which were parameterized, but this is now handled by KotlinPoet's TypeName.
    // `MyType` and `MyType` now generate different hashCodes.
    val key: Int = identifier.hashCode() * 31 + typeName.hashCode()

    companion object {
      @OptIn(KspExperimental::class)
      fun KSValueParameter.toAssistedParameterKey(
        typeName: TypeName,
      ): AssistedParameterKey {
        return AssistedParameterKey(
          typeName,
          getAnnotationsByType(Assisted::class)
            .singleOrNull()
            ?.value
            .orEmpty(),
        )
      }

      fun ParameterReference.toAssistedParameterKey(
        factoryClass: ClassReference.Psi,
      ): AssistedParameterKey {
        return AssistedParameterKey(
          typeName = resolveTypeName(factoryClass),
          identifier = annotations
            .singleOrNull { it.fqName == assistedFqName }
            ?.let { annotation ->
              annotation.argumentAt("value", index = 0)?.value()
            }
            .orEmpty(),
        )
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy