All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.jeppeman.mockposable.compiler.IrUtils.kt Maven / Gradle / Ivy

package com.jeppeman.mockposable.compiler

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.declarations.IrVariable
import org.jetbrains.kotlin.ir.expressions.IrBody
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrContainerExpression
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionExpression
import org.jetbrains.kotlin.ir.expressions.IrGetField
import org.jetbrains.kotlin.ir.expressions.IrLoop
import org.jetbrains.kotlin.ir.expressions.IrReturn
import org.jetbrains.kotlin.ir.expressions.IrTry
import org.jetbrains.kotlin.ir.expressions.IrTypeOperatorCall
import org.jetbrains.kotlin.ir.expressions.IrWhen
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.types.defaultType
import org.jetbrains.kotlin.ir.types.makeNullable
import org.jetbrains.kotlin.ir.util.dumpKotlinLike
import org.jetbrains.kotlin.ir.util.hasAnnotation
import org.jetbrains.kotlin.ir.util.statements
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.util.Logger

fun FqName.classId(
    isLocal: Boolean = false
): ClassId = ClassId(parent(), FqName.topLevel(shortName()), isLocal)

fun IrPluginContext.classSymbol(
    fqName: String
): IrClassSymbol = referenceClass(FqName(fqName).classId())!!

fun String.callableId(
    packageName: FqName
): CallableId = CallableId(packageName, Name.identifier(this))

fun  IrPluginContext.buildIr(
    declarationSymbol: IrSymbol,
    block: DeclarationIrBuilder.() -> T
): T = DeclarationIrBuilder(this, declarationSymbol).run { block() }

val composableAnnotationFqName: FqName
    get() = FqName("${COMPOSE_RUNTIME_PACKAGE_NAME}.Composable")

val IrPluginContext.composerClassSymbol: IrClassSymbol
    get() = classSymbol("${COMPOSE_RUNTIME_PACKAGE_NAME}.Composer")

val IrCall.isComposableFunctionCall: Boolean
    get() = symbol.owner.hasAnnotation(composableAnnotationFqName)

val IrWhen.calls get() = branches.flatMap { branch -> branch.result.calls }

val IrContainerExpression.calls get() = statements.flatMap { statement -> statement.calls }

val IrTry.calls: List get() = allStatements().flatMap { (statement) -> statement.calls }

val IrLoop.calls: List
    get() = body?.allStatements()?.flatMap { (statement) -> statement.calls }.orEmpty()

val IrBody.calls: List get() = statements.flatMap { statement -> statement.calls }

val IrTypeOperatorCall.calls: List get() = argument.calls

val IrReturn.calls: List get() = value.calls

val IrVariable.calls: List get() = initializer?.calls.orEmpty()

/**
 * Recursively extracts all [IrCall]s from an [IrElement]
 */
val IrElement.calls: List
    get() = when (this) {
        is IrTypeOperatorCall -> calls
        is IrReturn -> calls
        is IrVariable -> calls
        is IrWhen -> calls
        is IrLoop -> calls
        is IrTry -> calls
        is IrContainerExpression -> calls
        is IrBody -> calls
        is IrCall -> listOf(this)
        else -> emptyList()
    }

val IrWhen.statements: List
    get() = branches.flatMap { branch -> branch.result.allStatements() }

val IrLoop.statements: List
    get() = body?.allStatements() ?: listOf(IrStatementWithParent(this))

val IrTry.statements: List
    get() = tryResult.allStatements() +
            catches.flatMap { it.result.allStatements() } +
            finallyExpression?.allStatements().orEmpty()

/**
 * Recursively extracts all [IrStatement]s from an [IrElement]
 */
fun IrElement.allStatements(
    parent: IrElement? = null
): List = when (this) {
    is IrContainerExpression -> statements.flatMap { it.allStatements(this) }
    is IrBody -> statements.flatMap { it.allStatements(this) }
    is IrWhen -> statements
    is IrLoop -> statements
    is IrTry -> statements
    is IrStatement -> listOf(IrStatementWithParent(this, parent))
    else -> emptyList()
}

data class IrStatementWithParent(val statement: IrStatement, val parent: IrElement? = null)

/**
 * Find the actual composable call for stubbing or verification, i.e.
 */
fun IrCall.extractComposableCallFromBlockArg(
    name: String,
): IrFunctionExpression {
    val composableBlockValueParameter = symbol.owner.valueParameters.find { valueParameter ->
        valueParameter.name.asString() == name
    } ?: pluginError("Failed to find @Composable function in lambda from $name.")

    val composableBlockValueArgument = getValueArgument(composableBlockValueParameter.index)
        ?: pluginError("No value argument found for ${composableBlockValueParameter.name}.")

    return composableBlockValueArgument.run {
        cast() ?: extractComposableLambdaInstance()
    }
}

/**
 * The compose compiler may transform as follows:
 *
 * everyComposable(
 *     stubBlock = @Composable { ... }
 * )
 * ->
 * everyComposable(
 *     stubBlock = composableLambdaInstance(key = ..., tracked = ..., block = @Composable { ... })
 * )
 * or ->
 * everyComposable(
 *     stubBlock = ComposableSingletons$.()
 * )
 *
 * If that is the case, we need to extract the stub block from one of the composable lambda
 * containers instead of from the call directly.
 */
private fun IrExpression.extractComposableLambdaInstance(
    composableLambdaBlockParameterName: String = "block"
): IrFunctionExpression = when (this) {
    is IrCall -> when (symbol.owner.name.asString()) {
        // If this branch is matched we found a composableLambdaInstance(key, tracked, block = @Composable { ... })
        "composableLambdaInstance", "composableLambda" -> {
            val blockValueParameter = symbol.owner.valueParameters.find { valueParameter ->
                valueParameter.name.asString() == composableLambdaBlockParameterName
            } ?: pluginError(
                "Failed to find @Composable block in composableLambdaInstance from $composableLambdaBlockParameterName}. Parameter names were: ${symbol.owner.valueParameters.map { it.name }}."
            )

            val valueArgument = getValueArgument(blockValueParameter.index)
                ?: pluginError("No value argument found for ${blockValueParameter.name}.")
            valueArgument.cast() ?: pluginError(
                "Assumed wrong type for value argument ${blockValueParameter.name}, expected ${IrFunctionExpression::class.java.name}, got ${valueArgument::class.java.name}."
            )
        }
        // If not, the call is most likely contained within ComposableSingletons$,
        // we'll backtrack from the return statement of the call.
        else -> symbol.owner.body
            ?.statements
            ?.lastOrNull()
            ?.cast()
            ?.value
            ?.let { returnValue ->
                when (returnValue) {
                    is IrGetField -> returnValue.symbol.owner.initializer
                        ?.expression
                        ?.extractComposableLambdaInstance()
                    is IrCall -> returnValue.extractComposableLambdaInstance()
                    is IrFunctionExpression -> returnValue.extractComposableLambdaInstance()
                    else -> null
                }
            }
            ?: pluginError("Failed to find composableLambdaInstance from expression ${dumpKotlinLike()} with body ${symbol.owner.body?.dumpKotlinLike()}.")
    }
    is IrFunctionExpression -> this
    else -> pluginError("Failed to find composableLambdaInstance from expression ${dumpKotlinLike()}.")
}

/**
 * Recursively tries to find all @Composable calls from a function body.
 */
fun IrBody.findComposableCalls(): List = allStatements()
    .map(IrStatementWithParent::statement)
    .flatMap(IrStatement::calls)
    .filter(IrCall::isComposableFunctionCall)

/**
 * Transforms the $composer and $changed arguments that the Compose compiler generates into
 * argument matchers for the given mocking framework(s) that is in use.
 */
fun IrCall.transformComposeArgs(
    pluginContext: IrPluginContext,
    composerArgument: DeclarationIrBuilder.(IrExpression) -> IrExpression,
    changedArgument: DeclarationIrBuilder.(IrExpression) -> IrExpression,
    logger: Logger
) {
    val beforeTransform = dumpKotlinLike()

    val composerType = pluginContext.composerClassSymbol.defaultType.makeNullable()

    val composerArgIndex = (0 until valueArgumentsCount)
        .mapNotNull { index -> getValueArgument(index)?.let { index to it } }
        .find { (_, arg) -> arg.type == composerType }
        ?.first
    val composerArgMissing = symbol.owner.valueParameters.none {
        it.type.makeNullable() == composerType
    }

    if (composerArgMissing || composerArgIndex == null) {
        // Function has not been transformed with $composer and $changed if we get here.
        pluginError(
            "Composable function\n${symbol.owner.dumpKotlinLike()}was not transformed with ${"\$composer"} and ${"\$changed"} args. This most likely means that the Compose compiler plugin is not on the kotlinc classpath."
        )
    }

    putValueArgument(
        composerArgIndex,
        pluginContext.buildIr(symbol) {
            composerArgument(getValueArgument(composerArgIndex)!!)
        }
    )

    // $changed is always added as following $composer
    putValueArgument(
        composerArgIndex + 1,
        pluginContext.buildIr(symbol) {
            changedArgument(getValueArgument(composerArgIndex + 1)!!)
        }
    )

    val afterTransform = dumpKotlinLike()
    logger.log("Transformed $beforeTransform -> $afterTransform")
}

private const val COMPOSE_RUNTIME_PACKAGE_NAME = "androidx.compose.runtime"

class MockposablePluginException(message: String) : Exception(
    "$message\nThis is a bug in the Mockposable compiler plugin, please report the issue on GitHub."
)

fun pluginError(message: String): Nothing = throw MockposablePluginException(message)

@Suppress("UNCHECKED_CAST")
inline fun  Any?.cast(): T? = this as? T




© 2015 - 2025 Weber Informatics LLC | Privacy Policy