io.improbable.keanu.algorithms.graphtraversal.DifferentiableChecker Maven / Gradle / Ivy
package io.improbable.keanu.algorithms.graphtraversal;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.tensor.number.floating.dbl.DoubleVertex;
import lombok.experimental.UtilityClass;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
/**
* Utility class for checking whether the given vertices are all differentiable w.r.t latents.
* When given latent variables, this ensures that the dLogProb can be calculated.
*
* This check is performed by traversing up each vertex's parents and ensuring that the path to next RV is
* differentiable or constant valued.
* If there is a non differentiable vertex on this path, then if it is constant valued (0 gradient) it has no effect
* and therefore will return true.
*
* -- Examples --
* RV = Random Variable
* (G) = A vertex we want to check whether differentiable
* ND = Non-differentiable vertex
* D = Differentiable vertex
* C = Constant valued vertex
*
* - Differentiable -
*
* RV RV
* \ /
* D RV
* \ /
* RV(G)
*
* This graph is differentiable as traversing up each of the vertex's parent to the next RV is a differentiable path.
*
* C C
* \ /
* ND RV
* \ /
* RV(G)
*
* This graph is differentiable as the path that is non differentiable is constant valued.
*
* - Not Differentiable -
*
* RV RV - Both RV not observed
* \ /
* ND RV
* \ /
* RV(G)
*
* This is not differentiable as there is a non differentiable path which does not have a constant value.
*
* For more examples see DifferentiableCheckerTest.java.
*/
@UtilityClass
public class DifferentiableChecker {
/**
* @param vertices the vertices to check are differentiable w.r.t latents.
* @return true if all given vertices are differentiable, false otherwise.
*/
public static boolean isDifferentiableWrtLatents(Collection vertices) {
// All probabilistic need to be double or observed to ensure that the dLogProb can be calculated, for example
// the dLogProb of BernoulliVertex can only be calculated when it is observed.
if (!allProbabilisticAreDoubleOrObserved(vertices)) {
return false;
}
Set allParents = allParentsOf(vertices);
Set constantValueVerticesCache = new HashSet<>();
return diffableOrConstantUptoNextRV(allParents, constantValueVerticesCache);
}
private static boolean allProbabilisticAreDoubleOrObserved(Collection vertices) {
return vertices.stream().filter(Vertex::isProbabilistic)
.allMatch(DifferentiableChecker::isDoubleOrObserved);
}
private static boolean isDoubleOrObserved(Vertex v) {
return (v instanceof DoubleVertex || v.isObserved());
}
private static Set allParentsOf(Collection vertices) {
Set allParents = new HashSet<>();
for (Vertex vertex : vertices) {
allParents.addAll(vertex.getParents());
}
return allParents;
}
private static boolean diffableOrConstantUptoNextRV(Collection vertices, Set constantValueVerticesCache) {
return BreadthFirstSearch.bfsWithFailureCondition(vertices,
vertex -> isNonDiffableAndNotConstant(vertex, constantValueVerticesCache),
DifferentiableChecker::getParentsIfVertexIsNotProbabilistic,
BreadthFirstSearch::doNothing);
}
private static Collection getParentsIfVertexIsNotProbabilistic(Vertex visiting) {
return visiting.isProbabilistic() ? Collections.emptySet() : visiting.getParents();
}
private static boolean isNonDiffableAndNotConstant(Vertex vertex, Set constantValueVerticesCache) {
return !vertex.isDifferentiable() &&
!isVertexValueConstant(vertex, constantValueVerticesCache);
}
private static boolean isVertexValueConstant(Vertex vertex, Set constantValueVerticesCache) {
if (isValueKnownToBeConstant(vertex, constantValueVerticesCache)) {
return true;
}
return BreadthFirstSearch.bfsWithFailureCondition(Collections.singletonList(vertex),
DifferentiableChecker::isUnobservedProbabilistic,
visiting -> getParentsIfValueNotKnownToBeConstant(visiting, constantValueVerticesCache),
constantValueVerticesCache::addAll);
}
private static boolean isUnobservedProbabilistic(Vertex vertex) {
return vertex.isProbabilistic() && !vertex.isObserved();
}
private static Collection getParentsIfValueNotKnownToBeConstant(Vertex visiting,
Set constantValueVerticesCache) {
return isValueKnownToBeConstant(visiting, constantValueVerticesCache) ? Collections.emptySet() : visiting.getParents();
}
// We know whether these are constant. For cases such as a MultiplicationVertex we would need to
// explore its parents to ensure its constant.
private static boolean isValueKnownToBeConstant(Vertex vertex, Set constantValueVerticesCache) {
return vertex instanceof ConstantVertex || vertex.isObserved() || constantValueVerticesCache.contains(vertex);
}
}