All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.improbable.keanu.network.TransitiveClosure Maven / Gradle / Ivy

package io.improbable.keanu.network;

import io.improbable.keanu.vertices.Vertex;
import lombok.Value;

import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static io.improbable.keanu.network.Propagation.getVertices;

/**
 * A Transitive Closure is defined as a given vertex and all the vertices that it affects (downstream) OR
 * all of the vertices that affects it (upstream). Unlike a Lambda Section it does not stop at observed or probabilistic vertices.
 * 

* For example: *

* A = SomeDistribution(...) * B = A.cos() * C = SomeDistribution(B, ...) * D = C.times(2) *

* The downstream Transitive Closure of A would be [A, B, C, D] * The upstream Transitive Closure of D would be [D, C, B, A] * The upstream Transitive Closure of C would be [C, B, A] */ @Value public class TransitiveClosure { private static final Predicate ADD_ALL = vertex -> true; private static final Predicate PROBABILISTIC_OR_OBSERVED_ONLY = vertex -> vertex.isObserved() || vertex.isProbabilistic(); private final Set allVertices; private final Set latentAndObservedVertices; private TransitiveClosure(Set allVertices) { this.allVertices = allVertices; this.latentAndObservedVertices = allVertices.stream() .filter(PROBABILISTIC_OR_OBSERVED_ONLY) .collect(Collectors.toSet()); } /** * @param aVertex the starting vertex * @param includeNonProbabilistic false if only the probabilistic or observed vertices are wanted * @return All upstream vertices, not including non probabilistic if includeNonProbabilistic is false. */ public static TransitiveClosure getUpstreamVertices(Vertex aVertex, boolean includeNonProbabilistic) { return getUpstreamVerticesForCollection(Collections.singletonList(aVertex), includeNonProbabilistic); } /** * @param aVertex the starting vertex * @param includeNonProbabilistic false if only the probabilistic and observed are wanted * @return All downstream vertices, not including non probabilistic if includeNonProbabilistic is false. */ public static TransitiveClosure getDownstreamVertices(Vertex aVertex, boolean includeNonProbabilistic) { return getDownstreamVerticesForCollection(Collections.singletonList(aVertex), includeNonProbabilistic); } /** * @param vertices the starting vertices * @param includeNonProbabilistic false if only the probabilistic or observed vertices are wanted * @return All upstream vertices, not including non probabilistic if includeNonProbabilistic is false. */ public static TransitiveClosure getUpstreamVerticesForCollection(List vertices, boolean includeNonProbabilistic) { Predicate shouldAdd = includeNonProbabilistic ? ADD_ALL : PROBABILISTIC_OR_OBSERVED_ONLY; Set upstreamVertices = getVertices( vertices, Vertex::getParents, v -> false, shouldAdd ); return new TransitiveClosure(upstreamVertices); } /** * @param vertices the starting vertices * @param includeNonProbabilistic false if only the probabilistic or observed vertices are wanted * @return All upstream vertices, not including non probabilistic if includeNonProbabilistic is false. */ public static TransitiveClosure getDownstreamVerticesForCollection(List vertices, boolean includeNonProbabilistic) { Predicate shouldAdd = includeNonProbabilistic ? ADD_ALL : PROBABILISTIC_OR_OBSERVED_ONLY; Set downstreamVertices = getVertices( vertices, Vertex::getChildren, v -> false, shouldAdd ); return new TransitiveClosure(downstreamVertices); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy