All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jacodb.analysis.engine.MainIfdsUnitManager.kt Maven / Gradle / Ivy

/*
 *  Copyright 2022 UnitTestBot contributors (utbot.org)
 * 

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at *

* http://www.apache.org/licenses/LICENSE-2.0 *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jacodb.analysis.engine import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.delay import kotlinx.coroutines.ensureActive import kotlinx.coroutines.flow.map import kotlinx.coroutines.isActive import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeoutOrNull import org.jacodb.analysis.logger import org.jacodb.analysis.runAnalysis import org.jacodb.api.JcMethod import org.jacodb.api.analysis.JcApplicationGraph import java.util.concurrent.ConcurrentHashMap /** * This manager launches and manages [IfdsUnitRunner]s for all units, reachable from [startMethods]. * It also merges [TraceGraph]s from different units giving a complete [TraceGraph] for each vulnerability. * See [runAnalysis] for more info. */ class MainIfdsUnitManager( private val graph: JcApplicationGraph, private val unitResolver: UnitResolver, private val ifdsUnitRunnerFactory: IfdsUnitRunnerFactory, private val startMethods: List, private val timeoutMillis: Long ) : IfdsUnitManager { private val foundMethods: MutableMap> = mutableMapOf() private val crossUnitCallers: MutableMap> = mutableMapOf() private val summaryEdgesStorage = SummaryStorageImpl() private val tracesStorage = SummaryStorageImpl() private val crossUnitCallsStorage = SummaryStorageImpl() private val vulnerabilitiesStorage = SummaryStorageImpl() private val aliveRunners: MutableMap> = ConcurrentHashMap() private val queueEmptiness: MutableMap = mutableMapOf() private val dependencies: MutableMap> = mutableMapOf() private val dependenciesRev: MutableMap> = mutableMapOf() private fun getAllCallees(method: JcMethod): Set { val result = mutableSetOf() for (inst in method.flowGraph().instructions) { graph.callees(inst).forEach { result.add(it) } } return result } private fun addStart(method: JcMethod) { val unit = unitResolver.resolve(method) if (method in foundMethods[unit].orEmpty()) { return } foundMethods.getOrPut(unit) { mutableSetOf() }.add(method) val dependencies = getAllCallees(method) dependencies.forEach { addStart(it) } } private val IfdsVertex.traceGraph: TraceGraph get() = tracesStorage .getCurrentFacts(method) .map { it.graph } .singleOrNull { it.sink == this } ?: TraceGraph.bySink(this) /** * Launches [IfdsUnitRunner] for each observed unit, handles respective jobs, * and gathers results into list of vulnerabilities, restoring full traces */ fun analyze(): List = runBlocking(Dispatchers.Default) { withTimeoutOrNull(timeoutMillis) { logger.info { "Searching for units to analyze..." } startMethods.forEach { ensureActive() addStart(it) } val allUnits = foundMethods.keys.toList() logger.info { "Starting analysis. Number of found units: ${allUnits.size}" } val progressLoggerJob = launch { while (isActive) { delay(1000) val totalCount = allUnits.size val aliveCount = aliveRunners.size logger.info { "Current progress: ${totalCount - aliveCount} / $totalCount units completed" } } } launch { dispatchDependencies() } // TODO: do smth smarter here val allJobs = allUnits.map { unit -> val runner = ifdsUnitRunnerFactory.newRunner( graph, this@MainIfdsUnitManager, unitResolver, unit, foundMethods[unit]!!.toList() ) aliveRunners[unit] = runner runner.launchIn(this) } allJobs.joinAll() eventChannel.close() progressLoggerJob.cancel() } logger.info { "All jobs completed, gathering results..." } val foundVulnerabilities = foundMethods.values.flatten().flatMap { method -> vulnerabilitiesStorage.getCurrentFacts(method) } foundMethods.values.flatten().forEach { method -> for (crossUnitCall in crossUnitCallsStorage.getCurrentFacts(method)) { val calledMethod = graph.methodOf(crossUnitCall.calleeVertex.statement) crossUnitCallers.getOrPut(calledMethod) { mutableSetOf() }.add(crossUnitCall) } } logger.info { "Restoring traces..." } foundVulnerabilities .map { VulnerabilityInstance(it.vulnerabilityDescription, extendTraceGraph(it.sink.traceGraph)) } .filter { it.traceGraph.sources.any { source -> graph.methodOf(source.statement) in startMethods || source.domainFact == ZEROFact } } } private val TraceGraph.methods: List get() { return (edges.keys.map { graph.methodOf(it.statement) } + listOf(graph.methodOf(sink.statement))).distinct() } /** * Given a [traceGraph], searches for other traceGraphs (from different units) * and merges them into given if they extend any path leading to sink. * * This method allows to restore traces that pass through several units. */ private fun extendTraceGraph(traceGraph: TraceGraph): TraceGraph { var result = traceGraph val methodQueue: MutableSet = traceGraph.methods.toMutableSet() val addedMethods: MutableSet = methodQueue.toMutableSet() while (methodQueue.isNotEmpty()) { val method = methodQueue.first() methodQueue.remove(method) for (callFact in crossUnitCallers[method].orEmpty()) { // TODO: merge calleeVertices here val sFacts = setOf(callFact.calleeVertex) val upGraph = callFact.callerVertex.traceGraph val newValue = result.mergeWithUpGraph(upGraph, sFacts) if (result != newValue) { result = newValue for (nMethod in upGraph.methods) { if (nMethod !in addedMethods) { addedMethods.add(nMethod) methodQueue.add(nMethod) } } } } } return result } override suspend fun handleEvent(event: IfdsUnitRunnerEvent, runner: IfdsUnitRunner) { when (event) { is EdgeForOtherRunnerQuery -> { val otherRunner = aliveRunners[unitResolver.resolve(event.edge.method)] ?: return if (otherRunner.job?.isActive == true) { otherRunner.submitNewEdge(event.edge) } } is NewSummaryFact -> { when (val fact = event.fact) { is CrossUnitCallFact -> crossUnitCallsStorage.send(fact) is SummaryEdgeFact -> summaryEdgesStorage.send(fact) is TraceGraphFact -> tracesStorage.send(fact) is VulnerabilityLocation -> vulnerabilitiesStorage.send(fact) } } is QueueEmptinessChanged -> { eventChannel.send(Pair(event, runner)) } is SubscriptionForSummaryEdges -> { eventChannel.send(Pair(event, runner)) summaryEdgesStorage.getFacts(event.method).map { it.edge }.collect(event.collector) } } } // Used to linearize all events that change dependencies or queue emptiness of runners private val eventChannel: Channel>> = Channel(capacity = Int.MAX_VALUE) private suspend fun dispatchDependencies() = eventChannel.consumeEach { (event, runner) -> when (event) { is SubscriptionForSummaryEdges -> { dependencies.getOrPut(runner.unit) { mutableSetOf() } .add(unitResolver.resolve(event.method)) dependenciesRev.getOrPut(unitResolver.resolve(event.method)) { mutableSetOf() } .add(runner.unit) } is QueueEmptinessChanged -> { if (runner.unit !in aliveRunners) { return@consumeEach } queueEmptiness[runner.unit] = event.isEmpty if (event.isEmpty) { val toDelete = mutableListOf(runner.unit) while (toDelete.isNotEmpty()) { val current = toDelete.removeLast() if (current in aliveRunners && dependencies[runner.unit].orEmpty().all { queueEmptiness[it] != false } ) { aliveRunners[current]!!.job?.cancel() ?: error("Runner's job is not instantiated") aliveRunners.remove(current) for (next in dependenciesRev[current].orEmpty()) { if (queueEmptiness[next] == true) { toDelete.add(next) } } } } } } else -> error("Unexpected event for dependencies dispatcher") } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy