org.jetbrains.kotlin.backend.jvm.lower.SuspendLambdaLowering.kt Maven / Gradle / Ivy
/*
* Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.backend.jvm.lower
import org.jetbrains.kotlin.backend.common.FileLoweringPass
import org.jetbrains.kotlin.backend.common.IrElementTransformerVoidWithContext
import org.jetbrains.kotlin.backend.common.ir.moveBodyTo
import org.jetbrains.kotlin.backend.common.lower.LocalDeclarationsLowering
import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.ir.util.parents
import org.jetbrains.kotlin.backend.common.phaser.makeIrFilePhase
import org.jetbrains.kotlin.backend.jvm.JvmBackendContext
import org.jetbrains.kotlin.backend.jvm.JvmLoweredDeclarationOrigin
import org.jetbrains.kotlin.backend.jvm.ir.hasChild
import org.jetbrains.kotlin.backend.jvm.ir.isReadOfCrossinline
import org.jetbrains.kotlin.codegen.coroutines.COROUTINE_LABEL_FIELD_NAME
import org.jetbrains.kotlin.codegen.coroutines.INVOKE_SUSPEND_METHOD_NAME
import org.jetbrains.kotlin.codegen.coroutines.SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME
import org.jetbrains.kotlin.codegen.coroutines.normalize
import org.jetbrains.kotlin.codegen.inline.coroutines.FOR_INLINE_SUFFIX
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.builders.declarations.*
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrBlock
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import org.jetbrains.kotlin.ir.expressions.IrGetValue
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.IrValueParameterSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.load.java.JavaDescriptorVisibilities
import org.jetbrains.kotlin.name.FqNameUnsafe
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.resolve.jvm.AsmTypes
import org.jetbrains.org.objectweb.asm.Type
import kotlin.collections.set
internal val suspendLambdaPhase = makeIrFilePhase(
::SuspendLambdaLowering,
"SuspendLambda",
"Transform suspend lambdas into continuation classes"
)
private fun IrFunction.capturesCrossinline(): Boolean {
val parents = parents.toSet()
return hasChild { it is IrGetValue && it.isReadOfCrossinline() && it.symbol.owner.parent in parents }
}
internal abstract class SuspendLoweringUtils(protected val context: JvmBackendContext) {
protected fun IrClass.addFunctionOverride(
function: IrSimpleFunction,
startOffset: Int = UNDEFINED_OFFSET,
endOffset: Int = UNDEFINED_OFFSET,
): IrSimpleFunction {
val overriddenType = superTypes.single { it.classifierOrFail == function.parentAsClass.symbol }
val typeSubstitution = (overriddenType.classifierOrFail.owner as IrClass).typeParameters
.map { it.symbol }
.zip((overriddenType as IrSimpleType).arguments.map { (it as IrTypeProjection).type }) // No star projections in this lowering
.toMap()
return addFunction(
function.name.asString(), function.returnType.substitute(typeSubstitution),
startOffset = startOffset, endOffset = endOffset
).apply {
overriddenSymbols = listOf(function.symbol)
valueParameters = function.valueParameters.map { it.copyTo(this, type = it.type.substitute(typeSubstitution)) }
}
}
protected fun IrClass.addFunctionOverride(
function: IrSimpleFunction,
startOffset: Int = UNDEFINED_OFFSET,
endOffset: Int = UNDEFINED_OFFSET,
makeBody: IrBlockBodyBuilder.(IrFunction) -> Unit
): IrSimpleFunction =
addFunctionOverride(function, startOffset, endOffset).apply {
body = context.createIrBuilder(symbol).irBlockBody { makeBody(this@apply) }
}
protected fun IrSimpleFunction.generateErrorForInlineBody() {
val message = "This is a stub representing a copy of a suspend method without the state machine " +
"(used by the inliner). Since the difference is at the bytecode level, the body is " +
"still on the original function. Use suspendForInlineToOriginal() to retrieve it."
body = context.irFactory.createExpressionBody(
startOffset,
endOffset,
IrErrorExpressionImpl(startOffset, endOffset, returnType, message),
)
}
protected fun IrFunction.addCompletionValueParameter(): IrValueParameter =
addValueParameter(SUSPEND_FUNCTION_COMPLETION_PARAMETER_NAME, continuationType())
protected fun IrFunction.continuationType(): IrType =
context.ir.symbols.continuationClass.typeWith(returnType).makeNullable()
}
private class SuspendLambdaLowering(context: JvmBackendContext) : SuspendLoweringUtils(context), FileLoweringPass {
override fun lower(irFile: IrFile) {
irFile.transformChildrenVoid(object : IrElementTransformerVoidWithContext() {
override fun visitBlock(expression: IrBlock): IrExpression {
val reference = expression.statements.lastOrNull() as? IrFunctionReference ?: return super.visitBlock(expression)
if (reference.isSuspend && reference.origin.isLambda) {
assert(expression.statements.size == 2 && expression.statements[0] is IrFunction)
expression.transformChildrenVoid(this)
val parent = currentDeclarationParent ?: error("No current declaration parent at ${reference.dump()}")
return generateAnonymousObjectForLambda(reference, parent)
}
return super.visitBlock(expression)
}
})
}
private fun generateAnonymousObjectForLambda(reference: IrFunctionReference, parent: IrDeclarationParent) =
context.createIrBuilder(reference.symbol).irBlock(reference.startOffset, reference.endOffset) {
assert(reference.getArgumentsWithIr().isEmpty()) { "lambda with bound arguments: ${reference.render()}" }
val continuation = generateContinuationClassForLambda(reference, parent)
+continuation
+irCall(continuation.constructors.single().symbol).apply {
// Pass null as completion parameter
putValueArgument(0, irNull())
}
}
private fun generateContinuationClassForLambda(reference: IrFunctionReference, parent: IrDeclarationParent): IrClass =
context.irFactory.buildClass {
name = SpecialNames.NO_NAME_PROVIDED
origin = JvmLoweredDeclarationOrigin.SUSPEND_LAMBDA
visibility = DescriptorVisibilities.LOCAL
}.apply {
this.parent = parent
createImplicitParameterDeclarationWithWrappedDescriptor()
copyAttributes(reference)
val function = reference.symbol.owner
val extensionReceiver = function.extensionReceiverParameter?.type?.classOrNull
val isRestricted = extensionReceiver != null && extensionReceiver.owner.annotations.any {
it.type.classOrNull?.isClassWithFqName(FqNameUnsafe("kotlin.coroutines.RestrictsSuspension")) == true
}
val suspendLambda =
if (isRestricted) context.ir.symbols.restrictedSuspendLambdaClass.owner
else context.ir.symbols.suspendLambdaClass.owner
val arity = (reference.type as IrSimpleType).arguments.size - 1
val functionNClass = context.ir.symbols.getJvmFunctionClass(arity + 1)
val functionNType = functionNClass.typeWith(
function.explicitParameters.subList(0, arity).map { it.type }
+ function.continuationType()
+ context.irBuiltIns.anyNType
)
superTypes = listOf(suspendLambda.defaultType, functionNType)
val usedParams = mutableSetOf()
// marking the parameters referenced in the function
function.acceptChildrenVoid(
object : IrElementVisitorVoid {
override fun visitElement(element: IrElement) = element.acceptChildrenVoid(this)
override fun visitGetValue(expression: IrGetValue) {
if (expression.symbol is IrValueParameterSymbol && expression.symbol.owner in function.explicitParameters) {
usedParams += expression.symbol.owner
}
}
},
)
addField(COROUTINE_LABEL_FIELD_NAME, context.irBuiltIns.intType, JavaDescriptorVisibilities.PACKAGE_VISIBILITY)
val varsCountByType = HashMap()
val parametersFields = function.explicitParameters.map {
val field = if (it in usedParams) addField {
val normalizedType = context.defaultTypeMapper.mapType(it.type).normalize()
val index = varsCountByType[normalizedType]?.plus(1) ?: 0
varsCountByType[normalizedType] = index
// Rename `$this` to avoid being caught by inlineCodegenUtils.isCapturedFieldName()
name = Name.identifier("${normalizedType.descriptor[0]}$$index")
type = if (normalizedType == AsmTypes.OBJECT_TYPE) context.irBuiltIns.anyNType else it.type
origin = LocalDeclarationsLowering.DECLARATION_ORIGIN_FIELD_FOR_CAPTURED_VALUE
isFinal = false
visibility = if (it.index < 0) DescriptorVisibilities.PRIVATE else JavaDescriptorVisibilities.PACKAGE_VISIBILITY
} else null
ParameterInfo(field, it.type, it.name, it.origin)
}
context.continuationClassesVarsCountByType[attributeOwnerId] = varsCountByType
val constructor = addPrimaryConstructorForLambda(suspendLambda, arity)
val invokeToOverride = functionNClass.functions.single {
it.owner.valueParameters.size == arity + 1 && it.owner.name.asString() == "invoke"
}
val createToOverride = suspendLambda.symbol.functions.singleOrNull {
it.owner.valueParameters.size == arity + 1 && it.owner.name.asString() == "create"
}
val invokeSuspend = addInvokeSuspendForLambda(function, suspendLambda, parametersFields)
if (function.capturesCrossinline()) {
addInvokeSuspendForInlineLambda(invokeSuspend)
}
if (createToOverride != null) {
addInvokeCallingCreate(addCreate(constructor, createToOverride, parametersFields), invokeSuspend, invokeToOverride)
} else {
addInvokeCallingConstructor(constructor, invokeSuspend, invokeToOverride, parametersFields)
}
this.metadata = function.metadata
context.suspendLambdaToOriginalFunctionMap[attributeOwnerId] = function
}
private fun IrClass.addInvokeSuspendForLambda(
irFunction: IrFunction,
suspendLambda: IrClass,
parameterInfos: List
): IrSimpleFunction {
val superMethod = suspendLambda.functions.single {
it.name.asString() == INVOKE_SUSPEND_METHOD_NAME && it.valueParameters.size == 1 &&
it.valueParameters[0].type.isKotlinResult()
}
return addFunctionOverride(superMethod, irFunction.startOffset, irFunction.endOffset).apply {
val localVals: List = parameterInfos.map { param ->
if (param.isUsed) {
buildVariable(
parent = this,
startOffset = UNDEFINED_OFFSET,
endOffset = UNDEFINED_OFFSET,
origin = param.origin,
name = param.name,
type = param.type
).apply {
val receiver = IrGetValueImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, dispatchReceiverParameter!!.symbol)
val initializerBlock = IrBlockImpl(UNDEFINED_OFFSET, UNDEFINED_OFFSET, type)
initializerBlock.statements += IrGetFieldImpl(
UNDEFINED_OFFSET, UNDEFINED_OFFSET, param.field!!.symbol, type, receiver
)
initializer = initializerBlock
}
} else null
}
body = irFunction.moveBodyTo(this, mapOf())?.let { body ->
body.transform(object : IrElementTransformerVoid() {
override fun visitGetValue(expression: IrGetValue): IrExpression {
val parameter = (expression.symbol.owner as? IrValueParameter)?.takeIf { it.parent == irFunction }
?: return expression
val varIndex = if (parameter.index < 0) irFunction.contextReceiverParametersCount
else if (parameter.index < irFunction.contextReceiverParametersCount || irFunction.extensionReceiverParameter == null) parameter.index
else parameter.index + 1
val lvar = localVals[varIndex]
?: return expression
return IrGetValueImpl(expression.startOffset, expression.endOffset, lvar.symbol)
}
}, null)
context.irFactory.createBlockBody(UNDEFINED_OFFSET, UNDEFINED_OFFSET, localVals.filterNotNull() + body.statements)
}
copyAnnotationsFrom(irFunction)
}
}
private fun IrClass.addInvokeSuspendForInlineLambda(invokeSuspend: IrSimpleFunction): IrSimpleFunction {
return addFunction(
INVOKE_SUSPEND_METHOD_NAME + FOR_INLINE_SUFFIX,
context.irBuiltIns.anyNType,
Modality.FINAL,
origin = JvmLoweredDeclarationOrigin.FOR_INLINE_STATE_MACHINE_TEMPLATE_CAPTURES_CROSSINLINE
).apply {
copyAttributes(invokeSuspend)
generateErrorForInlineBody()
valueParameters = invokeSuspend.valueParameters.map { it.copyTo(this) }
}
}
// Invoke function in lambdas is responsible for
// 1) calling `create`
// 2) starting newly created coroutine by calling `invokeSuspend`.
// Thus, it creates a clone of suspend lambda and starts it.
// TODO: fix the generic signature -- type parameters of FunctionN should be substituted
private fun IrClass.addInvokeCallingCreate(
create: IrFunction,
invokeSuspend: IrSimpleFunction,
invokeToOverride: IrSimpleFunctionSymbol
) = addFunctionOverride(invokeToOverride.owner) { function ->
val newlyCreatedObject = irCall(create).also { createCall ->
createCall.dispatchReceiver = irGet(function.dispatchReceiverParameter!!)
for ((index, param) in function.valueParameters.withIndex()) {
createCall.putValueArgument(index, irGet(param))
}
}
+irReturn(callInvokeSuspend(invokeSuspend, irImplicitCast(newlyCreatedObject, defaultType)))
}
// Same as above, but with `create` inlined. `create` is only defined in `SuspendLambda` in unary and binary
// versions; for other lambdas, there's no point in generating a non-overriding `create` separately.
private fun IrClass.addInvokeCallingConstructor(
constructor: IrFunction,
invokeSuspend: IrSimpleFunction,
invokeToOverride: IrSimpleFunctionSymbol,
fieldsForUnbound: List
) = addFunctionOverride(invokeToOverride.owner) { function ->
+irReturn(callInvokeSuspend(invokeSuspend, cloneLambda(function, constructor, fieldsForUnbound)))
}
private fun IrClass.addCreate(
constructor: IrFunction,
createToOverride: IrSimpleFunctionSymbol,
fieldsForUnbound: List
) = addFunctionOverride(createToOverride.owner) { function ->
+irReturn(cloneLambda(function, constructor, fieldsForUnbound))
}
private fun IrBlockBodyBuilder.cloneLambda(
scope: IrFunction,
constructor: IrFunction,
fieldsForUnbound: List
): IrExpression {
val constructorCall = irCall(constructor).also {
for (typeParameter in constructor.parentAsClass.typeParameters) {
it.putTypeArgument(typeParameter.index, typeParameter.defaultType)
}
it.putValueArgument(0, irGet(scope.valueParameters.last()))
}
if (fieldsForUnbound.none { it.isUsed }) {
return constructorCall
}
val result = irTemporary(constructorCall, "result")
for ((index, field) in fieldsForUnbound.withIndex()) {
if (field.isUsed) {
+irSetField(irGet(result), field.field!!, irGet(scope.valueParameters[index]))
}
}
return irGet(result)
}
private fun IrBlockBodyBuilder.callInvokeSuspend(invokeSuspend: IrSimpleFunction, lambda: IrExpression): IrExpression =
irCallOp(invokeSuspend.symbol, invokeSuspend.returnType, lambda, irCall(
this@SuspendLambdaLowering.context.ir.symbols.unsafeCoerceIntrinsic,
[email protected]
).apply {
putTypeArgument(0, context.irBuiltIns.anyNType)
putTypeArgument(1, type)
putValueArgument(0, irUnit())
})
private fun IrClass.addPrimaryConstructorForLambda(superClass: IrClass, arity: Int): IrConstructor =
addConstructor {
origin = JvmLoweredDeclarationOrigin.SUSPEND_LAMBDA
isPrimary = true
returnType = defaultType
visibility = DescriptorVisibilities.LOCAL
}.also { constructor ->
val completionParameterSymbol = constructor.addCompletionValueParameter()
val superClassConstructor = superClass.constructors.single {
it.valueParameters.size == 2 && it.valueParameters[0].type.isInt() && it.valueParameters[1].type.isNullableContinuation()
}
constructor.body = context.createIrBuilder(constructor.symbol).irBlockBody {
+irDelegatingConstructorCall(superClassConstructor).also {
it.putValueArgument(0, irInt(arity + 1))
it.putValueArgument(1, irGet(completionParameterSymbol))
}
+IrInstanceInitializerCallImpl(startOffset, endOffset, symbol, context.irBuiltIns.unitType)
}
}
}
private data class ParameterInfo(val field: IrField?, val type: IrType, val name: Name, val origin: IrDeclarationOrigin) {
val isUsed = field != null
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy