All Downloads are FREE. Search and download functionalities are using the official Maven repository.

androidx.compose.compiler.plugins.kotlin.lower.ComposerParamTransformer.kt Maven / Gradle / Ivy

There is a newer version: 2.1.0-RC
Show newest version
/*
 * Copyright 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

@file:OptIn(UnsafeDuringIrConstructionAPI::class)

package androidx.compose.compiler.plugins.kotlin.lower

import androidx.compose.compiler.plugins.kotlin.ComposeNames
import androidx.compose.compiler.plugins.kotlin.FeatureFlags
import androidx.compose.compiler.plugins.kotlin.ModuleMetrics
import androidx.compose.compiler.plugins.kotlin.analysis.StabilityInferencer
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.ir.moveBodyTo
import org.jetbrains.kotlin.backend.jvm.ir.isInlineClassType
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.load.java.JvmAbi
import org.jetbrains.kotlin.name.JvmStandardClassIds
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.platform.jvm.isJvm
import org.jetbrains.kotlin.resolve.DescriptorUtils
import org.jetbrains.kotlin.util.OperatorNameConventions
import kotlin.collections.set
import kotlin.math.min

class ComposerParamTransformer(
    context: IrPluginContext,
    stabilityInferencer: StabilityInferencer,
    metrics: ModuleMetrics,
    featureFlags: FeatureFlags,
) : AbstractComposeLowering(context, metrics, stabilityInferencer, featureFlags),
    ModuleLoweringPass {

    /**
     * Used to identify module fragment in case of incremental compilation
     * see [externallyTransformed]
     */
    private var currentModule: IrModuleFragment? = null

    private var inlineLambdaInfo = ComposeInlineLambdaLocator(context)

    override fun lower(module: IrModuleFragment) {
        currentModule = module

        inlineLambdaInfo.scan(module)

        module.transformChildrenVoid(this)

        val typeRemapper = ComposableTypeRemapper(
            context,
            composerType
        )
        val transformer = ComposableTypeTransformer(context, typeRemapper)
        // for each declaration, we remap types to ensure that @Composable lambda types are realized
        module.transformChildrenVoid(transformer)

        // just go through and patch all of the parents to make sure things are properly wired
        // up.
        module.patchDeclarationParents()
    }

    private val transformedFunctions: MutableMap =
        mutableMapOf()

    private val transformedFunctionSet = mutableSetOf()

    private val composerType = composerIrClass.defaultType.replaceArgumentsWithStarProjections()

    override fun visitSimpleFunction(declaration: IrSimpleFunction): IrStatement =
        super.visitSimpleFunction(declaration.withComposerParamIfNeeded())

    override fun visitLocalDelegatedPropertyReference(
        expression: IrLocalDelegatedPropertyReference,
    ): IrExpression {
        val transformedGetter = expression.getter.owner.withComposerParamIfNeeded()
        return super.visitLocalDelegatedPropertyReference(
            IrLocalDelegatedPropertyReferenceImpl(
                expression.startOffset,
                expression.endOffset,
                expression.type,
                expression.symbol,
                expression.delegate,
                transformedGetter.symbol,
                expression.setter,
                expression.origin
            )
        )
    }

    override fun visitLocalDelegatedProperty(declaration: IrLocalDelegatedProperty): IrStatement {
        if (declaration.getter.isComposableDelegatedAccessor()) {
            declaration.getter.annotations += createComposableAnnotation()
        }

        if (declaration.setter?.isComposableDelegatedAccessor() == true) {
            declaration.setter!!.annotations += createComposableAnnotation()
        }

        return super.visitLocalDelegatedProperty(declaration)
    }

    private fun createComposableAnnotation() =
        IrConstructorCallImpl(
            startOffset = SYNTHETIC_OFFSET,
            endOffset = SYNTHETIC_OFFSET,
            type = composableIrClass.defaultType,
            symbol = composableIrClass.primaryConstructor!!.symbol,
            typeArgumentsCount = 0,
            constructorTypeArgumentsCount = 0
        )

    fun IrCall.withComposerParamIfNeeded(composerParam: IrValueParameter): IrCall {
        val ownerFn = when {
            symbol.owner.isComposableDelegatedAccessor() -> {
                if (!symbol.owner.hasComposableAnnotation()) {
                    symbol.owner.annotations += createComposableAnnotation()
                }
                symbol.owner.withComposerParamIfNeeded()
            }
            isComposableLambdaInvoke() ->
                symbol.owner.lambdaInvokeWithComposerParam()
            symbol.owner.hasComposableAnnotation() ->
                symbol.owner.withComposerParamIfNeeded()
            // Not a composable call
            else -> return this
        }

        return IrCallImpl(
            startOffset,
            endOffset,
            type,
            ownerFn.symbol,
            typeArgumentsCount,
            origin,
            superQualifierSymbol
        ).also {
            it.copyAttributes(this)
            it.copyTypeArgumentsFrom(this)
            it.dispatchReceiver = dispatchReceiver
            it.extensionReceiver = extensionReceiver
            val argumentsMissing = mutableListOf()
            for (i in 0 until valueArgumentsCount) {
                val arg = getValueArgument(i)
                val param = ownerFn.valueParameters[i]
                val hasDefault = ownerFn.hasDefaultExpressionDefinedForValueParameter(i)
                argumentsMissing.add(arg == null && hasDefault)
                if (arg != null) {
                    it.putValueArgument(i, arg)
                } else if (param.isVararg) {
                    // do nothing
                } else {
                    it.putValueArgument(i, defaultArgumentFor(param))
                }
            }
            val valueParams = valueArgumentsCount
            val realValueParams = valueParams - ownerFn.contextReceiverParametersCount
            var argIndex = valueArgumentsCount
            it.putValueArgument(
                argIndex++,
                IrGetValueImpl(
                    UNDEFINED_OFFSET,
                    UNDEFINED_OFFSET,
                    composerParam.symbol
                )
            )

            // $changed[n]
            for (i in 0 until changedParamCount(realValueParams, ownerFn.thisParamCount)) {
                if (argIndex < ownerFn.valueParameters.size) {
                    it.putValueArgument(
                        argIndex++,
                        irConst(0)
                    )
                } else {
                    error("1. expected value parameter count to be higher: ${this.dumpSrc()}")
                }
            }

            // $default[n]
            for (i in 0 until defaultParamCount(valueParams)) {
                val start = i * BITS_PER_INT
                val end = min(start + BITS_PER_INT, valueParams)
                if (argIndex < ownerFn.valueParameters.size) {
                    val bits = argumentsMissing
                        .toBooleanArray()
                        .sliceArray(start until end)
                    it.putValueArgument(
                        argIndex++,
                        irConst(bitMask(*bits))
                    )
                } else if (argumentsMissing.any { it }) {
                    error("2. expected value parameter count to be higher: ${this.dumpSrc()}")
                }
            }
        }
    }

    private fun defaultArgumentFor(param: IrValueParameter): IrExpression? {
        if (param.varargElementType != null) return null
        return param.type.defaultValue().let {
            IrCompositeImpl(
                it.startOffset,
                it.endOffset,
                it.type,
                IrStatementOrigin.DEFAULT_VALUE,
                listOf(it)
            )
        }
    }

    // TODO(lmr): There is an equivalent function in IrUtils, but we can't use it because it
    //  expects a JvmBackendContext. That implementation uses a special "unsafe coerce" builtin
    //  method, which is only available on the JVM backend. On the JS and Native backends we
    //  don't have access to that so instead we are just going to construct the inline class
    //  itself and hope that it gets lowered properly.
    private fun IrType.defaultValue(
        startOffset: Int = UNDEFINED_OFFSET,
        endOffset: Int = UNDEFINED_OFFSET,
    ): IrExpression {
        val classSymbol = classOrNull
        if (this !is IrSimpleType || isMarkedNullable() || !isInlineClassType()) {
            return if (isMarkedNullable()) {
                IrConstImpl.constNull(startOffset, endOffset, context.irBuiltIns.nothingNType)
            } else {
                IrConstImpl.defaultValueForType(startOffset, endOffset, this)
            }
        }

        if (context.platform.isJvm()) {
            val underlyingType = unboxInlineClass()
            return coerceInlineClasses(
                IrConstImpl.defaultValueForType(startOffset, endOffset, underlyingType),
                underlyingType,
                this
            )
        } else {
            val ctor = classSymbol!!.constructors.first { it.owner.isPrimary }
            val underlyingType = getInlineClassUnderlyingType(classSymbol.owner)

            // TODO(lmr): We should not be calling the constructor here, but this seems like a
            //  reasonable interim solution.
            return IrConstructorCallImpl(
                startOffset,
                endOffset,
                this,
                ctor,
                typeArgumentsCount = 0,
                constructorTypeArgumentsCount = 0,
                origin = null
            ).also {
                it.putValueArgument(0, underlyingType.defaultValue(startOffset, endOffset))
            }
        }
    }

    // Transform `@Composable fun foo(params): RetType` into `fun foo(params, $composer: Composer): RetType`
    private fun IrSimpleFunction.withComposerParamIfNeeded(): IrSimpleFunction {
        // don't transform functions that themselves were produced by this function. (ie, if we
        // call this with a function that has the synthetic composer parameter, we don't want to
        // transform it further).
        if (transformedFunctionSet.contains(this)) return this

        // if not a composable fn, nothing we need to do
        if (!this.hasComposableAnnotation()) {
            return this
        }

        // we don't bother transforming expect functions. They exist only for type resolution and
        // don't need to be transformed to have a composer parameter
        if (isExpect) return this

        // cache the transformed function with composer parameter
        return transformedFunctions[this] ?: copyWithComposerParam()
    }

    private fun IrSimpleFunction.lambdaInvokeWithComposerParam(): IrSimpleFunction {
        val argCount = valueParameters.size
        val extraParams = composeSyntheticParamCount(argCount)
        val newFnClass = context.function(argCount + extraParams).owner
        val newInvoke = newFnClass.functions.first {
            it.name == OperatorNameConventions.INVOKE
        }
        return newInvoke
    }

    private fun IrSimpleFunction.copy(): IrSimpleFunction {
        // TODO(lmr): use deepCopy instead?
        return context.irFactory.createSimpleFunction(
            startOffset = startOffset,
            endOffset = endOffset,
            origin = origin,
            name = name,
            visibility = visibility,
            isInline = isInline,
            isExpect = isExpect,
            returnType = returnType,
            modality = modality,
            symbol = IrSimpleFunctionSymbolImpl(),
            isTailrec = isTailrec,
            isSuspend = isSuspend,
            isOperator = isOperator,
            isInfix = isInfix,
            isExternal = isExternal,
            containerSource = containerSource
        ).also { fn ->
            fn.copyAttributes(this)
            val propertySymbol = correspondingPropertySymbol
            if (propertySymbol != null) {
                fn.correspondingPropertySymbol = propertySymbol
                if (propertySymbol.owner.getter == this) {
                    propertySymbol.owner.getter = fn
                }
                if (propertySymbol.owner.setter == this) {
                    propertySymbol.owner.setter = fn
                }
            }
            fn.parent = parent
            fn.copyTypeParametersFrom(this)

            fun IrType.remapTypeParameters(): IrType =
                remapTypeParameters(this@copy, fn)

            fn.returnType = returnType.remapTypeParameters()

            fn.dispatchReceiverParameter = dispatchReceiverParameter?.copyTo(fn)
            fn.extensionReceiverParameter = extensionReceiverParameter?.copyTo(fn)
            fn.valueParameters = valueParameters.map { param ->
                // Composable lambdas will always have `IrGet`s of all of their parameters
                // generated, since they are passed into the restart lambda. This causes an
                // interesting corner case with "anonymous parameters" of composable functions.
                // If a parameter is anonymous (using the name `_`) in user code, you can usually
                // make the assumption that it is never used, but this is technically not the
                // case in composable lambdas. The synthetic name that kotlin generates for
                // anonymous parameters has an issue where it is not safe to dex, so we sanitize
                // the names here to ensure that dex is always safe.
                val newName = dexSafeName(param.name)
                param.copyTo(
                    fn,
                    name = newName,
                    isAssignable = param.defaultValue != null,
                    defaultValue = param.defaultValue?.copyWithNewTypeParams(
                        source = this, target = fn
                    )
                )
            }
            fn.contextReceiverParametersCount = contextReceiverParametersCount
            fn.annotations = annotations.toList()
            fn.metadata = metadata
            fn.body = moveBodyTo(fn)?.copyWithNewTypeParams(this, fn)
        }
    }

    private fun jvmNameAnnotation(name: String): IrConstructorCall {
        val jvmName = getTopLevelClass(JvmStandardClassIds.Annotations.JvmName)
        val ctor = jvmName.constructors.first { it.owner.isPrimary }
        val type = jvmName.createType(false, emptyList())
        return IrConstructorCallImpl(
            startOffset = UNDEFINED_OFFSET,
            endOffset = UNDEFINED_OFFSET,
            type = type,
            symbol = ctor,
            typeArgumentsCount = 0,
            constructorTypeArgumentsCount = 0,
        ).also {
            it.putValueArgument(
                0,
                IrConstImpl.string(
                    UNDEFINED_OFFSET,
                    UNDEFINED_OFFSET,
                    builtIns.stringType,
                    name
                )
            )
        }
    }

    private fun IrSimpleFunction.requiresDefaultParameter(): Boolean =
    // we only add a default mask parameter if one of the parameters has a default
    // expression. Note that if this is a "fake override" method, then only the overridden
        // symbols will have the default value expressions
        valueParameters.any { it.defaultValue != null } ||
                overriddenSymbols.any { it.owner.requiresDefaultParameter() }

    private fun IrSimpleFunction.hasDefaultExpressionDefinedForValueParameter(index: Int): Boolean {
        // checking for default value isn't enough, you need to ensure that none of the overrides
        // have it as well...
        if (valueParameters[index].defaultValue != null) return true

        return overriddenSymbols.any {
            // ComposableFunInterfaceLowering copies extension receiver as a value
            // parameter, which breaks indices for overrides. fun interface cannot
            // have parameters defaults, however, so we can skip here if mismatch is detected.
            it.owner.valueParameters.size == valueParameters.size &&
                    it.owner.hasDefaultExpressionDefinedForValueParameter(index)
        }
    }

    private fun IrSimpleFunction.copyWithComposerParam(): IrSimpleFunction {
        assert(explicitParameters.lastOrNull()?.name != ComposeNames.COMPOSER_PARAMETER) {
            "Attempted to add composer param to $this, but it has already been added."
        }
        return copy().also { fn ->
            val oldFn = this

            // NOTE: it's important to add these here before we recurse into the body in
            // order to avoid an infinite loop on circular/recursive calls
            transformedFunctionSet.add(fn)
            transformedFunctions[oldFn] = fn

            // The overridden symbols might also be composable functions, so we want to make sure
            // and transform them as well
            fn.overriddenSymbols = overriddenSymbols.map {
                it.owner.withComposerParamIfNeeded().symbol
            }

            // if we are transforming a composable property, the jvm signature of the
            // corresponding getters and setters have a composer parameter. Since Kotlin uses the
            // lack of a parameter to determine if it is a getter, this breaks inlining for
            // composable property getters since it ends up looking for the wrong jvmSignature.
            // In this case, we manually add the appropriate "@JvmName" annotation so that the
            // inliner doesn't get confused.
            fn.correspondingPropertySymbol?.let { propertySymbol ->
                if (!fn.hasAnnotation(DescriptorUtils.JVM_NAME)) {
                    val propertyName = propertySymbol.owner.name.identifier
                    val name = if (fn.isGetter) {
                        JvmAbi.getterName(propertyName)
                    } else {
                        JvmAbi.setterName(propertyName)
                    }
                    fn.annotations += jvmNameAnnotation(name)
                }
            }

            val valueParametersMapping = explicitParameters
                .zip(fn.explicitParameters)
                .toMap()

            val currentParams = fn.valueParameters.size
            val realParams = currentParams - fn.contextReceiverParametersCount

            // $composer
            val composerParam = fn.addValueParameter {
                name = ComposeNames.COMPOSER_PARAMETER
                type = composerType.makeNullable()
                origin = IrDeclarationOrigin.DEFINED
                isAssignable = true
            }

            // $changed[n]
            val changed = ComposeNames.CHANGED_PARAMETER.identifier
            for (i in 0 until changedParamCount(realParams, fn.thisParamCount)) {
                fn.addValueParameter(
                    if (i == 0) changed else "$changed$i",
                    context.irBuiltIns.intType
                )
            }

            // $default[n]
            if (oldFn.requiresDefaultParameter()) {
                val defaults = ComposeNames.DEFAULT_PARAMETER.identifier
                for (i in 0 until defaultParamCount(currentParams)) {
                    fn.addValueParameter(
                        if (i == 0) defaults else "$defaults$i",
                        context.irBuiltIns.intType,
                        IrDeclarationOrigin.MASK_FOR_DEFAULT_FUNCTION
                    )
                }
            }

            fn.makeStubForDefaultValuesIfNeeded()?.also {
                when (val parent = fn.parent) {
                    is IrClass -> parent.addChild(it)
                    is IrFile -> parent.addChild(it)
                    else -> {
                        // ignore
                    }
                }
            }

            // update parameter types so they are ready to accept the default values
            fn.valueParameters.fastForEach { param ->
                val newType = defaultParameterType(
                    param,
                    fn.hasDefaultExpressionDefinedForValueParameter(param.index)
                )
                param.type = newType
            }

            inlineLambdaInfo.scan(fn)

            fn.transformChildrenVoid(object : IrElementTransformerVoid() {
                var isNestedScope = false
                override fun visitGetValue(expression: IrGetValue): IrGetValue {
                    val newParam = valueParametersMapping[expression.symbol.owner]
                    return if (newParam != null) {
                        IrGetValueImpl(
                            expression.startOffset,
                            expression.endOffset,
                            expression.type,
                            newParam.symbol,
                            expression.origin
                        )
                    } else expression
                }

                override fun visitReturn(expression: IrReturn): IrExpression {
                    if (expression.returnTargetSymbol == oldFn.symbol) {
                        // update the return statement to point to the new function, or else
                        // it will be interpreted as a non-local return
                        return super.visitReturn(
                            IrReturnImpl(
                                expression.startOffset,
                                expression.endOffset,
                                expression.type,
                                fn.symbol,
                                expression.value
                            )
                        )
                    }
                    return super.visitReturn(expression)
                }

                override fun visitFunction(declaration: IrFunction): IrStatement {
                    val wasNested = isNestedScope
                    try {
                        // we don't want to pass the composer parameter in to composable calls
                        // inside of nested scopes.... *unless* the scope was inlined.
                        isNestedScope = wasNested ||
                                !inlineLambdaInfo.isInlineLambda(declaration) ||
                                declaration.hasComposableAnnotation()
                        return super.visitFunction(declaration)
                    } finally {
                        isNestedScope = wasNested
                    }
                }

                override fun visitCall(expression: IrCall): IrExpression {
                    val expr = if (!isNestedScope) {
                        expression.withComposerParamIfNeeded(composerParam)
                    } else
                        expression
                    return super.visitCall(expr)
                }
            })
        }
    }

    private fun defaultParameterType(
        param: IrValueParameter,
        hasDefaultValue: Boolean,
    ): IrType {
        val type = param.type
        if (!hasDefaultValue) return type
        val constructorAccessible = !type.isPrimitiveType() &&
                type.classOrNull?.owner?.primaryConstructor != null
        return when {
            type.isPrimitiveType() -> type
            type.isInlineClassType() -> if (context.platform.isJvm() || constructorAccessible) {
                if (type.unboxInlineClass().isPrimitiveType()) {
                    type
                } else {
                    type.makeNullable()
                }
            } else {
                // k/js and k/native: private constructors of value classes can be not accessible.
                // Therefore it won't be possible to create a "fake" default argument for calls.
                // Making it nullable allows to pass null.
                type.makeNullable()
            }
            else -> type.makeNullable()
        }
    }

    /**
     * Creates stubs for @Composable function with value class parameters that have a default and
     * are wrapping a non-primitive instance.
     * Before Compose compiler 1.5.12, not specifying such parameter resulted in a nullptr exception
     * (b/330655412) at runtime, caused by Kotlin compiler inserting checkParameterNotNull.
     *
     * Converting such parameters to be nullable caused a binary compatibility issue because
     * nullability changed the value class mangle on a function signature. This stub creates a
     * binary compatible function to support old compilers while redirecting to a new function.
     */
    private fun IrSimpleFunction.makeStubForDefaultValuesIfNeeded(): IrSimpleFunction? {
        if (!visibility.isPublicAPI && !hasAnnotation(PublishedApiFqName)) {
            return null
        }

        if (!hasComposableAnnotation()) {
            return null
        }

        var makeStub = false
        for (i in valueParameters.indices) {
            val param = valueParameters[i]
            if (
                hasDefaultExpressionDefinedForValueParameter(i) &&
                param.type.isInlineClassType() &&
                !param.type.isNullable() &&
                param.type.unboxInlineClass().let {
                    !it.isPrimitiveType() && !it.isNullable()
                }
            ) {
                makeStub = true
                break
            }
        }

        if (!makeStub) {
            return null
        }

        val source = this
        val copy = source.deepCopyWithSymbols(parent)
        copy.valueParameters.fastForEach {
            it.defaultValue = null
        }
        copy.origin = ComposeDefaultValueStubOrigin
        copy.annotations += IrConstructorCallImpl(
            startOffset = UNDEFINED_OFFSET,
            endOffset = UNDEFINED_OFFSET,
            type = jvmSyntheticIrClass.defaultType,
            symbol = jvmSyntheticIrClass.primaryConstructor!!.symbol,
            typeArgumentsCount = 0,
            constructorTypeArgumentsCount = 0,
        )
        copy.body = context.irFactory.createBlockBody(UNDEFINED_OFFSET, UNDEFINED_OFFSET) {
            statements.add(
                IrReturnImpl(
                    UNDEFINED_OFFSET,
                    UNDEFINED_OFFSET,
                    copy.returnType,
                    copy.symbol,
                    irCall(source).apply {
                        dispatchReceiver = copy.dispatchReceiverParameter?.let { irGet(it) }
                        extensionReceiver = copy.extensionReceiverParameter?.let { irGet(it) }
                        copy.typeParameters.fastForEachIndexed { index, param ->
                            putTypeArgument(index, param.defaultType)
                        }
                        copy.valueParameters.fastForEachIndexed { index, param ->
                            putValueArgument(index, irGet(param))
                        }
                    }
                )
            )
        }

        transformedFunctions[copy] = copy
        transformedFunctionSet += copy

        return copy
    }

    private val PublishedApiFqName = StandardClassIds.Annotations.PublishedApi.asSingleFqName()

    object ComposeDefaultValueStubOrigin : IrDeclarationOrigin {
        override val name = "ComposeDefaultValueStubOrigin"
        override val isSynthetic = true
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy