aima.core.probability.bayes.impl.BayesNet Maven / Gradle / Ivy
package aima.core.probability.bayes.impl;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.BayesianNetwork;
import aima.core.probability.bayes.Node;
/**
* Default implementation of the BayesianNetwork interface.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*/
public class BayesNet implements BayesianNetwork {
protected Set rootNodes = new LinkedHashSet();
protected List variables = new ArrayList();
protected Map varToNodeMap = new HashMap();
public BayesNet(Node... rootNodes) {
if (null == rootNodes) {
throw new IllegalArgumentException(
"Root Nodes need to be specified.");
}
for (Node n : rootNodes) {
this.rootNodes.add(n);
}
if (this.rootNodes.size() != rootNodes.length) {
throw new IllegalArgumentException(
"Duplicate Root Nodes Passed in.");
}
// Ensure is a DAG
checkIsDAGAndCollectVariablesInTopologicalOrder();
variables = Collections.unmodifiableList(variables);
}
//
// START-BayesianNetwork
@Override
public List getVariablesInTopologicalOrder() {
return variables;
}
@Override
public Node getNode(RandomVariable rv) {
return varToNodeMap.get(rv);
}
// END-BayesianNetwork
//
//
// PRIVATE METHODS
//
private void checkIsDAGAndCollectVariablesInTopologicalOrder() {
// Topological sort based on logic described at:
// http://en.wikipedia.org/wiki/Topoligical_sorting
Set seenAlready = new HashSet();
Map> incomingEdges = new HashMap>();
Set s = new LinkedHashSet();
for (Node n : this.rootNodes) {
walkNode(n, seenAlready, incomingEdges, s);
}
while (!s.isEmpty()) {
Node n = s.iterator().next();
s.remove(n);
variables.add(n.getRandomVariable());
varToNodeMap.put(n.getRandomVariable(), n);
for (Node m : n.getChildren()) {
List edges = incomingEdges.get(m);
edges.remove(n);
if (edges.isEmpty()) {
s.add(m);
}
}
}
for (List edges : incomingEdges.values()) {
if (!edges.isEmpty()) {
throw new IllegalArgumentException(
"Network contains at least one cycle in it, must be a DAG.");
}
}
}
private void walkNode(Node n, Set seenAlready,
Map> incomingEdges, Set rootNodes) {
if (!seenAlready.contains(n)) {
seenAlready.add(n);
// Check if has no incoming edges
if (n.isRoot()) {
rootNodes.add(n);
}
incomingEdges.put(n, new ArrayList(n.getParents()));
for (Node c : n.getChildren()) {
walkNode(c, seenAlready, incomingEdges, rootNodes);
}
}
}
}