All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.fir.extensions.FirExtensionDeclarationsSymbolProvider.kt Maven / Gradle / Ivy

There is a newer version: 2.1.0-RC
Show newest version
/*
 * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.fir.extensions

import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.FirSessionComponent
import org.jetbrains.kotlin.fir.caches.*
import org.jetbrains.kotlin.fir.declarations.validate
import org.jetbrains.kotlin.fir.ownerGenerator
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolNamesProvider
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProvider
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProviderInternals
import org.jetbrains.kotlin.fir.resolve.providers.symbolProvider
import org.jetbrains.kotlin.fir.scopes.impl.nestedClassifierScope
import org.jetbrains.kotlin.fir.scopes.processClassifiersByName
import org.jetbrains.kotlin.fir.symbols.impl.*
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addToStdlib.flatGroupBy

@OptIn(FirExtensionApiInternals::class)
class FirExtensionDeclarationsSymbolProvider private constructor(
    session: FirSession,
    cachesFactory: FirCachesFactory,
    private val extensions: List
) : FirSymbolProvider(session), FirSessionComponent {
    companion object {
        fun createIfNeeded(session: FirSession): FirExtensionDeclarationsSymbolProvider? {
            val extensions = session.extensionService.declarationGenerators
            if (extensions.isEmpty()) return null
            return FirExtensionDeclarationsSymbolProvider(session, session.firCachesFactory, extensions)
        }
    }

    // ------------------------------------------ caches ------------------------------------------

    private val classCache: FirCache?, Nothing?> = cachesFactory.createCache { classId, _ ->
        generateClassLikeDeclaration(classId)
    }

    private val functionCache: FirCache, Nothing?> = cachesFactory.createCache { callableId, _ ->
        generateTopLevelFunctions(callableId)
    }

    private val propertyCache: FirCache, Nothing?> = cachesFactory.createCache { callableId, _ ->
        generateTopLevelProperties(callableId)
    }

    private val packageCache: FirCache = cachesFactory.createCache { packageFqName, _ ->
        hasPackage(packageFqName)
    }

    private val callableNamesInPackageCache: FirLazyValue>> =
        cachesFactory.createLazyValue {
            computeNamesGroupedByPackage(
                FirDeclarationGenerationExtension::getTopLevelCallableIds,
                CallableId::packageName, CallableId::callableName
            )
        }

    private val classNamesInPackageCache: FirLazyValue>> =
        cachesFactory.createLazyValue {
            computeNamesGroupedByPackage(
                FirDeclarationGenerationExtension::getTopLevelClassIds,
                ClassId::getPackageFqName
            ) { it.shortClassName.asString() }
        }

    private fun  computeNamesGroupedByPackage(
        ids: FirDeclarationGenerationExtension.() -> Collection,
        packageFqName: (I) -> FqName,
        shortName: (I) -> N,
    ): Map> =
        buildMap> {
            for (extension in extensions) {
                for (id in extension.ids()) {
                    getOrPut(packageFqName(id)) { mutableSetOf() }.add(shortName(id))
                }
            }
        }

    private val extensionsByTopLevelClassId: FirLazyValue>> =
        session.firCachesFactory.createLazyValue {
            extensions.flatGroupBy { it.topLevelClassIdsCache.getValue() }
        }

    private val extensionsByTopLevelCallableId: FirLazyValue>> =
        session.firCachesFactory.createLazyValue {
            extensions.flatGroupBy { it.topLevelCallableIdsCache.getValue() }
        }

    // ------------------------------------------ generators ------------------------------------------

    private fun generateClassLikeDeclaration(classId: ClassId): FirClassLikeSymbol<*>? {
        return when {
            classId.isLocal -> null
            classId.isNestedClass -> {
                val owner = session.symbolProvider.getClassLikeSymbolByClassId(classId.outerClassId!!) as? FirClassSymbol<*> ?: return null
                val nestedClassifierScope = session.nestedClassifierScope(owner.fir) ?: return null
                var result: FirClassLikeSymbol<*>? = null
                nestedClassifierScope.processClassifiersByName(classId.shortClassName) {
                    if (it is FirClassLikeSymbol<*>) {
                        result = it
                    }
                }
                result
            }
            else -> {
                val matchedExtensions = extensionsByTopLevelClassId.getValue()[classId] ?: return null
                val generatedClasses = matchedExtensions
                    .mapNotNull { generatorExtension ->
                        generatorExtension.generateTopLevelClassLikeDeclaration(classId)?.also { symbol ->
                            symbol.fir.ownerGenerator = generatorExtension
                        }
                    }
                    .onEach { it.fir.validate() }
                when (generatedClasses.size) {
                    0 -> null
                    1 -> generatedClasses.first()
                    else -> error("Multiple plugins generated classes with same classId $classId\n${generatedClasses.joinToString("\n") { it.fir.render() }}")
                }
            }
        }
    }

    private fun generateTopLevelFunctions(callableId: CallableId): List {
        return extensionsByTopLevelCallableId.getValue()[callableId].orEmpty()
            .flatMap { it.generateFunctions(callableId, context = null) }
            .onEach { it.fir.validate() }
    }

    private fun generateTopLevelProperties(callableId: CallableId): List {
        return extensionsByTopLevelCallableId.getValue()[callableId].orEmpty()
            .flatMap { it.generateProperties(callableId, context = null) }
            .onEach { it.fir.validate() }
    }

    private fun hasPackage(packageFqName: FqName): Boolean {
        return extensions.any { it.hasPackage(packageFqName) }
    }

    // ------------------------------------------ provider methods ------------------------------------------

    override val symbolNamesProvider: FirSymbolNamesProvider = object : FirSymbolNamesProvider() {
        override fun getTopLevelClassifierNamesInPackage(packageFqName: FqName): Set =
            classNamesInPackageCache.getValue()[packageFqName] ?: emptySet()

        override fun getPackageNamesWithTopLevelCallables(): Set =
            extensions.flatMapTo(mutableSetOf()) { extension ->
                extension.topLevelCallableIdsCache.getValue().map { it.packageName.asString() }
            }

        override fun getTopLevelCallableNamesInPackage(packageFqName: FqName): Set =
            callableNamesInPackageCache.getValue()[packageFqName].orEmpty()
    }

    override fun getClassLikeSymbolByClassId(classId: ClassId): FirClassLikeSymbol<*>? {
        return classCache.getValue(classId)
    }

    @FirSymbolProviderInternals
    override fun getTopLevelCallableSymbolsTo(destination: MutableList>, packageFqName: FqName, name: Name) {
        val callableId = CallableId(packageFqName, name)
        destination += functionCache.getValue(callableId)
        destination += propertyCache.getValue(callableId)
    }

    @FirSymbolProviderInternals
    override fun getTopLevelFunctionSymbolsTo(destination: MutableList, packageFqName: FqName, name: Name) {
        destination += functionCache.getValue(CallableId(packageFqName, name))
    }

    @FirSymbolProviderInternals
    override fun getTopLevelPropertySymbolsTo(destination: MutableList, packageFqName: FqName, name: Name) {
        destination += propertyCache.getValue(CallableId(packageFqName, name))
    }

    override fun getPackage(fqName: FqName): FqName? {
        return fqName.takeIf { packageCache.getValue(fqName, null) }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy