org.jetbrains.kotlin.fir.analysis.cfa.FirCallsEffectAnalyzer.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kotlin-compiler-embeddable Show documentation
Show all versions of kotlin-compiler-embeddable Show documentation
the Kotlin compiler embeddable
/*
* Copyright 2010-2022 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.fir.analysis.cfa
import kotlinx.collections.immutable.PersistentMap
import kotlinx.collections.immutable.persistentMapOf
import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.contracts.description.EventOccurrencesRange
import org.jetbrains.kotlin.contracts.description.MarkedEventOccurrencesRange
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import org.jetbrains.kotlin.diagnostics.reportOn
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.cfa.util.*
import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import org.jetbrains.kotlin.fir.analysis.checkers.cfa.FirControlFlowChecker
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors
import org.jetbrains.kotlin.fir.contracts.FirContractDescription
import org.jetbrains.kotlin.fir.contracts.description.ConeCallsEffectDeclaration
import org.jetbrains.kotlin.fir.contracts.effects
import org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction
import org.jetbrains.kotlin.fir.declarations.FirContractDescriptionOwner
import org.jetbrains.kotlin.fir.declarations.FirFunction
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
import org.jetbrains.kotlin.fir.references.FirReference
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.FirThisReference
import org.jetbrains.kotlin.fir.resolve.dfa.cfg.*
import org.jetbrains.kotlin.fir.resolve.isFunctionOrSuspendFunctionInvoke
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.types.FirTypeRef
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.isSomeFunctionType
import org.jetbrains.kotlin.utils.addIfNotNull
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
object FirCallsEffectAnalyzer : FirControlFlowChecker(MppCheckerKind.Common) {
override fun analyze(graph: ControlFlowGraph, reporter: DiagnosticReporter, context: CheckerContext) {
// TODO, KT-59816: this is quadratic due to `graph.traverse`, surely there is a better way?
for (subGraph in graph.subGraphs) {
analyze(subGraph, reporter, context)
}
val session = context.session
val function = (graph.declaration as? FirFunction) ?: return
if (function !is FirContractDescriptionOwner) return
val contractDescription = function.contractDescription ?: return
if (contractDescription.effects?.any { it.effect is ConeCallsEffectDeclaration } != true) return
val functionalTypeEffects = mutableMapOf, ConeCallsEffectDeclaration>()
function.valueParameters.forEachIndexed { index, parameter ->
if (parameter.returnTypeRef.isFunctionTypeRef(session)) {
val effectDeclaration = contractDescription.getParameterCallsEffectDeclaration(index)
if (effectDeclaration != null) functionalTypeEffects[parameter.symbol] = effectDeclaration
}
}
if (function.receiverParameter?.typeRef.isFunctionTypeRef(session)) {
val effectDeclaration = contractDescription.getParameterCallsEffectDeclaration(-1)
if (effectDeclaration != null) functionalTypeEffects[function.symbol] = effectDeclaration
}
if (functionalTypeEffects.isEmpty()) return
val leakedSymbols = mutableMapOf, MutableList>()
graph.traverse(
CapturedLambdaFinder(function),
IllegalScopeContext(functionalTypeEffects.keys, leakedSymbols)
)
for ((symbol, leakedPlaces) in leakedSymbols) {
reporter.reportOn(contractDescription.source, FirErrors.LEAKED_IN_PLACE_LAMBDA, symbol, context)
leakedPlaces.forEach {
reporter.reportOn(it, FirErrors.LEAKED_IN_PLACE_LAMBDA, symbol, context)
}
}
val invocationData = graph.collectDataForNode(
TraverseDirection.Forward,
InvocationDataCollector(functionalTypeEffects.keys.filterTo(mutableSetOf()) { it !in leakedSymbols })
)
for ((symbol, effectDeclaration) in functionalTypeEffects) {
graph.exitNode.previousCfgNodes.forEach { node ->
val requiredRange = effectDeclaration.kind
val pathAwareInfo = invocationData.getValue(node)
val info = pathAwareInfo[NormalPath] ?: return@forEach
investigate(info, symbol, requiredRange, function, reporter, context)
}
}
}
private fun investigate(
info: LambdaInvocationInfo,
symbol: FirBasedSymbol<*>,
requiredRange: EventOccurrencesRange,
function: FirContractDescriptionOwner,
reporter: DiagnosticReporter,
context: CheckerContext
) {
val foundRange = info[symbol]?.withoutMarker ?: EventOccurrencesRange.ZERO
if (foundRange !in requiredRange) {
reporter.reportOn(
function.contractDescription?.source,
FirErrors.WRONG_INVOCATION_KIND,
symbol,
requiredRange,
foundRange,
context
)
}
}
private class IllegalScopeContext(
private val functionalTypeSymbols: Set>,
private val leakedSymbols: MutableMap, MutableList>,
) {
private var scopeDepth: Int = 0
private var illegalScopeDepth: Int? = null
val inIllegalScope: Boolean get() = illegalScopeDepth != null
fun enterScope(legal: Boolean) {
scopeDepth++
if (illegalScopeDepth == null && !legal) illegalScopeDepth = scopeDepth
}
fun exitScope() {
if (scopeDepth == illegalScopeDepth) illegalScopeDepth = null
scopeDepth--
}
inline fun checkExpressionForLeakedSymbols(
fir: FirExpression?,
source: KtSourceElement? = fir?.source,
illegalUsage: () -> Boolean = { false }
) {
val symbol = referenceToSymbol(fir.toQualifiedReference())
if (symbol != null && symbol in functionalTypeSymbols && (inIllegalScope || illegalUsage())) {
leakedSymbols.getOrPut(symbol, ::mutableListOf).addIfNotNull(source)
}
}
}
private class CapturedLambdaFinder(val rootFunction: FirFunction) : ControlFlowGraphVisitor() {
override fun visitNode(node: CFGNode<*>, data: IllegalScopeContext) {}
override fun visitFunctionEnterNode(node: FunctionEnterNode, data: IllegalScopeContext) {
// TODO, KT-59668: this is not how CFG works, this should be done by FIR tree traversal. Especially considering that
// none of these methods use anything from the CFG other than `node.fir`, which should've been a hint.
data.enterScope(node.fir === rootFunction || node.fir.isInPlaceLambda())
}
override fun visitFunctionExitNode(node: FunctionExitNode, data: IllegalScopeContext) {
data.exitScope()
}
override fun visitPropertyInitializerEnterNode(node: PropertyInitializerEnterNode, data: IllegalScopeContext) {
data.enterScope(false)
data.checkExpressionForLeakedSymbols(node.fir.initializer)
}
override fun visitPropertyInitializerExitNode(node: PropertyInitializerExitNode, data: IllegalScopeContext) {
data.exitScope()
}
override fun visitInitBlockEnterNode(node: InitBlockEnterNode, data: IllegalScopeContext) {
data.enterScope(false)
}
override fun visitInitBlockExitNode(node: InitBlockExitNode, data: IllegalScopeContext) {
data.exitScope()
}
override fun visitVariableAssignmentNode(node: VariableAssignmentNode, data: IllegalScopeContext) {
data.checkExpressionForLeakedSymbols(node.fir.rValue) { true }
}
override fun visitVariableDeclarationNode(node: VariableDeclarationNode, data: IllegalScopeContext) {
data.checkExpressionForLeakedSymbols(node.fir.initializer) { true }
}
override fun visitFunctionCallNode(node: FunctionCallNode, data: IllegalScopeContext) {
val functionSymbol = node.fir.toResolvedCallableSymbol() as? FirFunctionSymbol<*>?
val contractDescription = functionSymbol?.resolvedContractDescription
val callSource = node.fir.explicitReceiver?.source ?: node.fir.source
data.checkExpressionForLeakedSymbols(node.fir.explicitReceiver, callSource) {
functionSymbol?.callableId?.isFunctionOrSuspendFunctionInvoke() != true && contractDescription?.getParameterCallsEffect(-1) == null
}
for (arg in node.fir.argumentList.arguments) {
val unwrappedArg = arg.unwrapArgument()
data.checkExpressionForLeakedSymbols(unwrappedArg) {
node.fir.getArgumentCallsEffect(arg) == null
}
}
}
}
@Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE") // K2 warning suppression, TODO: KT-62472
class LambdaInvocationInfo(
map: PersistentMap, EventOccurrencesRangeAtNode> = persistentMapOf(),
) : EventOccurrencesRangeInfo>(map) {
companion object {
val EMPTY = LambdaInvocationInfo()
}
override val constructor: (PersistentMap, EventOccurrencesRangeAtNode>) -> LambdaInvocationInfo =
::LambdaInvocationInfo
}
private class InvocationDataCollector(
val functionalTypeSymbols: Set>
) : PathAwareControlFlowGraphVisitor() {
companion object {
private val EMPTY_INFO: PathAwareLambdaInvocationInfo = persistentMapOf(NormalPath to LambdaInvocationInfo.EMPTY)
}
override val emptyInfo: PathAwareLambdaInvocationInfo
get() = EMPTY_INFO
override fun visitFunctionCallNode(
node: FunctionCallNode,
data: PathAwareLambdaInvocationInfo
): PathAwareLambdaInvocationInfo {
var dataForNode = visitNode(node, data)
val functionSymbol = node.fir.toResolvedCallableSymbol() as? FirFunctionSymbol<*>?
val contractDescription = functionSymbol?.resolvedContractDescription
dataForNode = dataForNode.checkReference(node.fir.explicitReceiver.toQualifiedReference()) {
when {
functionSymbol?.callableId?.isFunctionOrSuspendFunctionInvoke() == true -> MarkedEventOccurrencesRange.ExactlyOnce(node)
else -> contractDescription.getParameterCallsEffect(-1)?.at(node)
?: MarkedEventOccurrencesRange.Zero
}
}
for (arg in node.fir.argumentList.arguments) {
dataForNode = dataForNode.checkReference(arg.toQualifiedReference()) {
node.fir.getArgumentCallsEffect(arg)?.at(node) ?: MarkedEventOccurrencesRange.Zero
}
}
return dataForNode
}
@OptIn(ExperimentalContracts::class)
private fun collectDataForReference(reference: FirReference?): Boolean {
contract {
returns(true) implies (reference != null)
}
return reference != null && referenceToSymbol(reference) in functionalTypeSymbols
}
private inline fun PathAwareLambdaInvocationInfo.checkReference(
reference: FirReference?,
rangeGetter: () -> EventOccurrencesRangeAtNode
): PathAwareLambdaInvocationInfo {
return if (collectDataForReference(reference)) addInvocationInfo(reference, rangeGetter()) else this
}
private fun PathAwareLambdaInvocationInfo.addInvocationInfo(
reference: FirReference,
range: EventOccurrencesRangeAtNode
): PathAwareLambdaInvocationInfo {
val symbol = referenceToSymbol(reference)
return if (symbol != null) addRange(symbol, range) else this
}
}
private fun FirTypeRef?.isFunctionTypeRef(session: FirSession): Boolean {
return this?.coneType?.isSomeFunctionType(session) == true
}
private fun FirContractDescription?.getParameterCallsEffectDeclaration(index: Int): ConeCallsEffectDeclaration? {
val effects = this?.effects?.map { it.effect }
val callsEffect = effects?.find { it is ConeCallsEffectDeclaration && it.valueParameterReference.parameterIndex == index }
return callsEffect as? ConeCallsEffectDeclaration?
}
private fun FirFunctionCall.getArgumentCallsEffect(arg: FirExpression): EventOccurrencesRange? {
val functionSymbol = (this.toResolvedCallableSymbol() as? FirFunctionSymbol<*>?)
val contractDescription = functionSymbol?.resolvedContractDescription
val resolvedArguments = argumentList as? FirResolvedArgumentList
return if (functionSymbol != null && resolvedArguments != null) {
val parameter = resolvedArguments.mapping[arg]
contractDescription.getParameterCallsEffect(functionSymbol.valueParameterSymbols.indexOf(parameter?.symbol))
} else null
}
private fun FirContractDescription?.getParameterCallsEffect(index: Int): EventOccurrencesRange? {
return getParameterCallsEffectDeclaration(index)?.kind
}
private fun FirFunction.isInPlaceLambda(): Boolean {
return this is FirAnonymousFunction && this.isLambda && this.invocationKind != null
}
private fun FirExpression?.toQualifiedReference(): FirReference? {
return (this?.unwrapArgument() as? FirQualifiedAccessExpression)?.calleeReference
}
private fun referenceToSymbol(reference: FirReference?): FirBasedSymbol<*>? = when (reference) {
is FirResolvedNamedReference -> reference.resolvedSymbol
is FirThisReference -> reference.boundSymbol
else -> null
}
}
private typealias PathAwareLambdaInvocationInfo = PathAwareControlFlowInfo