
org.diffkt.tracing.TopologicalSort.kt Maven / Gradle / Ivy
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.diffkt.tracing
import java.util.*
import org.diffkt.*
/**
* A topological sort, which processes an acyclic graph and returns
* a topologically sorted list of its nodes, in which each node precedes any appearance of its
* successors. Returns null if the input graph is found to have a cycle.
*/
fun topologicalSort(
roots: List,
successors: (TNode) -> List,
skip: (TNode) -> Boolean = { false }): List? {
// First, count the predecessors of each node
val predecessorCounts: HashMap = predecessorCounts(roots, successors, skip)
// Initialize the ready set with those nodes that have no predecessors
val ready = Stack()
for ((k, v) in predecessorCounts) {
if (v == 0)
ready.push(k)
}
// Process the ready set. Output a node, and decrement the predecessor count of its successors.
val result = Stack()
while (!ready.isEmpty())
{
val node = ready.pop()
result.add(node)
for (succ in successors(node)) {
if (skip(succ)) continue
val count = predecessorCounts[succ]!!
assert(count != 0)
predecessorCounts[succ] = count - 1
if (count == 1)
ready.push(succ)
}
}
// At this point all the nodes should have been output, otherwise there was a cycle
val hadCycle: Boolean = predecessorCounts.size != result.size
return if (hadCycle) null else result
}
private fun predecessorCounts(
roots: List,
successors: (TNode) -> List,
skip: (TNode) -> Boolean = { false }): HashMap {
val predecessorCounts = HashMap()
val counted = HashSet()
val toCount = Stack()
for (r in roots) {
if (!skip(r))
toCount.add(r)
}
while (!toCount.isEmpty()) {
val n = toCount.pop();
if (!counted.add(n))
continue
if (!predecessorCounts.containsKey(n))
predecessorCounts.put(n, 0)
for (succ in successors(n)) if (!skip(succ)) {
toCount.push(succ)
if (predecessorCounts.containsKey(succ))
predecessorCounts[succ] = predecessorCounts[succ]!! + 1
else
predecessorCounts.put(succ, 1)
}
}
return predecessorCounts
}
fun useCounts(roots: List): HashMap {
val result = predecessorCounts(roots, ::children, { false })
// Also count appearences in roots as a predecessor.
for (r in roots)
result[r] = result[r]!! + 1
return result
}
/**
* A topological sort of a set of tracing tensors, in which each tensor appears
* after any of its inputs.
*/
internal fun topologicalSort(
roots: List,
skip: (Traceable) -> Boolean): List {
return topologicalSort(roots, ::children, skip)!!
}
internal fun children(x: Traceable) = x.accept(childrenVisitor)
private object childrenVisitor: TracingVisitor> {
override fun visitConstant(x: TracingTensor.Constant) = listOf()
override fun visitVariable(x: TracingTensor.Variable) = listOf()
override fun visitPlus(x: TracingTensor.Plus) = listOf(x.left, x.right)
override fun visitMinus(x: TracingTensor.Minus) = listOf(x.left, x.right)
override fun visitTimes(x: TracingTensor.Times) = listOf(x.left, x.right)
override fun visitTimesScalar(x: TracingTensor.TimesScalar) = listOf(x.left, x.right)
override fun visitDiv(x: TracingTensor.Div) = listOf(x.left, x.right)
override fun visitZero(x: TracingTensor.Zero) = listOf()
override fun visitIdentityGradient(x: TracingTensor.IdentityGradient) = listOf()
override fun visitUnaryMinus(x: TracingTensor.UnaryMinus) = listOf(x.x)
override fun visitMatmul(x: TracingTensor.Matmul) = listOf(x.x, x.y)
override fun visitOuterProduct(x: TracingTensor.OuterProduct) = listOf(x.x, x.y)
override fun visitSin(x: TracingTensor.Sin) = listOf(x.x)
override fun visitCos(x: TracingTensor.Cos) = listOf(x.x)
override fun visitTan(x: TracingTensor.Tan) = listOf(x.x)
override fun visitAtan(x: TracingTensor.Atan) = listOf(x.x)
override fun visitExp(x: TracingTensor.Exp) = listOf(x.x)
override fun visitLn(x: TracingTensor.Ln) = listOf(x.x)
override fun visitLgamma(x: TracingTensor.Lgamma) = listOf(x.x)
override fun visitDigamma(x: TracingTensor.Digamma) = listOf(x.x)
override fun visitPolygamma(x: TracingTensor.Polygamma) = listOf(x.x)
override fun visitSqrt(x: TracingTensor.Sqrt) = listOf(x.x)
override fun visitTanh(x: TracingTensor.Tanh) = listOf(x.x)
override fun visitMeld(x: TracingTensor.Meld) = x.values
override fun visitSplit(x: TracingTensor.Split) = listOf(x.x)
override fun visitSplitPart(x: TracingTensor.SplitPart) = listOf(x.from)
override fun visitConcat(x: TracingTensor.Concat) = x.slices
override fun visitBroadcastTo(x: TracingTensor.BroadcastTo) = listOf(x.x)
override fun visitConvImpl(x: TracingTensor.ConvImpl) = listOf(x.filter, x.signal)
override fun visitExpand(x: TracingTensor.Expand) = listOf(x.x)
override fun visitFlip(x: TracingTensor.Flip) = listOf(x.x)
override fun visitLogSoftmax(x: TracingTensor.LogSoftmax) = listOf(x.x)
override fun visitLogSoftmaxGrad(x: TracingTensor.LogSoftmaxGrad) = listOf(x.x, x.logSoftmax, x.upstream)
override fun visitPow(x: TracingTensor.Pow) = listOf(x.base)
override fun visitView1(x: TracingTensor.View1) = listOf(x.x)
override fun visitView2(x: TracingTensor.View2) = listOf(x.x)
override fun visitView3(x: TracingTensor.View3) = listOf(x.x)
override fun visitReshape(x: TracingTensor.Reshape) = listOf(x.x)
override fun visitReshapeToScalar(x: TracingScalar.ReshapeToScalar) = listOf(x.x)
override fun visitSqueeze(x: TracingTensor.Squeeze) = listOf(x.x)
override fun visitUnsqueeze(x: TracingTensor.Unsqueeze) = listOf(x.x)
override fun visitTranspose(x: TracingTensor.Transpose) = listOf(x.x)
override fun visitRelu(x: TracingTensor.Relu) = listOf(x.x)
override fun visitReluGrad(x: TracingTensor.ReluGrad) = listOf(x.x, x.upstream)
override fun visitSigmoid(x: TracingTensor.Sigmoid) = listOf(x.x)
override fun visitSum(x: TracingTensor.Sum) = listOf(x.x)
override fun visitAvgPool(x: TracingTensor.AvgPool) = listOf(x.x)
override fun visitAvgPoolGrad(x: TracingTensor.AvgPoolGrad) = listOf(x.x)
override fun visitMaxPoolWithIndices(x: TracingTensor.MaxPoolWithIndices) = listOf(x.x)
override fun visitGather(x: TracingTensor.Gather) = listOf(x.x)
override fun visitGatherAtIndices(x: TracingTensor.GatherAtIndices) = listOf(x.x)
override fun visitScatter(x: TracingTensor.Scatter) = listOf(x.x)
override fun visitScatterAtIndices(x: TracingTensor.ScatterAtIndices) = listOf(x.x)
override fun visitCompare(x: TracingTensor.Compare) = listOf(x.left, x.right)
override fun visitIfThenElse(x: TracingTensor.IfThenElse) = listOf(x.cond, x.whenTrue, x.whenFalse)
override fun visitRandomFloats(x: TracingTensor.RandomFloats) = listOf(x.key)
override fun visitRandomVariable(x: TracingRandomKey.Variable): List = listOf()
override fun visitRandomSplit(x: TracingRandomKey.Split): List = listOf(x.key)
override fun visitRandomSplitPart(x: TracingRandomKey.SplitPart): List = listOf(x.split)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy