
com.github.vangj.jbayes.inf.exact.sampling.Table Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jbayes-inference Show documentation
Show all versions of jbayes-inference Show documentation
A very cool project for BBN inference using approximate and exact algorithms.
The newest version!
package com.github.vangj.jbayes.inf.exact.sampling;
import com.github.vangj.jbayes.inf.exact.graph.Node;
import com.github.vangj.jbayes.inf.exact.graph.Variable;
import com.github.vangj.jbayes.inf.exact.graph.util.NodeUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class Table {
private final Node node;
private final List parents;
private final List parentIds;
private Map> probs;
public Table(Node node, List parents) {
this.node = node;
parents.sort((n1, n2) -> n1.getId().compareTo(n2.getId()));
this.parents = parents;
this.parentIds = this.parents.stream()
.map(Variable::getId)
.collect(Collectors.toList());
this.probs = new HashMap<>();
if (0 == this.parents.size()) {
List cumsums = getCumsums(node.probs());
this.probs.put("default", cumsums);
} else {
List> lists = parents.stream()
.map(Variable::getValues)
.map(s -> (List) new ArrayList(s))
.collect(Collectors.toList());
List> cartesian = NodeUtil.product(lists);
List keys = cartesian.stream()
.map(values -> {
final int paSize = this.parents.size();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < paSize; i++) {
Node n = this.parents.get(i);
String v = values.get(i);
String s = n.getId() + "=" + v;
sb.append(s);
if (i < paSize - 1) {
sb.append(",");
}
}
return sb.toString();
})
.collect(Collectors.toList());
final int n = node.getValues().size();
List nodeProbs = node.probs();
List> probs = NodeUtil.groupList(n, nodeProbs).stream()
.map(Table::getCumsums)
.collect(Collectors.toList());
this.probs = new HashMap<>();
final int numPairs = keys.size();
for (int i = 0; i < numPairs; i++) {
String k = keys.get(i);
List v = probs.get(i);
this.probs.put(k, v);
}
}
}
private static List getCumsums(List probs) {
final int size = probs.size();
List cumsums = new ArrayList<>();
cumsums.add(probs.get(0));
for (int i = 1; i < size; i++) {
double sum = cumsums.get(i - 1) + probs.get(i);
cumsums.add(sum);
}
return cumsums;
}
public String getValue(final double prob, Map sample) {
if (!this.hasParents()) {
List probs = this.probs.get("default");
int index = NodeUtil.bisectRight(probs, prob);
return node.getValueList().get(index);
} else {
String k = this.parentIds.stream()
.map(id -> id + "=" + sample.get(id))
.collect(Collectors.joining(","));
List probs = this.probs.get(k);
int index = NodeUtil.bisectRight(probs, prob);
return node.getValueList().get(index);
}
}
public boolean hasParents() {
return parents.size() > 0;
}
public Map> getProbs() {
return probs;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy