All Downloads are FREE. Search and download functionalities are using the official Maven repository.

codegen.CodegenStep.kt Maven / Gradle / Ivy

/*
 * Copyright (C) 2024 Meowool 
 *
 * This file is part of the MMKV-KTX project .
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

@file:Suppress("MemberVisibilityCanBePrivate")

package com.meowool.mmkv.ktx.compiler.codegen

import com.google.devtools.ksp.findActualType
import com.google.devtools.ksp.processing.CodeGenerator
import com.google.devtools.ksp.processing.KSPLogger
import com.google.devtools.ksp.symbol.ClassKind
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSAnnotation
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSDeclaration
import com.google.devtools.ksp.symbol.KSFile
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSTypeAlias
import com.google.devtools.ksp.symbol.KSTypeReference
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.ksp.kspDependencies
import com.squareup.kotlinpoet.ksp.originatingKSFiles
import com.squareup.kotlinpoet.ksp.toClassName
import java.util.Locale

abstract class CodegenStep {
  private var recordedExceptions = mutableSetOf()
  private val deferredNodes = mutableListOf()

  protected lateinit var context: Context

  protected abstract fun generate()

  fun start(context: Context): List {
    this.context = context
    // Every time we start a new round, we need to clear the exception nodes of the
    // previous round, because we don't need them anymore.
    deferredNodes.clear()
    generate()
    // We return the new deferred nodes so that they can be processed in the next round
    // of symbol processing.
    return deferredNodes
  }

  fun reportException() = recordedExceptions
    .distinctBy { it.message }
    .forEach { context.logger.exception(it) }

  open fun String.fixGeneratedCode() = this

  protected fun List.process(generate: (KSClassDeclaration) -> Unit) = forEach {
    try {
      generate(it)
    } catch (e: Throwable) {
      // Record it if this symbol processing fails, so that it can be deferred to
      // the next round of symbol processing.
      recordedExceptions.add(e)
      deferredNodes.add(it)
    }
  }

  fun KSFunctionDeclaration.toMemberName() = MemberName(
    packageName = parentDeclaration!!.qualifiedName!!.asString(),
    simpleName = simpleName.asString()
  )

  fun KSDeclaration.logName() = qualifiedName?.asString() ?: simpleName.asString()

  fun KSTypeReference?.logName(): String = buildString {
    val type = this@logName?.resolve() ?: return@buildString
    append(type.declaration.simpleName.asString())
    if (type.arguments.isNotEmpty()) {
      append("<")
      append(type.arguments.joinToString { it.type!!.logName() })
      append(">")
    }
    if (type.isMarkedNullable) append("?")
  }

  fun KSTypeReference?.matches(other: Any?): Boolean {
    if (this == null && other == null) return true
    if (this == null || other == null) return false
    val aType = resolve()
    val a = aType.declaration
    val b = when (other) {
      is ClassName -> return a.qualifiedName?.asString() == other.canonicalName
      is KSDeclaration -> other
      is KSTypeReference -> return matches(other.resolve())
      is KSType -> {
        if (other.arguments.size != aType.arguments.size) return false
        if (other.arguments.isNotEmpty()) {
          val aArgs = aType.arguments.map { it.type }
          val bArgs = other.arguments.map { it.type }
          if (aArgs.zip(bArgs).any { (a, b) -> !a.matches(b) }) return false
        }
        other.declaration
      }
      else -> return false
    }
    return a.qualifiedName?.asString() == b.qualifiedName?.asString()
  }

  fun KSType?.matchesNullable(other: KSType?): Boolean {
    if (this == null && other == null) return true
    if (this == null || other == null) return false
    if (this.isMarkedNullable != other.isMarkedNullable) return false
    return true
  }

  fun Sequence.contains(className: ClassName) = any {
    it.resolve().declaration.qualifiedName?.asString() == className.canonicalName
  }

  fun KSType.findActualDeclaration() = when (val symbol = declaration) {
    is KSClassDeclaration -> symbol
    is KSTypeAlias -> symbol.findActualType()
    else -> {
      context.logger.warn("Cannot find actual declaration of the type", symbol)
      null
    }
  }

  fun KSAnnotated.findAnnotation(className: ClassName) = annotations.find {
    it.shortName.asString() == className.simpleName && it.annotationType.resolve()
      .declaration.qualifiedName?.asString() == className.canonicalName
  }

  fun KSAnnotation?.findArgument(name: String) = this?.arguments?.find {
    it.name?.asString() == name
  }

  fun KSAnnotation?.findStringArgument(name: String) =
    when (val value = findArgument(name)?.value as? String) {
      null -> null
      else -> value.takeIf { it.isNotEmpty() }
    }

  fun KSAnnotation?.findIntArgument(name: String) =
    when (val value = findArgument(name)?.value as? Int) {
      null -> null
      else -> value
    }

  fun FileSpec.write(originatingDeclarations: Iterable) =
    write(originatingDeclarations.mapNotNull { it.containingFile }.toSet())

  fun FileSpec.write(originatingDeclaration: KSDeclaration) =
    write(listOf(originatingDeclaration.containingFile!!))

  @JvmName("writeWithFiles")
  fun FileSpec.write(originatingKSFiles: Iterable = originatingKSFiles()) {
    val dependencies = kspDependencies(
      aggregating = true,
      // We need the source file of the converter anyway, because we always depend on them.
      originatingKSFiles = originatingKSFiles + context.typeConverters.mapNotNull { it.containingFile }
    )
    val source = this.toString().fixGeneratedCode().replace("`get`()", "get()")

    context.codeGenerator
      .createNewFile(dependencies, packageName, name)
      .bufferedWriter()
      .use { it.write(source) }
  }

  fun String.lowercaseFirstChar() = replaceFirstChar { it.lowercase(Locale.getDefault()) }

  fun String.uppercaseFirstChar() = replaceFirstChar { it.uppercase(Locale.getDefault()) }

  data class Context(
    val preferences: List,
    val typeConverters: List,
    val logger: KSPLogger,
    val codeGenerator: CodeGenerator,
    val packageName: String,
  ) {
    val factoryClassName = className("PreferencesFactory")
    val factoryImplClassName = className("PreferencesFactoryImpl")
    val classTypeConverters = typeConverters.filter { it.classKind != ClassKind.OBJECT }

    val classTypeConvertersParams = classTypeConverters.map {
      ParameterSpec(
        name = it.simpleName.asString().replaceFirstChar(Char::lowercase),
        type = it.toClassName()
      )
    }

    fun mutableClassName(raw: KSClassDeclaration) =
      className("Mutable" + raw.simpleName.asString())

    fun mutableImplClassName(raw: KSClassDeclaration) =
      className("Mutable" + raw.simpleName.asString() + "Impl")

    fun preferencesClassName(raw: KSClassDeclaration) =
      className(raw.simpleName.asString() + "Preferences")

    fun preferencesImplClassName(raw: KSClassDeclaration) =
      className(raw.simpleName.asString() + "PreferencesImpl")

    private fun className(name: String) = ClassName(packageName, name)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy