All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.fir.scopes.impl.FirTypeIntersectionScopeContext.kt Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.fir.scopes.impl

import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.descriptors.Visibilities
import org.jetbrains.kotlin.descriptors.Visibility
import org.jetbrains.kotlin.fir.*
import org.jetbrains.kotlin.fir.caches.*
import org.jetbrains.kotlin.fir.declarations.FirCallableDeclaration
import org.jetbrains.kotlin.fir.declarations.FirDeclarationOrigin
import org.jetbrains.kotlin.fir.declarations.FirMemberDeclaration
import org.jetbrains.kotlin.fir.declarations.utils.isExpect
import org.jetbrains.kotlin.fir.declarations.utils.modality
import org.jetbrains.kotlin.fir.declarations.utils.visibility
import org.jetbrains.kotlin.fir.resolve.substitution.ConeSubstitutor
import org.jetbrains.kotlin.fir.resolve.transformers.ReturnTypeCalculatorForFullBodyResolve
import org.jetbrains.kotlin.fir.scopes.*
import org.jetbrains.kotlin.fir.scopes.impl.FirIntersectionOverrideStorage.ContextForIntersectionOverrideConstruction
import org.jetbrains.kotlin.fir.scopes.impl.FirTypeIntersectionScopeContext.ResultOfIntersection
import org.jetbrains.kotlin.fir.symbols.impl.*
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeSimpleKotlinType
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.Name
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract

typealias MembersByScope = List>>

class FirTypeIntersectionScopeContext(
    val session: FirSession,
    private val overrideChecker: FirOverrideChecker,
    val scopes: List,
    private val dispatchReceiverType: ConeSimpleKotlinType,
) {
    private val overrideService = session.overrideService

    val intersectionOverrides: FirCache, MemberWithBaseScope>, ContextForIntersectionOverrideConstruction<*>> =
        session.intersectionOverrideStorage.cacheByScope.getValue(dispatchReceiverType).intersectionOverrides

    sealed class ResultOfIntersection>(
        val overriddenMembers: List>,
        val containingScope: FirTypeScope?
    ) {
        abstract val chosenSymbol: D

        class SingleMember>(
            override val chosenSymbol: D,
            overriddenMembers: List>,
            containingScope: FirTypeScope?
        ) : ResultOfIntersection(overriddenMembers, containingScope) {
            constructor(
                chosenSymbol: D,
                overriddenMember: MemberWithBaseScope
            ) : this(chosenSymbol, listOf(overriddenMember), overriddenMember.baseScope)
        }

        class NonTrivial>(
            private val intersectionOverridesCache: FirCache, MemberWithBaseScope>, ContextForIntersectionOverrideConstruction<*>>,
            private val context: ContextForIntersectionOverrideConstruction,
            overriddenMembers: List>,
            containingScope: FirTypeScope?
        ) : ResultOfIntersection(overriddenMembers, containingScope) {
            override val chosenSymbol: D by lazy {
                @Suppress("UNCHECKED_CAST")
                intersectionOverridesCache.getValue(
                    context.mostSpecific,
                    context
                ).member as D
            }

            val mostSpecific: D
                get() = context.mostSpecific
        }
    }

    fun processClassifiersByNameWithSubstitution(
        name: Name,
        absentClassifierNames: MutableSet,
        processor: (FirClassifierSymbol<*>, ConeSubstitutor) -> Unit
    ) {
        if (name in absentClassifierNames) return
        val classifiers = collectClassifiers(name)
        if (classifiers.isEmpty()) {
            absentClassifierNames += name
            return
        }
        for ((symbol, substitution) in classifiers) {
            processor(symbol, substitution)
        }
    }

    private fun collectClassifiers(name: Name): List, ConeSubstitutor>> {
        val accepted = HashSet>()
        val pending = mutableListOf>()
        val result = mutableListOf, ConeSubstitutor>>()
        for (scope in scopes) {
            scope.processClassifiersByNameWithSubstitution(name) { symbol, substitution ->
                if (symbol !in accepted) {
                    pending += symbol
                    result += symbol to substitution
                }
            }
            accepted += pending
            pending.clear()
        }
        return result
    }

    fun collectFunctions(name: Name): List> {
        return collectIntersectionResultsForCallables(name, FirScope::processFunctionsByName)
    }

    @OptIn(PrivateForInline::class)
    inline fun > collectMembersGroupedByScope(
        name: Name,
        processCallables: FirScope.(Name, (D) -> Unit) -> Unit
    ): MembersByScope {
        return scopes.mapNotNull { scope ->
            val resultForScope = mutableListOf()
            scope.processCallables(name) {
                if (it !is FirConstructorSymbol) {
                    resultForScope.add(it)
                }
            }

            resultForScope.takeIf { it.isNotEmpty() }?.let {
                scope to it
            }
        }
    }

    @OptIn(PrivateForInline::class)
    inline fun > collectIntersectionResultsForCallables(
        name: Name,
        processCallables: FirScope.(Name, (D) -> Unit) -> Unit
    ): List> {
        return convertGroupedCallablesToIntersectionResults(collectMembersGroupedByScope(name, processCallables))
    }

    fun > convertGroupedCallablesToIntersectionResults(
        membersByScope: List>>
    ): List> {
        if (membersByScope.isEmpty()) {
            return emptyList()
        }

        membersByScope.singleOrNull()?.let { (scope, members) ->
            return members.map { ResultOfIntersection.SingleMember(it, MemberWithBaseScope(it, scope)) }
        }

        val allMembersWithScope = membersByScope.flatMapTo(linkedSetOf()) { (scope, members) ->
            members.map { MemberWithBaseScope(it, scope) }
        }

        val result = mutableListOf>()

        while (allMembersWithScope.size > 1) {
            val maxByVisibility = findMemberWithMaxVisibility(allMembersWithScope)
            val extractBothWaysWithPrivate = overrideService.extractBothWaysOverridable(maxByVisibility, allMembersWithScope, overrideChecker)
            val extractedOverrides = extractBothWaysWithPrivate.filterNotTo(mutableListOf()) {
                Visibilities.isPrivate((it.member.fir as FirMemberDeclaration).visibility)
            }.takeIf { it.isNotEmpty() } ?: extractBothWaysWithPrivate
            val baseMembersForIntersection = extractedOverrides.calcBaseMembersForIntersectionOverride()
            if (baseMembersForIntersection.size > 1) {
                val (mostSpecific, scopeForMostSpecific) = overrideService.selectMostSpecificMember(
                    baseMembersForIntersection,
                    ReturnTypeCalculatorForFullBodyResolve
                )
                val intersectionOverrideContext = ContextForIntersectionOverrideConstruction(
                    mostSpecific,
                    this,
                    extractedOverrides,
                    scopeForMostSpecific
                )
                result += ResultOfIntersection.NonTrivial(
                    intersectionOverrides,
                    intersectionOverrideContext,
                    extractedOverrides,
                    containingScope = null
                )
            } else {
                val (mostSpecific, containingScope) = baseMembersForIntersection.single()
                result += ResultOfIntersection.SingleMember(mostSpecific, extractedOverrides, containingScope)
            }
        }

        if (allMembersWithScope.isNotEmpty()) {
            val (single, containingScope) = allMembersWithScope.single()
            result += ResultOfIntersection.SingleMember(single, allMembersWithScope.toList(), containingScope)
        }

        return result
    }

    fun > createIntersectionOverride(
        extractedOverrides: List>,
        mostSpecific: D,
        scopeForMostSpecific: FirTypeScope
    ): MemberWithBaseScope> {
        val newModality = chooseIntersectionOverrideModality(extractedOverrides)
        val newVisibility = chooseIntersectionVisibility(extractedOverrides)
        val extractedOverridesSymbols = extractedOverrides.map { it.member }
        return when (mostSpecific) {
            is FirNamedFunctionSymbol -> createIntersectionOverride(mostSpecific, extractedOverridesSymbols, newModality, newVisibility)
            is FirPropertySymbol -> createIntersectionOverride(mostSpecific, extractedOverridesSymbols, newModality, newVisibility)
            else -> throw IllegalStateException("Should not be here")
        }.withScope(scopeForMostSpecific)
    }

    private fun > List>.calcBaseMembersForIntersectionOverride(): List> {
        if (size == 1) return this
        val unwrappedMemberSet = mutableSetOf>()
        for ((member, scope) in this) {
            @Suppress("UNCHECKED_CAST")
            unwrappedMemberSet += MemberWithBaseScope(member.fir.unwrapSubstitutionOverrides().symbol as S, scope)
        }
        // If in fact extracted overrides are the same symbols,
        // we should just take most specific member without creating intersection
        // A typical sample here is inheritance of the same class in different places of hierarchy
        if (unwrappedMemberSet.size == 1) {
            return listOf(overrideService.selectMostSpecificMember(this, ReturnTypeCalculatorForFullBodyResolve))
        }

        val baseMembers = mutableSetOf()
        for ((member, scope) in this) {
            @Suppress("UNCHECKED_CAST")
            if (member is FirNamedFunctionSymbol) {
                scope.processOverriddenFunctions(member) {
                    val symbol = it.fir.unwrapSubstitutionOverrides().symbol
                    if (symbol != member.fir.unwrapSubstitutionOverrides().symbol) {
                        baseMembers += symbol as S
                    }
                    ProcessorAction.NEXT
                }
            } else if (member is FirPropertySymbol) {
                scope.processOverriddenProperties(member) {
                    val symbol = it.fir.unwrapSubstitutionOverrides().symbol
                    if (symbol != member.fir.unwrapSubstitutionOverrides().symbol) {
                        baseMembers += symbol as S
                    }
                    ProcessorAction.NEXT
                }
            }
        }

        val result = this.toMutableList()
        result.removeIf { (member, _) -> member.fir.unwrapSubstitutionOverrides().symbol in baseMembers }
        return result
    }

    private fun > findMemberWithMaxVisibility(members: Collection>): MemberWithBaseScope {
        assert(members.isNotEmpty())

        var member: MemberWithBaseScope? = null
        for (candidate in members) {
            if (member == null) {
                member = candidate
                continue
            }

            val result = Visibilities.compare(
                member.member.fir.status.visibility,
                candidate.member.fir.status.visibility
            )
            if (result != null && result < 0) {
                member = candidate
            }
        }
        return member!!
    }

    private fun > chooseIntersectionOverrideModality(
        extractedOverridden: Collection>
    ): Modality? {
        var hasOpen = false
        var hasAbstract = false

        for ((member) in extractedOverridden) {
            when ((member.fir as FirMemberDeclaration).modality) {
                Modality.FINAL -> return Modality.FINAL
                Modality.SEALED -> {
                    // Members should not be sealed. But, that will be reported as WRONG_MODIFIER_TARGET, and here we shouldn't raise an
                    // internal error. Instead, let the intersection override have the default modality: null.
                    return null
                }
                Modality.OPEN -> {
                    hasOpen = true
                }
                Modality.ABSTRACT -> {
                    hasAbstract = true
                }
                null -> {
                }
            }
        }

        if (hasAbstract && !hasOpen) return Modality.ABSTRACT
        if (!hasAbstract && hasOpen) return Modality.OPEN

        @Suppress("UNCHECKED_CAST")
        val processDirectOverridden: ProcessOverriddenWithBaseScope = when (extractedOverridden.first().member) {
            is FirNamedFunctionSymbol -> FirTypeScope::processDirectOverriddenFunctionsWithBaseScope as ProcessOverriddenWithBaseScope
            is FirPropertySymbol -> FirTypeScope::processDirectOverriddenPropertiesWithBaseScope as ProcessOverriddenWithBaseScope
            else -> error("Unexpected callable kind: ${extractedOverridden.first().member}")
        }

        val realOverridden = extractedOverridden.flatMap { realOverridden(it.member, it.baseScope, processDirectOverridden) }
        val filteredOverridden = filterOutOverridden(realOverridden, processDirectOverridden)

        return filteredOverridden.minOf { (it.member.fir as FirMemberDeclaration).modality ?: Modality.ABSTRACT }
    }

    private fun > realOverridden(
        symbol: D,
        scope: FirTypeScope,
        processDirectOverridden: ProcessOverriddenWithBaseScope,
    ): Collection> {
        val result = mutableSetOf>()

        collectRealOverridden(symbol, scope, result, mutableSetOf(), processDirectOverridden)

        return result
    }

    private inline fun  D.unwrapSubstitutionOverrides(): D {
        var current = this

        do {
            val next = current.originalForSubstitutionOverride ?: return current
            current = next
        } while (true)
    }

    private fun > collectRealOverridden(
        symbol: D,
        scope: FirTypeScope,
        result: MutableCollection>,
        visited: MutableSet,
        processDirectOverridden: FirTypeScope.(D, (D, FirTypeScope) -> ProcessorAction) -> ProcessorAction,
    ) {
        if (!visited.add(symbol)) return
        if (!symbol.fir.origin.fromSupertypes) {
            result.add(MemberWithBaseScope(symbol, scope))
            return
        }

        scope.processDirectOverridden(symbol) { overridden, baseScope ->
            collectRealOverridden(overridden, baseScope, result, visited, processDirectOverridden)
            ProcessorAction.NEXT
        }
    }

    private fun > chooseIntersectionVisibility(
        extractedOverrides: Collection>
    ): Visibility {
        var maxVisibility: Visibility = Visibilities.Private
        for ((override) in extractedOverrides) {
            val visibility = (override.fir as FirMemberDeclaration).visibility
            // TODO: There is more complex logic at org.jetbrains.kotlin.resolve.OverridingUtil.resolveUnknownVisibilityForMember
            // TODO: and org.jetbrains.kotlin.resolve.OverridingUtil.findMaxVisibility
            val compare = Visibilities.compare(visibility, maxVisibility) ?: return Visibilities.DEFAULT_VISIBILITY
            if (compare > 0) {
                maxVisibility = visibility
            }
        }
        return maxVisibility
    }

    private fun createIntersectionOverride(
        mostSpecific: FirNamedFunctionSymbol,
        overrides: Collection>,
        newModality: Modality?,
        newVisibility: Visibility,
    ): FirNamedFunctionSymbol {

        val newSymbol =
            FirIntersectionOverrideFunctionSymbol(
                CallableId(
                    dispatchReceiverType.classId ?: mostSpecific.dispatchReceiverClassOrNull()?.classId!!,
                    mostSpecific.fir.name
                ),
                overrides
            )
        val mostSpecificFunction = mostSpecific.fir
        FirFakeOverrideGenerator.createCopyForFirFunction(
            newSymbol,
            mostSpecificFunction, session, FirDeclarationOrigin.IntersectionOverride,
            mostSpecificFunction.isExpect,
            newDispatchReceiverType = dispatchReceiverType,
            newModality = newModality,
            newVisibility = newVisibility,
        ).apply {
            originalForIntersectionOverrideAttr = mostSpecific.fir
        }
        return newSymbol
    }

    private fun createIntersectionOverride(
        mostSpecific: FirPropertySymbol,
        overrides: Collection>,
        newModality: Modality?,
        newVisibility: Visibility,
    ): FirPropertySymbol {
        val callableId = CallableId(
            dispatchReceiverType.classId ?: mostSpecific.dispatchReceiverClassOrNull()?.classId!!,
            mostSpecific.fir.name
        )
        val newSymbol = FirIntersectionOverridePropertySymbol(callableId, overrides)
        val mostSpecificProperty = mostSpecific.fir
        FirFakeOverrideGenerator.createCopyForFirProperty(
            newSymbol, mostSpecificProperty, session, FirDeclarationOrigin.IntersectionOverride,
            newModality = newModality,
            newVisibility = newVisibility,
            newDispatchReceiverType = dispatchReceiverType,
        ).apply {
            originalForIntersectionOverrideAttr = mostSpecific.fir
        }
        return newSymbol
    }
}

private fun > D.withScope(baseScope: FirTypeScope) = MemberWithBaseScope(this, baseScope)

class FirIntersectionOverrideStorage(val session: FirSession) : FirSessionComponent {
    private val cachesFactory = session.firCachesFactory

    class CacheForScope(cachesFactory: FirCachesFactory) {
        val intersectionOverrides: FirCache, MemberWithBaseScope>, ContextForIntersectionOverrideConstruction<*>> =
            cachesFactory.createCache { mostSpecific, context ->
                val (_, intersectionScope, extractedOverrides, scopeForMostSpecific) = context
                intersectionScope.createIntersectionOverride(extractedOverrides, mostSpecific, scopeForMostSpecific)
            }
    }

    data class ContextForIntersectionOverrideConstruction>(
        val mostSpecific: D,
        val intersectionContext: FirTypeIntersectionScopeContext,
        val extractedOverrides: List>,
        val scopeForMostSpecific: FirTypeScope
    )

    val cacheByScope: FirCache =
        cachesFactory.createCache { _ -> CacheForScope(cachesFactory) }
}

private val FirSession.intersectionOverrideStorage: FirIntersectionOverrideStorage by FirSession.sessionComponentAccessor()

@OptIn(ExperimentalContracts::class)
fun > ResultOfIntersection.isIntersectionOverride(): Boolean {
    contract {
        returns(true) implies (this@isIntersectionOverride is ResultOfIntersection.NonTrivial)
    }
    return this is ResultOfIntersection.NonTrivial
}