All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.vk.compiler.plugin.composable.skippability.checker.fir.FirStability.kt Maven / Gradle / Ivy

package com.vk.compiler.plugin.composable.skippability.checker.fir

import com.vk.compiler.plugin.composable.skippability.checker.ComposeClassId
import com.vk.compiler.plugin.composable.skippability.checker.KnownStableConstructs
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.Visibilities
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.checkers.typeParameterSymbols
import org.jetbrains.kotlin.fir.declarations.FirDeclarationOrigin
import org.jetbrains.kotlin.fir.declarations.hasAnnotation
import org.jetbrains.kotlin.fir.declarations.isInlineOrValueClass
import org.jetbrains.kotlin.fir.declarations.utils.isCompanion
import org.jetbrains.kotlin.fir.declarations.utils.isEnumClass
import org.jetbrains.kotlin.fir.declarations.utils.isEnumEntry
import org.jetbrains.kotlin.fir.declarations.utils.isExpect
import org.jetbrains.kotlin.fir.declarations.utils.isFinal
import org.jetbrains.kotlin.fir.declarations.utils.isInner
import org.jetbrains.kotlin.fir.declarations.utils.isInterface
import org.jetbrains.kotlin.fir.declarations.utils.visibility
import org.jetbrains.kotlin.fir.expressions.UnresolvedExpressionTypeAccess
import org.jetbrains.kotlin.fir.isPrimitiveType
import org.jetbrains.kotlin.fir.resolve.defaultType
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.resolve.toFirRegularClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassifierSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFieldSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
import org.jetbrains.kotlin.fir.types.ConeDynamicType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeKotlinTypeProjection
import org.jetbrains.kotlin.fir.types.ConeNullability
import org.jetbrains.kotlin.fir.types.ConeTypeParameterType
import org.jetbrains.kotlin.fir.types.FirResolvedTypeRef
import org.jetbrains.kotlin.fir.types.FirTypeRef
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.coneTypeSafe
import org.jetbrains.kotlin.fir.types.isAny
import org.jetbrains.kotlin.fir.types.isNullable
import org.jetbrains.kotlin.fir.types.isPrimitive
import org.jetbrains.kotlin.fir.types.isString
import org.jetbrains.kotlin.fir.types.isUnit
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.fir.types.toRegularClassSymbol
import org.jetbrains.kotlin.fir.types.typeContext
import org.jetbrains.kotlin.fir.types.withNullability
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.types.checker.SimpleClassicTypeSystemContext.getUnsubstitutedUnderlyingType

sealed class FirStability {
    // class Foo(val bar: Int)
    class Certain(val stable: Boolean) : FirStability() {
        override fun toString(): String {
            return if (stable) "Stable" else "Unstable"
        }
    }

    // class Foo(val bar: ExternalType) -> ExternalType.$stable
    class Runtime(val classSymbol: FirRegularClassSymbol) : FirStability() {
        override fun toString(): String {
            return "Runtime(${classSymbol.name.asString()})"
        }
    }

    // interface Foo { fun result(): Int }
    class Unknown(val declaration: FirRegularClassSymbol) : FirStability() {
        override fun toString(): String {
            return "Uncertain(${declaration.name.asString()})"
        }
    }

    // class  Foo(val value: T)
    class Parameter(val parameter: FirTypeParameterSymbol) : FirStability() {
        override fun toString(): String {
            return "Parameter(${parameter.name.asString()})"
        }
    }

    // class Foo(val foo: A, val bar: B)
    class Combined(val elements: List) : FirStability() {
        override fun toString(): String {
            return elements.joinToString(",")
        }
    }

    operator fun plus(other: FirStability): FirStability = when {
        other is Certain -> if (other.stable) this else other
        this is Certain -> if (stable) other else this
        else -> Combined(listOf(this, other))
    }

    operator fun plus(other: List): FirStability {
        var stability = this
        for (el in other) {
            stability += el
        }
        return stability
    }

    companion object {
        val Stable: FirStability = Certain(true)
        val Unstable: FirStability = Certain(false)
    }
}

fun FirStability.knownUnstable(): Boolean = when (this) {
    is FirStability.Certain -> !stable
    is FirStability.Runtime -> false
    is FirStability.Unknown -> false
    is FirStability.Parameter -> false
    is FirStability.Combined -> elements.any { it.knownUnstable() }
}

fun FirStability.knownStable(): Boolean = when (this) {
    is FirStability.Certain -> stable
    is FirStability.Runtime -> false
    is FirStability.Unknown -> false
    is FirStability.Parameter -> false
    is FirStability.Combined -> elements.all { it.knownStable() }
}

fun FirStability.isUncertain(): Boolean = when (this) {
    is FirStability.Certain -> false
    is FirStability.Runtime -> true
    is FirStability.Unknown -> true
    is FirStability.Parameter -> true
    is FirStability.Combined -> elements.any { it.isUncertain() }
}

fun FirStability.isExpressible(): Boolean = when (this) {
    is FirStability.Certain -> true
    is FirStability.Runtime -> true
    is FirStability.Unknown -> false
    is FirStability.Parameter -> true
    is FirStability.Combined -> elements.all { it.isExpressible() }
}

fun FirStability.normalize(): FirStability {
    when (this) {
        // if not combined, there is no normalization needed
        is FirStability.Certain,
        is FirStability.Parameter,
        is FirStability.Runtime,
        is FirStability.Unknown -> return this

        is FirStability.Combined -> {
            // if combined, we perform the more expensive normalization process
        }
    }
    val parameters = mutableSetOf()
    val parts = mutableListOf()
    val stack = mutableListOf(this)
    while (stack.isNotEmpty()) {
        when (val stability: FirStability = stack.removeAt(stack.size - 1)) {
            is FirStability.Combined -> {
                stack.addAll(stability.elements)
            }

            is FirStability.Certain -> {
                if (!stability.stable)
                    return FirStability.Unstable
            }

            is FirStability.Parameter -> {
                if (parameters.contains(stability.parameter)) {
                    parameters.add(stability.parameter)
                    parts.add(stability)
                }
            }

            is FirStability.Runtime -> parts.add(stability)
            is FirStability.Unknown -> {
                /* do nothing */
            }
        }
    }
    return FirStability.Combined(parts)
}

fun FirStability.forEach(callback: (FirStability) -> Unit) {
    if (this is FirStability.Combined) {
        elements.forEach { it.forEach(callback) }
    } else {
        callback(this)
    }
}

@OptIn(UnresolvedExpressionTypeAccess::class)
fun FirRegularClassSymbol.hasStableMarkerAnnotation(session: FirSession): Boolean {
    return resolvedAnnotationsWithClassIds.map { it.coneTypeOrNull }
        .mapNotNull { it?.toRegularClassSymbol(session) }
        .any { it.hasAnnotation(ComposeClassId.StableMarker, session) }
}

private fun FirRegularClassSymbol.hasStableMarkedDescendant(session: FirSession): Boolean {
    if (hasStableMarkerAnnotation(session)) return true
    return resolvedSuperTypes.any {
        !it.isAny && it.toRegularClassSymbol(session)?.hasStableMarkedDescendant(session) == true
    }
}

private fun FirRegularClassSymbol.isProtobufType(session: FirSession): Boolean {
    // Quick exit as all protos are final
    if (!isFinal) return false
    val directParentClassName =
        resolvedSuperTypes
            .mapNotNull { it.toRegularClassSymbol(session) }
            .lastOrNull { !it.isInterface }
            ?.classId?.asFqNameString()
    return directParentClassName == "com.google.protobuf.GeneratedMessageLite" ||
            directParentClassName == "com.google.protobuf.GeneratedMessage"
}

//private fun IrAnnotationContainer.stabilityParamBitmask(): Int? =
//    (annotations.findAnnotation(StabilityInferred)
//        ?.getValueArgument(0) as? IrConst<*>
//            )?.value as? Int

fun firStabilityOf(firTypeRef: FirTypeRef, session: FirSession): FirStability =
    firStabilityOf(firTypeRef, session, emptyMap(), emptySet())

@OptIn(UnresolvedExpressionTypeAccess::class)
@Suppress("ReturnCount", "NestedBlockDepth") // expected
private fun firStabilityOf(
    classSymbol: FirRegularClassSymbol,
    session: FirSession,
    substitutions: Map,
    currentlyAnalyzing: Set>
): FirStability {
    if (currentlyAnalyzing.contains(classSymbol)) return FirStability.Unstable
    if (classSymbol.hasStableMarkedDescendant(session)) return FirStability.Stable
    if (classSymbol.isEnumClass || classSymbol.isEnumEntry) return FirStability.Stable
    if (classSymbol.isPrimitiveType()) return FirStability.Stable
    if (classSymbol.isProtobufType(session)) return FirStability.Stable

//    if (classSymbol.origin == IrDeclarationOrigin.IR_BUILTINS_STUB) {
//        error("Builtins Stub: ${declaration.name}")
//    }

    val analyzing = currentlyAnalyzing + classSymbol

    if (canInferStability(classSymbol)) {
        val fqName = classSymbol.classId.asFqNameString()
        val stability: FirStability
        val mask: Int
        if (KnownStableConstructs.stableTypes.contains(fqName)) {
            mask = KnownStableConstructs.stableTypes[fqName] ?: 0
            stability = FirStability.Stable
        } else {
            mask = retrieveParameterMask(classSymbol, session, substitutions, analyzing)
                ?: return FirStability.Unstable
            stability = FirStability.Runtime(classSymbol)
        }
        return stability + FirStability.Combined(
            classSymbol.typeParameterSymbols.mapIndexedNotNull { index, parameter ->
                if (mask and (0b1 shl index) != 0) {
                    val type = substitutions[parameter]?.type
                    if (type != null)
                        firStabilityOf(type, session, substitutions, analyzing)
                    else
                        FirStability.Stable
                } else null
            }
        )
    } else if (classSymbol.origin is FirDeclarationOrigin.Java) {
        return FirStability.Unstable
    }

    if (classSymbol.isInterface) {
        return FirStability.Unknown(classSymbol)
    }

    var stability = FirStability.Stable

    for (member in classSymbol.declarationSymbols) {
        when (member) {
            is FirPropertySymbol -> {
                member.backingFieldSymbol?.let {
                    if (member.isVar && member.delegate == null) return FirStability.Unstable
                    val delegateType = member.delegate?.coneTypeOrNull
                    stability += firStabilityOf(
                        delegateType ?: it.resolvedReturnType,
                        session,
                        substitutions,
                        analyzing
                    )
                }
            }

            is FirFieldSymbol -> {
                stability += firStabilityOf(
                    member.resolvedReturnType,
                    session,
                    substitutions,
                    analyzing
                )
            }
        }
    }

    return stability
}

private fun canInferStability(declaration: FirRegularClassSymbol): Boolean {
    val fqName = declaration.classId.asFqNameString()
    return KnownStableConstructs.stableTypes.contains(fqName) || declaration.origin is FirDeclarationOrigin.BuiltIns
}

private fun firStabilityOf(
    firTypeRef: FirTypeRef,
    session: FirSession,
    substitutions: Map,
    currentlyAnalyzing: Set>
): FirStability {
    if (firTypeRef !is FirResolvedTypeRef) return FirStability.Unstable
    return firStabilityOf(firTypeRef.type, session, substitutions, currentlyAnalyzing)
}

private fun firStabilityOf(
    coneKotlinType: ConeKotlinType,
    session: FirSession,
    substitutions: Map,
    currentlyAnalyzing: Set>
): FirStability {
    val expandedType = coneKotlinType.fullyExpandedType(session)
    val regularClassSymbol = expandedType.toRegularClassSymbol(session)
    return when {
        expandedType is ConeDynamicType -> FirStability.Unstable

        expandedType.isUnit ||
                expandedType.isPrimitive ||
                expandedType.isFunctionOrKFunctionType(session, false) ||
                regularClassSymbol?.isSyntheticComposableFunction() == true ||
                expandedType.isString -> FirStability.Stable

        expandedType.isNonReifiedTypeParameter() -> {
            val parameterSymbol = expandedType.lookupTag.symbol
            val arg = substitutions[parameterSymbol]?.type
            if (arg != null) {
                firStabilityOf(arg, session, substitutions, currentlyAnalyzing)
            } else {
                FirStability.Parameter(parameterSymbol)
            }
        }

        expandedType.isNullable -> {
            val type = expandedType.withNullability(ConeNullability.NOT_NULL, session.typeContext)
            firStabilityOf(
                coneKotlinType = type,
                session = session,
                substitutions = substitutions,
                currentlyAnalyzing = currentlyAnalyzing
            )
        }

        regularClassSymbol?.isInlineOrValueClass() == true -> {
            if (regularClassSymbol.hasAnnotationSafe(ComposeClassId.StableMarker, session)) {
                FirStability.Stable
            } else {
                val type = expandedType.getUnsubstitutedUnderlyingType() as? ConeKotlinType
                if (type != null) {
                    firStabilityOf(
                        coneKotlinType = type,
                        session = session,
                        substitutions = substitutions,
                        currentlyAnalyzing = currentlyAnalyzing
                    )
                } else FirStability.Unstable
            }
        }

        expandedType is ConeClassLikeType -> {
            val classSymbol = expandedType.lookupTag.toFirRegularClassSymbol(session)
            if (classSymbol != null) {
                firStabilityOf(
                    classSymbol,
                    session,
                    substitutions + classSymbol.substitutionMap(coneKotlinType),
                    currentlyAnalyzing
                )
            } else FirStability.Unstable
        }

        expandedType is ConeTypeParameterType -> {
            FirStability.Unstable
        }
//      isAlias -> we expanded type early
        else -> error("Unexpected coneKotlinType: $coneKotlinType")
    }
}


private fun retrieveParameterMask(
    classSymbol: FirRegularClassSymbol,
    session: FirSession,
    substitutions: Map,
    currentlyAnalyzing: Set>
): Int? {

    if (
        classSymbol.visibility !== Visibilities.Public ||
        classSymbol.isEnumClass ||
        classSymbol.isEnumEntry ||
        classSymbol.isInterface ||
        classSymbol.classKind == ClassKind.ANNOTATION_CLASS ||
        classSymbol.name == SpecialNames.NO_NAME_PROVIDED ||
        classSymbol.isExpect ||
        classSymbol.isInner ||
        classSymbol.isCompanion ||
        classSymbol.isInlineOrValueClass() ||
        classSymbol.defaultType().isAny
    ) return null

    val stability =
        firStabilityOf(classSymbol, session, substitutions, currentlyAnalyzing).normalize()

    var parameterMask = 0

    if (classSymbol.typeParameterSymbols.isNotEmpty()) {

        stability.forEach {
            when (it) {
                is FirStability.Parameter -> {
                    val index = classSymbol.typeParameterSymbols.indexOf(it.parameter)
                    if (index != -1) {
                        // the stability of this parameter matters for the stability of the
                        // class
                        parameterMask = parameterMask or 0b1 shl index
                    }
                }

                else -> {
                    /* No action necessary */
                }
            }
        }
    }
    return parameterMask
}

private fun FirClassifierSymbol<*>.substitutionMap(coneKotlinType: ConeKotlinType): Map {
    val params = typeParameterSymbols
    val args = coneKotlinType.typeArguments
    return params?.zip(args)?.filter { (_, arg) -> arg is ConeKotlinTypeProjection }
        ?.associateTo(mutableMapOf()) { (param, arg) -> param to arg as ConeKotlinTypeProjection }
        .orEmpty()
}

private fun FirRegularClassSymbol.isSyntheticComposableFunction() =
    classId.asFqNameString().startsWith("androidx.compose.runtime.internal.ComposableFunction")




© 2015 - 2024 Weber Informatics LLC | Privacy Policy