All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.backend.jvm.lower.FunctionReferenceLowering.kt Maven / Gradle / Ivy

There is a newer version: 2.1.0-RC
Show newest version
/*
 * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.backend.jvm.lower

import org.jetbrains.kotlin.backend.common.FileLoweringPass
import org.jetbrains.kotlin.backend.common.IrElementTransformerVoidWithContext
import org.jetbrains.kotlin.backend.common.ir.moveBodyTo
import org.jetbrains.kotlin.backend.common.lower.LoweredStatementOrigins
import org.jetbrains.kotlin.backend.common.lower.SamEqualsHashCodeMethodsGenerator
import org.jetbrains.kotlin.backend.common.lower.VariableRemapper
import org.jetbrains.kotlin.backend.common.phaser.PhaseDescription
import org.jetbrains.kotlin.backend.jvm.JvmBackendContext
import org.jetbrains.kotlin.backend.jvm.JvmLoweredDeclarationOrigin
import org.jetbrains.kotlin.backend.jvm.JvmSymbols
import org.jetbrains.kotlin.backend.jvm.ir.*
import org.jetbrains.kotlin.backend.jvm.lower.indy.*
import org.jetbrains.kotlin.config.JvmClosureGenerationScheme
import org.jetbrains.kotlin.config.LanguageFeature
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.*
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.declarations.impl.IrVariableImpl
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.impl.IrVariableSymbolImpl
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.load.java.JvmAnnotationNames
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.resolve.jvm.annotations.JVM_SERIALIZABLE_LAMBDA_ANNOTATION_FQ_NAME

@PhaseDescription(
    name = "FunctionReference",
    description = "Construct instances of anonymous KFunction subclasses for function references"
)
internal class FunctionReferenceLowering(private val context: JvmBackendContext) : FileLoweringPass, IrElementTransformerVoidWithContext() {
    private val crossinlineLambdas = HashSet()

    private val IrFunctionReference.isIgnored: Boolean
        get() = (!type.isFunctionOrKFunction() && !isSuspendFunctionReference()) || origin == LoweredStatementOrigins.INLINE_LAMBDA

    // `suspend` function references are the same as non-`suspend` ones, just with an extra continuation parameter;
    // however, suspending lambdas require different generation implemented in SuspendLambdaLowering
    // because they are also their own continuation classes.
    // TODO: Currently, origin of callable references explicitly written in source code is null. Do we need to create one?
    private fun IrFunctionReference.isSuspendFunctionReference(): Boolean = isSuspend &&
            (origin == null || origin == IrStatementOrigin.ADAPTED_FUNCTION_REFERENCE || origin == IrStatementOrigin.SUSPEND_CONVERSION)

    override fun lower(irFile: IrFile) {
        irFile.findInlineLambdas(context) { argument, _, parameter, _ ->
            if (parameter.isCrossinline) {
                crossinlineLambdas.add(argument.symbol.owner as IrSimpleFunction)
            }
        }
        irFile.transformChildrenVoid(this)
        crossinlineLambdas.clear()
    }

    private val shouldGenerateIndySamConversions =
        context.config.samConversionsScheme == JvmClosureGenerationScheme.INDY

    private val shouldGenerateIndyLambdas =
        context.config.lambdasScheme == JvmClosureGenerationScheme.INDY

    private val shouldGenerateLightweightLambdas =
        shouldGenerateIndyLambdas && context.config.languageVersionSettings.supportsFeature(LanguageFeature.LightweightLambdas)

    private val isJavaSamConversionWithEqualsHashCode =
        context.config.languageVersionSettings.supportsFeature(LanguageFeature.JavaSamConversionEqualsHashCode)

    override fun visitBlock(expression: IrBlock): IrExpression {
        return processBlock(expression, forceSerializability = false)
    }

    private fun processBlock(expression: IrBlock, forceSerializability: Boolean): IrExpression {
        if (!expression.origin.isLambda)
            return super.visitBlock(expression)

        val reference = expression.statements.last() as IrFunctionReference
        if (reference.isIgnored)
            return super.visitBlock(expression)

        expression.statements.dropLast(1).forEach { it.transform(this, null) }
        reference.transformChildrenVoid(this)

        if (shouldGenerateIndyLambdas) {
            val lambdaMetafactoryArguments =
                LambdaMetafactoryArgumentsBuilder(context, crossinlineLambdas)
                    .getLambdaMetafactoryArguments(reference, reference.type, plainLambda = true, forceSerializability)
            if (lambdaMetafactoryArguments is LambdaMetafactoryArguments) {
                return wrapLambdaReferenceWithIndySamConversion(expression, reference, lambdaMetafactoryArguments)
            }
            // TODO MetafactoryArgumentsResult.Failure.FunctionHazard?
        }

        return FunctionReferenceBuilder(reference).build()
    }

    private fun wrapLambdaReferenceWithIndySamConversion(
        expression: IrBlock,
        reference: IrFunctionReference,
        lambdaMetafactoryArguments: LambdaMetafactoryArguments,
    ): IrBlock {
        val indySamConversion = wrapWithIndySamConversion(reference.type, lambdaMetafactoryArguments)
        expression.statements[expression.statements.size - 1] = indySamConversion
        expression.type = indySamConversion.type
        return expression
    }

    override fun visitFunctionReference(expression: IrFunctionReference): IrExpression {
        expression.transformChildrenVoid(this)
        return if (expression.isIgnored)
            expression
        else
            FunctionReferenceBuilder(expression).build()
    }

    private fun getDeclarationParentForDelegatingLambda(): IrDeclarationParent {
        for (s in allScopes.asReversed()) {
            val scopeOwner = s.scope.scopeOwnerSymbol.owner
            if (scopeOwner is IrDeclarationParent) {
                return scopeOwner
            }
        }
        throw AssertionError(
            "No IrDeclarationParent found in scopes:\n" +
                    allScopes.joinToString(separator = "\n") { "  " + it.scope.scopeOwnerSymbol.owner.render() }
        )
    }

    // Handle SAM conversions which wrap a function reference:
    //     class sam$n(private val receiver: R) : Interface { override fun method(...) = receiver.target(...) }
    //
    // This avoids materializing an invokable KFunction representing, thus producing one less class.
    // This is actually very common, as `Interface { something }` is a local function + a SAM-conversion
    // of a reference to it into an implementation.
    override fun visitTypeOperator(expression: IrTypeOperatorCall): IrExpression {
        if (expression.operator != IrTypeOperator.SAM_CONVERSION) {
            return super.visitTypeOperator(expression)
        }

        val samSuperType = expression.typeOperand

        val invokable = expression.argument
        val reference = if (invokable is IrFunctionReference) {
            invokable
        } else if (invokable is IrBlock && invokable.origin.isLambda && invokable.statements.last() is IrFunctionReference) {
            invokable.statements.dropLast(1).forEach { it.transform(this, null) }
            invokable.statements.last() as IrFunctionReference
        } else if (shouldGenerateIndySamConversions && canGenerateIndySamConversionOnFunctionalExpression(samSuperType, invokable)) {
            val lambdaBlock = SamDelegatingLambdaBuilder(context)
                .build(invokable, samSuperType, currentScope!!.scope.scopeOwnerSymbol, getDeclarationParentForDelegatingLambda())
            val lambdaMetafactoryArguments = LambdaMetafactoryArgumentsBuilder(context, crossinlineLambdas)
                .getLambdaMetafactoryArguments(lambdaBlock.ref, samSuperType, plainLambda = false, forceSerializability = false)

            if (lambdaMetafactoryArguments !is LambdaMetafactoryArguments) {
                // TODO MetafactoryArgumentsResult.Failure.FunctionHazard?
                return super.visitTypeOperator(expression)
            }

            // This is what IR contains if SAM-converted argument needs additional cast
            // See ArgumentsGenerationUtilsKt.castArgumentToFunctionalInterfaceForSamType (K1)
            // or AdapterGenerator.castArgumentToFunctionalInterfaceForSamType (K2)
            // In this case we should propagate serialization flag from the SAM type regardless of delegating lambda type (see KT-70306)
            if (invokable is IrTypeOperatorCall && invokable.operator == IrTypeOperator.IMPLICIT_CAST && invokable.argument is IrBlock) {
                invokable.argument = processBlock(invokable.argument as IrBlock, lambdaMetafactoryArguments.shouldBeSerializable)
            } else {
                invokable.transformChildrenVoid()
            }

            return wrapSamDelegatingLambdaWithIndySamConversion(samSuperType, lambdaBlock, lambdaMetafactoryArguments)
        } else {
            return super.visitTypeOperator(expression)
        }

        reference.transformChildrenVoid()

        if (shouldGenerateIndySamConversions) {
            val lambdaMetafactoryArguments =
                LambdaMetafactoryArgumentsBuilder(context, crossinlineLambdas)
                    .getLambdaMetafactoryArguments(reference, samSuperType, plainLambda = false, forceSerializability = false)
            if (lambdaMetafactoryArguments is LambdaMetafactoryArguments) {
                return wrapSamConversionArgumentWithIndySamConversion(expression) { samType ->
                    wrapWithIndySamConversion(samType, lambdaMetafactoryArguments, expression.startOffset, expression.endOffset)
                }
            } else if (lambdaMetafactoryArguments is MetafactoryArgumentsResult.Failure.FunctionHazard) {
                // Try wrapping function with a proxy local function and see if that helps.
                val proxyLocalFunBlock = createProxyLocalFunctionForIndySamConversion(reference)
                val proxyLocalFunRef = proxyLocalFunBlock.statements[proxyLocalFunBlock.statements.size - 1] as IrFunctionReference
                val proxyLambdaMetafactoryArguments = LambdaMetafactoryArgumentsBuilder(context, crossinlineLambdas)
                    .getLambdaMetafactoryArguments(proxyLocalFunRef, samSuperType, plainLambda = false, forceSerializability = false)
                if (proxyLambdaMetafactoryArguments is LambdaMetafactoryArguments) {
                    return wrapSamConversionArgumentWithIndySamConversion(expression) { samType ->
                        proxyLocalFunBlock.statements[proxyLocalFunBlock.statements.size - 1] =
                            wrapWithIndySamConversion(samType, proxyLambdaMetafactoryArguments)
                        proxyLocalFunBlock.type = samType
                        proxyLocalFunBlock
                    }
                }
            }
        }

        // Erase generic arguments in the SAM type, because they are not easy to approximate correctly otherwise,
        // and LambdaMetafactory also uses erased type.
        val erasedSamSuperType = samSuperType.erasedUpperBound.rawType(context)

        return FunctionReferenceBuilder(reference, erasedSamSuperType).build()
    }

    private fun createProxyLocalFunctionForIndySamConversion(reference: IrFunctionReference): IrBlock {
        val startOffset = reference.startOffset
        val endOffset = reference.endOffset
        val targetFun = reference.symbol.owner

        // For a function reference with possibly bound value parameters
        //      [ dispatchReceiver = ..., extensionReceiver = ..., ... ]::foo
        // create a proxy wrapper:
        //      {
        //          val tmp_proxy_0 = 
        //          ...
        //          val tmp_proxy_N = 
        //          fun `$proxy`(p_0: TP_0, ..., p_M: TP_M): TR =
        //              foo(... arg_J ...)
        //              // here for each J arg_J is either 'tmp_proxy_K' or 'p_K' for some K
        //          ::`$proxy`
        //      }

        val temporaryVals = ArrayList()

        val targetCall: IrFunctionAccessExpression =
            when (targetFun) {
                is IrSimpleFunction ->
                    IrCallImpl.fromSymbolOwner(startOffset, endOffset, targetFun.symbol)
                is IrConstructor ->
                    IrConstructorCallImpl.fromSymbolOwner(startOffset, endOffset, targetFun.returnType, targetFun.symbol)
            }

        val proxyFun = context.irFactory.buildFun {
            name = Name.identifier("${targetFun.name.asString()}__proxy")
            returnType = targetFun.returnType.eraseTypeParameters()
            visibility = DescriptorVisibilities.LOCAL
            modality = Modality.FINAL
            isSuspend = false
            isInline = false
            origin =
                if (targetFun.isInline || targetFun.isArrayOf())
                    JvmLoweredDeclarationOrigin.PROXY_FUN_FOR_METAFACTORY
                else
                    JvmLoweredDeclarationOrigin.SYNTHETIC_PROXY_FUN_FOR_METAFACTORY
        }.also { proxyFun ->
            proxyFun.parent = currentDeclarationParent
                ?: throw AssertionError("No declaration parent when processing $reference")

            var temporaryValIndex = 0
            var proxyParameterIndex = 0

            fun addAndGetTemporaryVal(initializer: IrExpression): IrGetValue {
                val tmpVal = IrVariableImpl(
                    startOffset, endOffset,
                    IrDeclarationOrigin.IR_TEMPORARY_VARIABLE,
                    IrVariableSymbolImpl(),
                    Name.identifier("tmp_proxy_${temporaryValIndex++}"),
                    initializer.type,
                    isVar = false, isConst = false, isLateinit = false
                )
                tmpVal.initializer = initializer
                tmpVal.parent = proxyFun.parent
                temporaryVals.add(tmpVal)
                return IrGetValueImpl(startOffset, endOffset, tmpVal.symbol)
            }

            fun addAndGetProxyValueParameter(originalParameter: IrValueParameter): IrGetValue {
                val proxyParameter = buildValueParameter(proxyFun) {
                    updateFrom(originalParameter)
                    name = Name.identifier("p$proxyParameterIndex\$${originalParameter.name.asString()}")
                    type = originalParameter.type.eraseTypeParameters()
                    proxyParameterIndex++
                }.apply {
                    parent = proxyFun
                }
                proxyFun.valueParameters += proxyParameter
                return IrGetValueImpl(startOffset, endOffset, proxyParameter.symbol)
            }

            fun getTargetCallArgument(boundValue: IrExpression?, originalParameter: IrValueParameter?): IrExpression? =
                when {
                    boundValue != null ->
                        addAndGetTemporaryVal(boundValue)
                    originalParameter != null ->
                        addAndGetProxyValueParameter(originalParameter)
                    else ->
                        null
                }

            targetCall.dispatchReceiver = getTargetCallArgument(reference.dispatchReceiver, targetFun.dispatchReceiverParameter)
            targetCall.extensionReceiver = getTargetCallArgument(reference.extensionReceiver, targetFun.extensionReceiverParameter)
            for ((valueParameterIndex, valueParameter) in targetFun.valueParameters.withIndex()) {
                targetCall.putValueArgument(
                    valueParameterIndex,
                    getTargetCallArgument(
                        reference.getValueArgument(valueParameterIndex),
                        valueParameter
                    )
                )
            }

            for (typeParameterIndex in targetFun.typeParameters.indices) {
                targetCall.putTypeArgument(typeParameterIndex, reference.getTypeArgument(typeParameterIndex))
            }

            val proxyFunBody = context.irFactory.createBlockBody(startOffset, endOffset).also { proxyFun.body = it }
            when {
                targetFun.returnType.isUnit() -> {
                    proxyFunBody.statements.add(targetCall)
                }
                else -> {
                    proxyFunBody.statements.add(
                        IrReturnImpl(
                            startOffset, endOffset,
                            context.irBuiltIns.nothingType,
                            proxyFun.symbol,
                            targetCall
                        )
                    )
                }
            }
        }

        val proxyFunRef = IrFunctionReferenceImpl(
            startOffset, endOffset,
            reference.type,
            proxyFun.symbol,
            0, // TODO generic function reference?
            proxyFun.valueParameters.size
        )

        return IrBlockImpl(
            startOffset, endOffset,
            reference.type,
            origin = null,
            temporaryVals + proxyFun + proxyFunRef
        )
    }

    private fun canGenerateIndySamConversionOnFunctionalExpression(samSuperType: IrType, expression: IrExpression): Boolean {
        val samClass = samSuperType.classOrNull
            ?: throw AssertionError("Class type expected: ${samSuperType.render()}")
        if (!samClass.owner.isFromJava() || isJavaSamConversionWithEqualsHashCode)
            return false
        if (expression is IrBlock && expression.origin == IrStatementOrigin.ADAPTED_FUNCTION_REFERENCE)
            return false
        return true
    }

    private fun wrapSamDelegatingLambdaWithIndySamConversion(
        samSuperType: IrType,
        lambdaBlock: SamDelegatingLambdaBlock,
        lambdaMetafactoryArguments: LambdaMetafactoryArguments,
    ): IrExpression {
        val indySamConversion = wrapWithIndySamConversion(samSuperType, lambdaMetafactoryArguments)
        lambdaBlock.replaceRefWith(indySamConversion)
        return lambdaBlock.block
    }

    private fun wrapSamConversionArgumentWithIndySamConversion(
        expression: IrTypeOperatorCall,
        produceSamConversion: (IrType) -> IrExpression,
    ): IrExpression {
        val samType = expression.typeOperand
        return when (val argument = expression.argument) {
            is IrFunctionReference ->
                produceSamConversion(samType)
            is IrBlock ->
                wrapFunctionReferenceInsideBlockWithIndySamConversion(samType, argument, produceSamConversion)
            else -> throw AssertionError("Block or function reference expected: ${expression.render()}")
        }
    }

    private fun wrapFunctionReferenceInsideBlockWithIndySamConversion(
        samType: IrType,
        block: IrBlock,
        produceSamConversion: (IrType) -> IrExpression,
    ): IrExpression {
        val indySamConversion = produceSamConversion(samType)
        block.statements[block.statements.size - 1] = indySamConversion
        block.type = indySamConversion.type
        return block
    }

    private val jvmIndyLambdaMetafactoryIntrinsic = context.ir.symbols.indyLambdaMetafactoryIntrinsic

    private val specialNullabilityAnnotationsFqNames =
        setOf(
            JvmSymbols.FLEXIBLE_NULLABILITY_ANNOTATION_FQ_NAME,
            JvmAnnotationNames.ENHANCED_NULLABILITY_ANNOTATION,
        )

    private fun wrapWithIndySamConversion(
        samType: IrType,
        lambdaMetafactoryArguments: LambdaMetafactoryArguments,
        startOffset: Int = UNDEFINED_OFFSET,
        endOffset: Int = UNDEFINED_OFFSET,
    ): IrCall {
        val notNullSamType = samType.makeNotNull()
            .removeAnnotations { it.type.classFqName in specialNullabilityAnnotationsFqNames }
        return context.createJvmIrBuilder(currentScope!!, startOffset, endOffset).run {
            // See [org.jetbrains.kotlin.backend.jvm.JvmSymbols::indyLambdaMetafactoryIntrinsic].
            irCall(jvmIndyLambdaMetafactoryIntrinsic, notNullSamType).apply {
                putTypeArgument(0, notNullSamType)
                putValueArgument(0, irRawFunctionRef(lambdaMetafactoryArguments.samMethod))
                putValueArgument(1, lambdaMetafactoryArguments.implMethodReference)
                putValueArgument(2, irRawFunctionRef(lambdaMetafactoryArguments.fakeInstanceMethod))
                putValueArgument(3, irVarargOfRawFunctionRefs(lambdaMetafactoryArguments.extraOverriddenMethods))
                putValueArgument(4, irBoolean(lambdaMetafactoryArguments.shouldBeSerializable))
            }
        }
    }

    private fun IrBuilderWithScope.irRawFunctionRef(irFun: IrFunction) =
        irRawFunctionReference(context.irBuiltIns.anyType, irFun.symbol)

    private fun IrBuilderWithScope.irVarargOfRawFunctionRefs(irFuns: List) =
        irVararg(context.irBuiltIns.anyType, irFuns.map { irRawFunctionRef(it) })

    private inner class FunctionReferenceBuilder(val irFunctionReference: IrFunctionReference, val samSuperType: IrType? = null) {
        private val isLambda = irFunctionReference.origin.isLambda
        private val isLightweightLambda = isLambda
                && shouldGenerateLightweightLambdas
                && !irFunctionReference.symbol.owner.hasAnnotation(JVM_SERIALIZABLE_LAMBDA_ANNOTATION_FQ_NAME)
        private val isHeavyweightLambda = isLambda && !isLightweightLambda
        private val callee = irFunctionReference.symbol.owner

        // Only function references can bind a receiver and even then we can only bind either an extension or a dispatch receiver.
        // However, when we bind a value of an inline class type as a receiver, the receiver will turn into an argument of
        // the function in question. Yet we still need to record it as the "receiver" in CallableReference in order for reflection
        // to work correctly.
        private val boundReceiver: Pair? =
            if (callee.isJvmStaticInObject()) createFakeBoundReceiverForJvmStaticInObject()
            else irFunctionReference.getArgumentsWithIr().singleOrNull()

        // The type of the reference is KFunction
        private val parameterTypes = (irFunctionReference.type as IrSimpleType).arguments.map {
            when (it) {
                is IrTypeProjection -> it.type
                is IrStarProjection -> context.irBuiltIns.anyNType
            }
        }
        private val argumentTypes = parameterTypes.dropLast(1)
        private val referenceReturnType = parameterTypes.last()

        private val typeArgumentsMap = irFunctionReference.typeSubstitutionMap

        private val functionSuperClass =
            samSuperType?.classOrNull
                ?: if (irFunctionReference.isSuspend)
                    context.irBuiltIns.suspendFunctionN(argumentTypes.size).symbol
                else
                    context.irBuiltIns.functionN(argumentTypes.size).symbol
        private val superMethod =
            functionSuperClass.owner.getSingleAbstractMethod()
                ?: throw AssertionError("Not a SAM class: ${functionSuperClass.owner.render()}")

        private val adapteeCall: IrFunctionAccessExpression? =
            if (callee.origin == IrDeclarationOrigin.ADAPTER_FOR_CALLABLE_REFERENCE) {
                // The body of a callable reference adapter contains either only a call, or an IMPLICIT_COERCION_TO_UNIT type operator
                // applied to a call. That call's target is the original function which we need to get owner/name/signature.
                val call = when (val statement = callee.body!!.statements.single()) {
                    is IrTypeOperatorCall -> {
                        assert(statement.operator == IrTypeOperator.IMPLICIT_COERCION_TO_UNIT) {
                            "Unexpected type operator in ADAPTER_FOR_CALLABLE_REFERENCE: ${callee.render()}"
                        }
                        statement.argument
                    }
                    is IrReturn -> statement.value
                    else -> statement
                }
                if (call !is IrFunctionAccessExpression) {
                    throw UnsupportedOperationException("Unknown structure of ADAPTER_FOR_CALLABLE_REFERENCE: ${callee.render()}")
                }
                call
            } else {
                null
            }

        private val adaptedReferenceOriginalTarget: IrFunction? = adapteeCall?.symbol?.owner
        private val isFunInterfaceConstructorReference =
            callee.origin == IrDeclarationOrigin.ADAPTER_FOR_FUN_INTERFACE_CONSTRUCTOR
        private val constructedFunInterfaceSymbol: IrClassSymbol? =
            if (isFunInterfaceConstructorReference)
                callee.returnType.classOrNull
                    ?: throw AssertionError("Fun interface type expected: ${callee.returnType.render()}")
            else
                null
        private val isAdaptedReference =
            isFunInterfaceConstructorReference || adaptedReferenceOriginalTarget != null

        private val samInterface = samSuperType?.getClass()
        private val isKotlinFunInterface = samInterface != null && !samInterface.isFromJava()

        private val needToGenerateSamEqualsHashCodeMethods =
            (isKotlinFunInterface || isJavaSamConversionWithEqualsHashCode) &&
                    (isAdaptedReference || !isLambda)

        private val superType =
            samSuperType
                ?: when {
                    isLightweightLambda -> context.ir.symbols.any
                    isHeavyweightLambda -> context.ir.symbols.lambdaClass
                    isFunInterfaceConstructorReference -> context.ir.symbols.funInterfaceConstructorReferenceClass
                    else -> when {
                        isAdaptedReference -> context.ir.symbols.adaptedFunctionReference
                        else -> context.ir.symbols.functionReferenceImpl
                    }
                }.defaultType

        private val functionReferenceClass = context.irFactory.buildClass {
            setSourceRange(irFunctionReference)
            visibility = DescriptorVisibilities.LOCAL
            // A callable reference results in a synthetic class, while a lambda is not synthetic.
            // We don't produce GENERATED_SAM_IMPLEMENTATION, which is always synthetic.
            origin = if (isLambda) JvmLoweredDeclarationOrigin.LAMBDA_IMPL else JvmLoweredDeclarationOrigin.FUNCTION_REFERENCE_IMPL
            name = SpecialNames.NO_NAME_PROVIDED
        }.apply {
            parent = currentDeclarationParent ?: error("No current declaration parent at ${irFunctionReference.dump()}")
            superTypes = listOfNotNull(
                superType,
                if (samSuperType == null)
                    functionSuperClass.typeWith(parameterTypes)
                else null,
                if (needToGenerateSamEqualsHashCodeMethods)
                    context.ir.symbols.functionAdapter.defaultType
                else null,
            )
            if (samInterface != null && origin == JvmLoweredDeclarationOrigin.LAMBDA_IMPL) {
                // Old back-end generates formal type parameters as in SAM supertype.
                // Here we create formal type parameters with same names and equivalent upper bounds.
                // We don't really perform any type substitutions within class body
                // (it's all fine as soon as we have required generic signatures and don't fail anywhere).
                // NB this would no longer matter if we generate SAM wrapper classes as synthetic.
                typeParameters = createFakeFormalTypeParameters(samInterface.typeParameters, this)
            }
            createImplicitParameterDeclarationWithWrappedDescriptor()
            copyAttributes(irFunctionReference)
            if (isHeavyweightLambda) {
                metadata = irFunctionReference.symbol.owner.metadata
            }
        }

        private fun createFakeFormalTypeParameters(sourceTypeParameters: List, irClass: IrClass): List {
            if (sourceTypeParameters.isEmpty()) return emptyList()

            val fakeTypeParameters = sourceTypeParameters.map {
                buildTypeParameter(irClass) {
                    updateFrom(it)
                    name = it.name
                }
            }
            val typeRemapper = IrTypeParameterRemapper(sourceTypeParameters.associateWith { fakeTypeParameters[it.index] })
            for (fakeTypeParameter in fakeTypeParameters) {
                val sourceTypeParameter = sourceTypeParameters[fakeTypeParameter.index]
                fakeTypeParameter.superTypes = sourceTypeParameter.superTypes.map { typeRemapper.remapType(it) }
            }

            return fakeTypeParameters
        }

        fun build(): IrExpression = context.createJvmIrBuilder(currentScope!!).run {
            irBlock(irFunctionReference.startOffset, irFunctionReference.endOffset) {
                val constructor = createConstructor()
                val boundReceiverVar =
                    if (samSuperType != null && boundReceiver != null) {
                        irTemporary(boundReceiver.second)
                    } else null
                createInvokeMethod(boundReceiverVar)

                if (needToGenerateSamEqualsHashCodeMethods) {
                    generateSamEqualsHashCodeMethods(boundReceiverVar)
                }
                if (isKotlinFunInterface) {
                    functionReferenceClass.addFakeOverrides(backendContext.typeSystem)
                }

                +functionReferenceClass
                +irCall(constructor.symbol).apply {
                    if (valueArgumentsCount > 0) putValueArgument(0, boundReceiver!!.second)
                }
            }
        }

        private fun JvmIrBuilder.generateSamEqualsHashCodeMethods(boundReceiverVar: IrVariable?) {
            checkNotNull(samSuperType) { "equals/hashCode can only be generated for fun interface wrappers: ${callee.render()}" }

            SamEqualsHashCodeMethodsGenerator(backendContext, functionReferenceClass, samSuperType) {
                val internalClass = when {
                    isAdaptedReference -> backendContext.ir.symbols.adaptedFunctionReference
                    else -> backendContext.ir.symbols.functionReferenceImpl
                }
                val constructor = internalClass.owner.constructors.single {
                    // arity, [receiver], owner, name, signature, flags
                    it.valueParameters.size == 1 + (if (boundReceiver != null) 1 else 0) + 4
                }
                irCallConstructor(constructor.symbol, emptyList()).apply {
                    generateConstructorCallArguments(this) { irGet(boundReceiverVar!!) }
                }
            }.generate()
        }

        private fun createConstructor(): IrConstructor =
            functionReferenceClass.addConstructor {
                origin = JvmLoweredDeclarationOrigin.GENERATED_MEMBER_IN_CALLABLE_REFERENCE
                returnType = functionReferenceClass.defaultType
                isPrimary = true
            }.apply {
                if (samSuperType == null && boundReceiver != null) {
                    addValueParameter("receiver", context.irBuiltIns.anyNType)
                }

                // Super constructor:
                // - For fun interface constructor references, super class is kotlin.jvm.internal.FunInterfaceConstructorReference
                //   with single constructor 'public FunInterfaceConstructorReference(Class funInterface)'
                // - For SAM references, the super class is Any
                // - For lambdas, accepts arity
                // - For optimized function references (1.4+), accepts:
                //       arity, [receiver], owner, name, signature, flags
                // - For unoptimized function references, accepts:
                //       arity, [receiver]
                val constructor =
                    when {
                        isFunInterfaceConstructorReference ->
                            context.ir.symbols.funInterfaceConstructorReferenceClass.owner.constructors.single()
                        samSuperType != null ->
                            context.irBuiltIns.anyClass.owner.constructors.single()
                        else -> {
                            val expectedArity =
                                if (isLightweightLambda && !isAdaptedReference) 0
                                else if (isHeavyweightLambda && !isAdaptedReference) 1
                                else 1 + (if (boundReceiver != null) 1 else 0) + 4
                            superType.getClass()!!.constructors.single {
                                it.valueParameters.size == expectedArity
                            }
                        }
                    }

                body = context.createJvmIrBuilder(symbol).run {
                    irBlockBody(startOffset, endOffset) {
                        +irDelegatingConstructorCall(constructor).also { call ->
                            if (samSuperType == null) {
                                generateConstructorCallArguments(call) { irGet(valueParameters.first()) }
                            }
                        }
                        +IrInstanceInitializerCallImpl(startOffset, endOffset, functionReferenceClass.symbol, context.irBuiltIns.unitType)
                    }
                }
            }

        private fun JvmIrBuilder.generateConstructorCallArguments(
            call: IrFunctionAccessExpression,
            generateBoundReceiver: IrBuilder.() -> IrExpression,
        ) {
            if (isFunInterfaceConstructorReference) {
                val funInterfaceKClassRef = kClassReference(constructedFunInterfaceSymbol!!.owner.defaultType)
                val funInterfaceJavaClassRef = kClassToJavaClass(funInterfaceKClassRef)
                call.putValueArgument(0, funInterfaceJavaClassRef)
            } else {
                var index = 0
                if (!isLightweightLambda) {
                    call.putValueArgument(index++, irInt(argumentTypes.size + if (irFunctionReference.isSuspend) 1 else 0))
                }
                if (boundReceiver != null) {
                    call.putValueArgument(index++, generateBoundReceiver())
                }
                if (!isLambda) {
                    val callableReferenceTarget = adaptedReferenceOriginalTarget ?: callee
                    val owner = calculateOwnerKClass(callableReferenceTarget.parent)
                    call.putValueArgument(index++, kClassToJavaClass(owner))
                    call.putValueArgument(index++, irString(callableReferenceTarget.originalName.asString()))
                    call.putValueArgument(index++, generateSignature(callableReferenceTarget.symbol))
                    call.putValueArgument(index, irInt(getFunctionReferenceFlags(callableReferenceTarget)))
                }
            }
        }

        private fun getFunctionReferenceFlags(callableReferenceTarget: IrFunction): Int {
            val isTopLevelBit = callableReferenceTarget.getCallableReferenceTopLevelFlag()
            val adaptedCallableReferenceFlags = getAdaptedCallableReferenceFlags()
            return isTopLevelBit + (adaptedCallableReferenceFlags shl 1)
        }

        private fun getAdaptedCallableReferenceFlags(): Int {
            if (adaptedReferenceOriginalTarget == null) return 0

            val isVarargMappedToElementBit = if (hasVarargMappedToElement()) 1 else 0
            val isSuspendConvertedBit = if (!adaptedReferenceOriginalTarget.isSuspend && callee.isSuspend) 1 else 0
            val isCoercedToUnitBit = if (!adaptedReferenceOriginalTarget.returnType.isUnit() && callee.returnType.isUnit()) 1 else 0

            return isVarargMappedToElementBit +
                    (isSuspendConvertedBit shl 1) +
                    (isCoercedToUnitBit shl 2)
        }

        private fun hasVarargMappedToElement(): Boolean {
            if (adapteeCall == null) return false
            for (i in 0 until adapteeCall.valueArgumentsCount) {
                val arg = adapteeCall.getValueArgument(i) ?: continue
                if (arg !is IrVararg) continue
                for (varargElement in arg.elements) {
                    if (varargElement is IrGetValue) return true
                }
            }
            return false
        }

        private fun createInvokeMethod(receiverVar: IrValueDeclaration?): IrSimpleFunction =
            functionReferenceClass.addFunction {
                setSourceRange(if (isLambda) callee else irFunctionReference)
                name = superMethod.name
                returnType = referenceReturnType
                isSuspend = callee.isSuspend
            }.apply {
                metadata = functionReferenceClass.metadata
                overriddenSymbols += superMethod.symbol
                dispatchReceiverParameter = buildReceiverParameter(
                    this,
                    IrDeclarationOrigin.INSTANCE_RECEIVER,
                    functionReferenceClass.symbol.defaultType
                )

                when {
                    isLambda ->
                        createLambdaInvokeMethod()
                    isFunInterfaceConstructorReference ->
                        createFunInterfaceConstructorInvokeMethod()
                    else ->
                        createFunctionReferenceInvokeMethod(receiverVar)
                }
            }

        // Inline the body of an anonymous function into the generated lambda subclass.
        private fun IrSimpleFunction.createLambdaInvokeMethod() {
            annotations += callee.annotations
            val valueParameterMap = callee.explicitParameters.associate { param ->
                param to param.copyTo(this)
            }
            valueParameters += valueParameterMap.values
            body = callee.moveBodyTo(this, valueParameterMap)
        }

        private fun IrSimpleFunction.createFunInterfaceConstructorInvokeMethod() {
            val adapterValueParameter = callee.valueParameters.singleOrNull()
                ?: throw AssertionError("Single value parameter expected: ${callee.render()}")
            val invokeValueParameter = adapterValueParameter.copyTo(this)
            val valueParameterMap = mapOf(adapterValueParameter to invokeValueParameter)
            valueParameters = listOf(invokeValueParameter)
            body = callee.moveBodyTo(this, valueParameterMap)
            callee.body = null
        }

        private fun IrSimpleFunction.createFunctionReferenceInvokeMethod(receiver: IrValueDeclaration?) {
            for ((index, argumentType) in argumentTypes.withIndex()) {
                addValueParameter {
                    name = Name.identifier("p$index")
                    type = argumentType
                }
            }

            body = context.createJvmIrBuilder(symbol, startOffset, endOffset).run {
                var unboundIndex = 0
                val call = irCall(callee.symbol, referenceReturnType).apply {
                    for (typeParameter in irFunctionReference.symbol.owner.allTypeParameters) {
                        putTypeArgument(typeParameter.index, typeArgumentsMap[typeParameter.symbol])
                    }

                    for (parameter in callee.explicitParameters) {
                        when {
                            boundReceiver?.first == parameter ->
                                // Bound receiver parameter. For function references, this is stored in a field of the superclass.
                                // For sam references, we just capture the value in a local variable and LocalDeclarationsLowering
                                // will put it into a field.
                                if (samSuperType == null)
                                    irImplicitCast(
                                        irGetField(
                                            irGet(dispatchReceiverParameter!!),
                                            functionReferenceClass.getReceiverField(backendContext)
                                        ),
                                        boundReceiver.second.type
                                    )
                                else
                                    irGet(receiver!!)

                            unboundIndex >= argumentTypes.size ->
                                // Default value argument (this pass doesn't handle suspend functions, otherwise
                                // it could also be the continuation argument)
                                null

                            else ->
                                irGet(valueParameters[unboundIndex++])
                        }?.let { putArgument(callee, parameter, it) }
                    }
                }
                irExprBody(
                    inlineAdapterCallIfPossible(call, this@createFunctionReferenceInvokeMethod)
                )
            }
        }

        private fun inlineAdapterCallIfPossible(
            expression: IrFunctionAccessExpression,
            invokeMethod: IrSimpleFunction,
        ): IrExpression {
            val irCall = expression as? IrCall
                ?: return expression
            val callee = irCall.symbol.owner
            if (callee.origin != IrDeclarationOrigin.ADAPTER_FOR_CALLABLE_REFERENCE)
                return expression

            // TODO fix testSuspendUnitConversion
            if (callee.isSuspend) return expression

            // Callable reference adapter is a simple function that delegates to callable reference target,
            // adapting its signature for required functional type.
            // Usually it simply forwards arguments to target function.
            // It also passes 'receiver' field for bound references, with downcast to the actual receiver type.
            // In any case, adapter itself is synthetic and is not necessarily debuggable, so we can reuse variables freely.
            // Inlining adapter into 'invoke' saves us two methods (adapter & synthetic accessor).
            val adapterBody = callee.body as? IrBlockBody
            if (adapterBody == null || adapterBody.statements.size != 1)
                throw AssertionError("Unexpected adapter body: ${callee.dump()}")
            val resultStatement = adapterBody.statements[0]
            val resultExpression: IrExpression =
                when {
                    resultStatement is IrReturn ->
                        resultStatement.value
                    resultStatement is IrTypeOperatorCall && resultStatement.operator == IrTypeOperator.IMPLICIT_COERCION_TO_UNIT ->
                        resultStatement
                    resultStatement is IrCall ->
                        resultStatement
                    resultStatement is IrConstructorCall ->
                        resultStatement
                    else ->
                        throw AssertionError("Unexpected adapter body: ${callee.dump()}")
                }

            val startOffset = irCall.startOffset
            val endOffset = irCall.endOffset

            val callArguments = LinkedHashMap()
            val inlinedAdapterBlock = IrBlockImpl(startOffset, endOffset, irCall.type, origin = null)
            var tmpVarIndex = 0

            fun wrapIntoTemporaryVariableIfNecessary(expression: IrExpression): IrValueDeclaration {
                if (expression is IrGetValue)
                    return expression.symbol.owner
                if (expression !is IrTypeOperatorCall || expression.argument !is IrGetField)
                    throw AssertionError("Unexpected adapter argument:\n${expression.dump()}")
                val temporaryVar = IrVariableImpl(
                    startOffset, endOffset, IrDeclarationOrigin.IR_TEMPORARY_VARIABLE,
                    IrVariableSymbolImpl(),
                    Name.identifier("tmp_${tmpVarIndex++}"),
                    expression.type,
                    isVar = false, isConst = false, isLateinit = false
                )
                temporaryVar.parent = invokeMethod
                temporaryVar.initializer = expression
                inlinedAdapterBlock.statements.add(temporaryVar)
                return temporaryVar
            }

            callee.dispatchReceiverParameter?.let {
                callArguments[it] = wrapIntoTemporaryVariableIfNecessary(
                    irCall.dispatchReceiver
                        ?: throw AssertionError("No dispatch receiver in adapter call: ${irCall.dump()}")
                )
            }
            callee.extensionReceiverParameter?.let {
                callArguments[it] = wrapIntoTemporaryVariableIfNecessary(
                    irCall.extensionReceiver
                        ?: throw AssertionError("No extension receiver in adapter call: ${irCall.dump()}")
                )
            }
            for (valueParameter in callee.valueParameters) {
                callArguments[valueParameter] = wrapIntoTemporaryVariableIfNecessary(
                    irCall.getValueArgument(valueParameter.index)
                        ?: throw AssertionError("No value argument #${valueParameter.index} in adapter call: ${irCall.dump()}")
                )
            }

            val inlinedAdapterResult = resultExpression.transform(VariableRemapper(callArguments), null)
            inlinedAdapterBlock.statements.add(inlinedAdapterResult)

            callee.body = null
            return inlinedAdapterBlock.patchDeclarationParents(invokeMethod)
        }

        private val IrFunction.originalName: Name
            get() = metadata?.name ?: name

        private fun JvmIrBuilder.generateSignature(target: IrFunctionSymbol): IrExpression =
            irCall(backendContext.ir.symbols.signatureStringIntrinsic).apply {
                putValueArgument(
                    0,
                    //don't pass receivers otherwise LocalDeclarationLowering will create additional captured parameters
                    IrFunctionReferenceImpl(
                        UNDEFINED_OFFSET, UNDEFINED_OFFSET, irFunctionReference.type, target,
                        irFunctionReference.typeArgumentsCount, target.owner.valueParameters.size,
                        irFunctionReference.reflectionTarget, null
                    ).apply {
                        copyTypeArgumentsFrom(irFunctionReference)
                    }
                )
            }

        private fun createFakeBoundReceiverForJvmStaticInObject(): Pair {
            // JvmStatic functions in objects are special in that they are generated as static methods in the bytecode, and JVM IR lowers
            // both declarations and call sites early on in jvmStaticInObjectPhase because it's easier that way in subsequent lowerings.
            // However from the point of view of Kotlin language (and thus reflection), these functions still take the dispatch receiver
            // parameter of the object type. So we pretend here that a JvmStatic function in object has an additional dispatch receiver
            // parameter, so that the correct function reference object will be created and reflective calls will work at runtime.
            val objectClass = callee.parentAsClass
            return buildValueParameter(callee) {
                name = Name.identifier("\$this")
                type = objectClass.typeWith()
            } to IrGetObjectValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, objectClass.typeWith(), objectClass.symbol)
        }
    }

    companion object {
        internal fun JvmIrBuilder.calculateOwnerKClass(irContainer: IrDeclarationParent): IrExpression =
            kClassReference(irContainer.getCallableReferenceOwnerKClassType(backendContext))

        // This method creates an IrField for the field `kotlin.jvm.internal.CallableReference.receiver`, as if this field was declared
        // in the anonymous class for this callable reference. Technically it's incorrect because this field is not declared here. It would
        // be more correct to create a fake override but that seems like more work for no clear benefit. Codegen will generate the correct
        // field access anyway, even if the field is not present in this parent.
        // Note that it is necessary to generate an access to the field whose parent is this anonymous class, and NOT some supertype like
        // k.j.i.CallableReference, or k.j.i.FunctionReferenceImpl, because then AddSuperQualifierToJavaFieldAccess lowering would add
        // superQualifierSymbol, which would break inlining of bound function references, since inliner will not understand how to transform
        // this getfield instruction in the bytecode.
        internal fun IrClass.getReceiverField(context: JvmBackendContext): IrField =
            context.irFactory.buildField {
                name = Name.identifier("receiver")
                type = context.irBuiltIns.anyNType
                visibility = DescriptorVisibilities.PROTECTED
            }.apply {
                parent = this@getReceiverField
            }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy