All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.psi2ir.transformations.InsertImplicitCasts.kt Maven / Gradle / Ivy

There is a newer version: 2.1.0-RC
Show newest version
/*
 * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.psi2ir.transformations

import org.jetbrains.kotlin.builtins.KotlinBuiltIns
import org.jetbrains.kotlin.builtins.StandardNames
import org.jetbrains.kotlin.builtins.isFunctionType
import org.jetbrains.kotlin.builtins.isSuspendFunctionType
import org.jetbrains.kotlin.descriptors.*
import org.jetbrains.kotlin.incremental.components.NoLookupLocation
import org.jetbrains.kotlin.ir.IrBuiltIns
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.PsiIrFileEntry
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.descriptors.IrBasedDeclarationDescriptor
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl
import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.toKotlinType
import org.jetbrains.kotlin.ir.util.SymbolTable
import org.jetbrains.kotlin.ir.util.TypeTranslator
import org.jetbrains.kotlin.ir.util.render
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.KtBinaryExpression
import org.jetbrains.kotlin.psi2ir.containsNull
import org.jetbrains.kotlin.psi2ir.findSingleFunction
import org.jetbrains.kotlin.psi2ir.generators.GeneratorContext
import org.jetbrains.kotlin.psi2ir.generators.GeneratorExtensions
import org.jetbrains.kotlin.psi2ir.generators.OPERATORS_DESUGARED_TO_CALLS
import org.jetbrains.kotlin.psi2ir.generators.getSubstitutedFunctionTypeForSamType
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.types.*
import org.jetbrains.kotlin.types.checker.KotlinTypeChecker
import org.jetbrains.kotlin.types.typeUtil.*
import org.jetbrains.kotlin.util.OperatorNameConventions

internal fun insertImplicitCasts(file: IrFile, context: GeneratorContext) {
    InsertImplicitCasts(
        context.irBuiltIns,
        context.typeTranslator,
        context.callToSubstitutedDescriptorMap,
        context.extensions,
        context.symbolTable,
        file,
    ).run(file)
}

internal class InsertImplicitCasts(
    private val irBuiltIns: IrBuiltIns,
    private val typeTranslator: TypeTranslator,
    private val callToSubstitutedDescriptorMap: Map,
    private val generatorExtensions: GeneratorExtensions,
    private val symbolTable: SymbolTable,
    private val file: IrFile,
) : IrElementTransformerVoid() {

    private val expectedFunctionExpressionReturnType = hashMapOf()

    fun run(element: IrElement) {
        element.transformChildrenVoid(this)
        postprocessReturnExpressions(element)
    }

    private fun postprocessReturnExpressions(element: IrElement) {
        // We need to re-create type parameter context for casts of postprocessed return values.
        element.acceptChildrenVoid(object : IrElementVisitorVoid {
            override fun visitReturn(expression: IrReturn) {
                super.visitReturn(expression)
                val expectedReturnType = expectedFunctionExpressionReturnType[expression.returnTargetSymbol.descriptor] ?: return
                expression.value = expression.value.cast(expectedReturnType)
            }

            override fun visitClass(declaration: IrClass) {
                typeTranslator.buildWithScope(declaration) {
                    super.visitClass(declaration)
                }
            }

            override fun visitFunction(declaration: IrFunction) {
                typeTranslator.buildWithScope(declaration) {
                    super.visitFunction(declaration)
                }
            }

            override fun visitElement(element: IrElement) {
                element.acceptChildrenVoid(this)
            }

            override fun visitCall(expression: IrCall) {
                expression.acceptChildrenVoid(this)
            }
        })
    }

    private fun KotlinType.toIrType() = typeTranslator.translateType(this)

    private val IrDeclarationReference.substitutedDescriptor
        get() = callToSubstitutedDescriptorMap[this] ?: symbol.descriptor as CallableDescriptor

    override fun visitCallableReference(expression: IrCallableReference<*>): IrExpression {
        val substitutedDescriptor = expression.substitutedDescriptor
        return expression.transformPostfix {
            transformReceiverArguments(substitutedDescriptor)
        }
    }

    private fun IrMemberAccessExpression<*>.transformReceiverArguments(substitutedDescriptor: CallableDescriptor) {
        dispatchReceiver = dispatchReceiver?.cast(getEffectiveDispatchReceiverType(substitutedDescriptor))
        val extensionReceiverType = substitutedDescriptor.extensionReceiverParameter?.type
        val originalExtensionReceiverType = substitutedDescriptor.original.extensionReceiverParameter?.type
        extensionReceiver = extensionReceiver?.cast(extensionReceiverType, originalExtensionReceiverType)
    }

    private fun getEffectiveDispatchReceiverType(descriptor: CallableDescriptor): KotlinType? =
        when {
            descriptor !is CallableMemberDescriptor ->
                null

            descriptor.kind == CallableMemberDescriptor.Kind.FAKE_OVERRIDE -> {
                val containingDeclaration = descriptor.containingDeclaration
                if (containingDeclaration !is ClassDescriptor)
                    throw AssertionError("Containing declaration for $descriptor should be a class: $containingDeclaration")
                else
                    containingDeclaration.defaultType.replaceArgumentsWithStarProjections()
            }

            else ->
                descriptor.dispatchReceiverParameter?.type
        }

    override fun visitMemberAccess(expression: IrMemberAccessExpression<*>): IrExpression {
        val substitutedDescriptor = expression.substitutedDescriptor
        return expression.transformPostfix {
            transformReceiverArguments(substitutedDescriptor)
            for (index in substitutedDescriptor.valueParameters.indices) {
                val irIndex = index + substitutedDescriptor.contextReceiverParameters.size
                val argument = getValueArgument(irIndex) ?: continue
                val parameterType = substitutedDescriptor.valueParameters[index].type
                val originalParameterType = substitutedDescriptor.original.valueParameters[index].type

                // Hack to support SAM conversions on out-projected types.
                // See SamType#createByValueParameter and genericSamProjectedOut.kt for more details.
                val expectedType =
                    if (argument.isSamConversion() && KotlinBuiltIns.isNothing(parameterType))
                        substitutedDescriptor.original.valueParameters[index].type.replaceArgumentsWithNothing()
                    else
                        parameterType

                putValueArgument(irIndex, argument.cast(expectedType, originalExpectedType = originalParameterType))
            }
        }
    }

    private fun IrExpression.isSamConversion(): Boolean =
        this is IrTypeOperatorCall && operator == IrTypeOperator.SAM_CONVERSION

    override fun visitBlockBody(body: IrBlockBody): IrBody =
        body.transformPostfix {
            statements.forEachIndexed { i, irStatement ->
                if (irStatement is IrExpression) {
                    body.statements[i] = irStatement.coerceToUnit()
                }
            }
        }

    override fun visitContainerExpression(expression: IrContainerExpression): IrExpression =
        expression.transformPostfix {
            if (statements.isEmpty()) return this

            val lastIndex = statements.lastIndex
            statements.forEachIndexed { i, irStatement ->
                if (irStatement is IrExpression) {
                    statements[i] =
                        if (i == lastIndex)
                            irStatement.cast(type)
                        else
                            irStatement.coerceToUnit()
                }
            }
        }

    override fun visitReturn(expression: IrReturn): IrExpression =
        expression.transformPostfix {
            value = if (expression.returnTargetSymbol is IrConstructorSymbol) {
                value.coerceToUnit()
            } else {
                value.cast(expression.returnTargetSymbol.descriptor.returnType)
            }
        }

    override fun visitSetValue(expression: IrSetValue): IrExpression =
        expression.transformPostfix {
            value = value.cast(expression.symbol.owner.type)
        }

    override fun visitGetField(expression: IrGetField): IrExpression =
        expression.transformPostfix {
            receiver = receiver?.cast(getEffectiveDispatchReceiverType(expression.substitutedDescriptor))
        }

    override fun visitSetField(expression: IrSetField): IrExpression =
        expression.transformPostfix {
            val substituted = expression.substitutedDescriptor as PropertyDescriptor
            receiver = receiver?.cast(getEffectiveDispatchReceiverType(substituted))
            value = value.cast(substituted.type)
        }

    override fun visitVariable(declaration: IrVariable): IrVariable =
        declaration.transformPostfix {
            initializer = initializer?.cast(declaration.type)
        }

    override fun visitField(declaration: IrField): IrStatement {
        return typeTranslator.withTypeErasure(declaration.correspondingPropertySymbol?.descriptor ?: declaration.descriptor) {
            declaration.transformPostfix {
                initializer?.coerceInnerExpression(descriptor.type)
            }
        }
    }

    override fun visitFunction(declaration: IrFunction): IrStatement =
        typeTranslator.buildWithScope(declaration) {
            declaration.transformPostfix {
                valueParameters.forEach {
                    it.defaultValue?.coerceInnerExpression(it.descriptor.type)
                }
            }
        }

    override fun visitClass(declaration: IrClass): IrStatement =
        typeTranslator.buildWithScope(declaration) {
            super.visitClass(declaration)
        }

    override fun visitWhen(expression: IrWhen): IrExpression =
        expression.transformPostfix {
            for (irBranch in branches) {
                irBranch.condition = irBranch.condition.cast(irBuiltIns.booleanType)
                irBranch.result = irBranch.result.cast(type)
            }
        }

    override fun visitLoop(loop: IrLoop): IrExpression =
        loop.transformPostfix {
            condition = condition.cast(irBuiltIns.booleanType)
            body = body?.coerceToUnit()
        }

    override fun visitThrow(expression: IrThrow): IrExpression =
        expression.transformPostfix {
            value = value.cast(irBuiltIns.throwableType)
        }

    override fun visitTry(aTry: IrTry): IrExpression =
        aTry.transformPostfix {
            tryResult = tryResult.cast(type)

            for (aCatch in catches) {
                aCatch.result = aCatch.result.cast(type)
            }

            finallyExpression = finallyExpression?.coerceToUnit()
        }

    override fun visitTypeOperator(expression: IrTypeOperatorCall): IrExpression =
        when (expression.operator) {
            IrTypeOperator.SAM_CONVERSION ->
                expression.transformPostfix {
                    argument = argument.cast(typeOperand.originalKotlinType!!.getSubstitutedFunctionTypeForSamType())
                }

            IrTypeOperator.IMPLICIT_CAST -> {
                // This branch is required for handling specific ambiguous cases in implicit cast insertion,
                // such as SAM conversion VS smart cast.
                // Here IMPLICIT_CAST serves as a type hint.
                // Replace IrTypeOperatorCall(IMPLICIT_CAST, ...) with an argument cast to the required type
                // (possibly generating another IrTypeOperatorCall(IMPLICIT_CAST, ...), if required).

                expression.transformChildrenVoid()
                expression.argument.cast(expression.typeOperand)
            }

            else ->
                super.visitTypeOperator(expression)
        }

    override fun visitVararg(expression: IrVararg): IrExpression =
        expression.transformPostfix {
            elements.forEachIndexed { i, element ->
                when (element) {
                    is IrSpreadElement ->
                        element.expression = element.expression.cast(expression.type)
                    is IrExpression ->
                        putElement(i, element.cast(varargElementType))
                }
            }
        }

    private fun IrExpressionBody.coerceInnerExpression(expectedType: KotlinType) {
        expression = expression.cast(expectedType)
    }

    private fun IrExpression.cast(irType: IrType): IrExpression =
        cast(irType.originalKotlinType)

    private fun KotlinType.getFunctionReturnTypeOrNull(): KotlinType? =
        if (isFunctionType || isSuspendFunctionType)
            arguments.last().type
        else
            null

    private fun IrExpression.cast(
        possiblyNonDenotableExpectedType: KotlinType?,
        originalExpectedType: KotlinType? = possiblyNonDenotableExpectedType
    ): IrExpression {
        if (possiblyNonDenotableExpectedType == null) return this
        if (possiblyNonDenotableExpectedType.isError) return this

        val expectedType = typeTranslator.approximate(possiblyNonDenotableExpectedType)

        if (this is IrFunctionExpression && originalExpectedType != null) {
            recordExpectedLambdaReturnTypeIfAppropriate(expectedType, originalExpectedType)
        }

        val notNullableExpectedType = expectedType.makeNotNullable()

        val valueType = this.type.originalKotlinType ?: error("Expecting original kotlin type for IrType ${type.render()}")

        return when {
            expectedType.isUnit() ->
                coerceToUnit()

            valueType.isDynamic() && !expectedType.isDynamic() ->
                if (expectedType.isNullableAny())
                    this
                else
                    implicitCast(expectedType, IrTypeOperator.IMPLICIT_DYNAMIC_CAST)

            valueType.isNullabilityFlexible() && valueType.containsNull() && !expectedType.acceptsNullValues() ->
                implicitNonNull(valueType, expectedType)

            valueType.hasEnhancedNullability() && !expectedType.acceptsNullValues() ->
                implicitNonNull(valueType, expectedType)

            KotlinTypeChecker.DEFAULT.isSubtypeOf(valueType.toNonIrBased(), expectedType.toNonIrBased().makeNullable()) ->
                this

            KotlinBuiltIns.isInt(valueType) && notNullableExpectedType.isBuiltInIntegerType() ->
                coerceIntToAnotherIntegerType(notNullableExpectedType)

            else -> {
                val targetType = if (!valueType.containsNull()) notNullableExpectedType else expectedType
                implicitCast(targetType, IrTypeOperator.IMPLICIT_CAST)
            }
        }
    }

    private fun IrFunctionExpression.recordExpectedLambdaReturnTypeIfAppropriate(
        expectedType: KotlinType,
        originalExpectedType: KotlinType
    ) {
        // TODO see KT-35849

        val returnTypeFromExpected = expectedType.getFunctionReturnTypeOrNull() ?: return
        val returnTypeFromOriginalExpected = originalExpectedType.getFunctionReturnTypeOrNull()

        if (returnTypeFromOriginalExpected?.isTypeParameter() != true) {
            expectedFunctionExpressionReturnType[function.descriptor] = returnTypeFromExpected.toIrType()
        }
    }

    private fun KotlinType.acceptsNullValues() =
        containsNull() || hasEnhancedNullability()

    private fun KotlinType.hasEnhancedNullability() =
        generatorExtensions.enhancedNullability.hasEnhancedNullability(this)

    private fun IrExpression.implicitNonNull(valueType: KotlinType, expectedType: KotlinType): IrExpression {
        val nonNullFlexibleType = valueType.upperIfFlexible().makeNotNullable()
        val nonNullValueType = generatorExtensions.enhancedNullability.stripEnhancedNullability(nonNullFlexibleType)
        return implicitCast(nonNullValueType, IrTypeOperator.IMPLICIT_NOTNULL).cast(expectedType)
    }

    private fun IrExpression.implicitCast(targetType: KotlinType, typeOperator: IrTypeOperator): IrExpression {
        val irType = targetType.toIrType()
        return IrTypeOperatorCallImpl(startOffset, endOffset, irType, typeOperator, irType, this)
    }

    private fun IrExpression.coerceIntToAnotherIntegerType(targetType: KotlinType): IrExpression {
        if (!type.originalKotlinType!!.isInt()) throw AssertionError("Expression of type 'kotlin.Int' expected: $this")
        if (targetType.isInt()) return this

        if (generatorExtensions.shouldPreventDeprecatedIntegerValueTypeLiteralConversion &&
            this is IrCall && preventDeprecatedIntegerValueTypeLiteralConversion()
        ) return this

        return if (this is IrConst<*>) {
            val value = this.value as Int
            val irType = targetType.toIrType()
            when {
                targetType.isByte() -> IrConstImpl.byte(startOffset, endOffset, irType, value.toByte())
                targetType.isShort() -> IrConstImpl.short(startOffset, endOffset, irType, value.toShort())
                targetType.isLong() -> IrConstImpl.long(startOffset, endOffset, irType, value.toLong())
                KotlinBuiltIns.isUByte(targetType) -> IrConstImpl.byte(startOffset, endOffset, irType, value.toByte())
                KotlinBuiltIns.isUShort(targetType) -> IrConstImpl.short(startOffset, endOffset, irType, value.toShort())
                KotlinBuiltIns.isUInt(targetType) -> IrConstImpl.int(startOffset, endOffset, irType, value)
                KotlinBuiltIns.isULong(targetType) -> IrConstImpl.long(startOffset, endOffset, irType, value.toLong())
                else -> throw AssertionError("Unexpected target type for integer coercion: $targetType")
            }
        } else {
            when {
                targetType.isByte() -> invokeIntegerCoercionFunction(targetType, "toByte")
                targetType.isShort() -> invokeIntegerCoercionFunction(targetType, "toShort")
                targetType.isLong() -> invokeIntegerCoercionFunction(targetType, "toLong")
                KotlinBuiltIns.isUByte(targetType) -> invokeUnsignedIntegerCoercionFunction(targetType, "toUByte")
                KotlinBuiltIns.isUShort(targetType) -> invokeUnsignedIntegerCoercionFunction(targetType, "toUShort")
                KotlinBuiltIns.isUInt(targetType) -> invokeUnsignedIntegerCoercionFunction(targetType, "toUInt")
                KotlinBuiltIns.isULong(targetType) -> invokeUnsignedIntegerCoercionFunction(targetType, "toULong")
                else -> throw AssertionError("Unexpected target type for integer coercion: $targetType")
            }
        }
    }

    // In JVM, we don't convert values resulted from calling built-in operators on integer literals to another integer type.
    // The reason is that doing so would change behavior, which we want to avoid, see KT-42321.
    // At the same time, such structure seems possible to achieve only via the magical integer value type, but inferring the result of
    // the operator call based on an expected type is deprecated behavior which is going to be removed in the future, see KT-38895.
    private fun IrCall.preventDeprecatedIntegerValueTypeLiteralConversion(): Boolean {
        val descriptor = symbol.descriptor
        if (descriptor.name !in operatorsWithDeprecatedIntegerValueTypeLiteralConversion) return false

        // This bug is only reproducible for non-operator calls, for example "1.plus(2)", NOT "1 + 2".
        if (origin in OPERATORS_DESUGARED_TO_CALLS) return false

        // For infix methods, this bug is only reproducible for non-infix calls, for example "1.shl(2)", NOT "1 shl 2".
        if (descriptor.isInfix) {
            if ((file.fileEntry as? PsiIrFileEntry)?.findPsiElement(this) is KtBinaryExpression) return false
        }

        return descriptor.dispatchReceiverParameter?.type?.let { KotlinBuiltIns.isPrimitiveType(it) } == true
    }

    private val operatorsWithDeprecatedIntegerValueTypeLiteralConversion = with(OperatorNameConventions) {
        setOf(PLUS, MINUS, TIMES, DIV, REM, UNARY_PLUS, UNARY_MINUS, SHL, SHR, USHR, AND, OR, XOR, INV)
    }

    private fun IrExpression.invokeIntegerCoercionFunction(targetType: KotlinType, coercionFunName: String): IrExpression {
        val coercionFunction = irBuiltIns.intClass.descriptor.unsubstitutedMemberScope.findSingleFunction(Name.identifier(coercionFunName))
        return IrCallImpl(
            startOffset, endOffset,
            targetType.toIrType(),
            symbolTable.descriptorExtension.referenceSimpleFunction(coercionFunction),
            typeArgumentsCount = 0, valueArgumentsCount = 0
        ).also { irCall ->
            irCall.dispatchReceiver = this
        }
    }

    private fun IrExpression.invokeUnsignedIntegerCoercionFunction(targetType: KotlinType, coercionFunName: String): IrExpression {
        // 'toUByte', 'toUShort', 'toUInt', 'toULong' are top-level extension functions in 'kotlin' package.
        // There are several such functions (one for each built-in integer type: Byte, Short, Int, Long),
        // we need one that takes Int.
        val coercionFunction = targetType.constructor.declarationDescriptor!!.module
            .getPackage(StandardNames.BUILT_INS_PACKAGE_FQ_NAME)
            .memberScope.getContributedFunctions(Name.identifier(coercionFunName), NoLookupLocation.FROM_BACKEND)
            .find {
                val extensionReceiver = it.extensionReceiverParameter
                extensionReceiver != null && extensionReceiver.type.isInt()
            }
            ?: throw AssertionError("Coercion function '$coercionFunName' not found")
        return IrCallImpl(
            startOffset, endOffset,
            targetType.toIrType(),
            symbolTable.descriptorExtension.referenceSimpleFunction(coercionFunction),
            typeArgumentsCount = 0, valueArgumentsCount = 0
        ).also { irCall ->
            irCall.extensionReceiver = this
        }
    }

    private fun KotlinType.isBuiltInIntegerType(): Boolean =
        KotlinBuiltIns.isByte(this) ||
                KotlinBuiltIns.isShort(this) ||
                KotlinBuiltIns.isInt(this) ||
                KotlinBuiltIns.isLong(this) ||
                KotlinBuiltIns.isUByte(this) ||
                KotlinBuiltIns.isUShort(this) ||
                KotlinBuiltIns.isUInt(this) ||
                KotlinBuiltIns.isULong(this)

    private fun IrExpression.coerceToUnit(): IrExpression {
        return if (KotlinTypeChecker.DEFAULT.isSubtypeOf(type.toKotlinType(), irBuiltIns.unitType.toKotlinType()))
            this
        else
            IrTypeOperatorCallImpl(
                startOffset, endOffset,
                irBuiltIns.unitType,
                IrTypeOperator.IMPLICIT_COERCION_TO_UNIT,
                irBuiltIns.unitType,
                this
            )
    }

    // KotlinType subtype checking fails when one of the types uses IR-based descriptors, the other one regular descriptors.
    // This is a kludge to remove IR-based descriptors where possible.
    private fun KotlinType.toNonIrBased(): KotlinType {
        if (this !is SimpleType) return this
        if (this.isError) return this
        val newDescriptor = constructor.declarationDescriptor?.let {
            if (it is IrBasedDeclarationDescriptor<*> && it.owner.symbol.hasDescriptor)
                it.owner.symbol.descriptor as ClassifierDescriptor
            else
                it
        } ?: return this
        val newArguments = arguments.mapIndexed { index, it ->
            if (it.isStarProjection)
                StarProjectionImpl((newDescriptor as ClassDescriptor).typeConstructor.parameters[index])
            else
                TypeProjectionImpl(it.projectionKind, it.type.toNonIrBased())
        }
        return newDescriptor.defaultType.replace(newArguments = newArguments).makeNullableAsSpecified(isMarkedNullable)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy