edu.cmu.tetrad.bayes.JunctionTreeAlgorithm Maven / Gradle / Ivy
/*
* Copyright (C) 2019 University of Pittsburgh.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
*/
package edu.cmu.tetrad.bayes;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.util.TetradSerializable;
import org.apache.commons.math3.util.FastMath;
import java.util.*;
import java.util.stream.Collectors;
/**
* Junction Tree Algorithm.
*
* This implementation follows the Weka's implementation.
*
* Nov 8, 2019 2:22:34 PM
*
* @author Kevin V. Bui ([email protected])
* @see MarginCalculator.java
*/
public class JunctionTreeAlgorithm implements TetradSerializable {
private static final long serialVersionUID = 23L;
private final TreeNode root;
private final Node[] graphNodes;
private final double[][] margins;
private final Node[] maxCardOrdering;
private final BayesPm bayesPm;
private final BayesIm bayesIm;
private final Map treeNodes;
public JunctionTreeAlgorithm(Graph graph, DataModel dataModel) {
this.bayesPm = createBayesPm(dataModel, graph);
this.bayesIm = createBayesIm(dataModel, this.bayesPm);
this.treeNodes = new HashMap<>();
int numOfNodes = graph.getNumNodes();
this.graphNodes = this.bayesIm.getDag().getNodes().toArray(new Node[numOfNodes]);
this.margins = new double[numOfNodes][];
this.maxCardOrdering = new Node[numOfNodes];
this.root = buildJunctionTree();
initialize();
}
public JunctionTreeAlgorithm(BayesIm bayesIm) {
this.bayesPm = bayesIm.getBayesPm();
this.bayesIm = bayesIm;
this.treeNodes = new HashMap<>();
int numOfNodes = this.bayesPm.getDag().getNumNodes();
this.graphNodes = bayesIm.getDag().getNodes().toArray(new Node[numOfNodes]);
this.margins = new double[numOfNodes][];
this.maxCardOrdering = new Node[numOfNodes];
this.root = buildJunctionTree();
initialize();
}
private void initialize() {
for (int i = this.maxCardOrdering.length - 1; i >= 0; i--) {
TreeNode treeNode = this.treeNodes.get(this.maxCardOrdering[i]);
if (treeNode != null) {
treeNode.initializeUp();
}
}
for (Node node : this.maxCardOrdering) {
TreeNode treeNode = this.treeNodes.get(node);
if (treeNode != null) {
treeNode.initializeDown(false);
}
}
}
/**
* Create the junction tree.
*
* @return the root of the junction tree
*/
private TreeNode buildJunctionTree() {
// moralize dag
Graph undirectedGraph = GraphTools.moralize(this.bayesIm.getDag());
// triangulate
computeMaximumCardinalityOrdering(undirectedGraph, this.maxCardOrdering);
GraphTools.fillIn(undirectedGraph, this.maxCardOrdering);
// get set of cliques
computeMaximumCardinalityOrdering(undirectedGraph, this.maxCardOrdering);
Map> cliques = GraphTools.getCliques(this.maxCardOrdering, undirectedGraph);
// get separator sets
Map> separators = GraphTools.getSeparators(this.maxCardOrdering, cliques);
// get clique tree
Map parentCliques = GraphTools.getCliqueTree(this.maxCardOrdering, cliques, separators);
// create tree nodes
Set finishedCalculated = new HashSet<>();
for (Node node : this.maxCardOrdering) {
if (cliques.containsKey(node)) {
this.treeNodes.put(node, new TreeNode(cliques.get(node), finishedCalculated));
}
}
// create tree separators
for (Node node : this.maxCardOrdering) {
if (cliques.containsKey(node) && parentCliques.containsKey(node)) {
TreeNode parent = this.treeNodes.get(parentCliques.get(node));
TreeNode treeNode = this.treeNodes.get(node);
treeNode.setParentSeparator(new TreeSeparator(separators.get(node), treeNode, parent));
parent.addChildClique(treeNode);
}
}
TreeNode rootNode = null;
for (Node node : this.treeNodes.keySet()) {
if (!parentCliques.containsKey(node)) {
rootNode = this.treeNodes.get(node);
}
}
return rootNode;
}
private void computeMaximumCardinalityOrdering(Graph graph, Node[] nodes) {
Set numbered = new HashSet<>();
for (int i = 0; i < nodes.length; i++) {
// find an unnumbered node that is adjacent to the most number of numbered nodes
Node maxCardinalityNode = null;
int maxCardinality = -1;
for (Node v : graph.getNodes()) {
if (!numbered.contains(v)) {
// count the number of times node v is adjacent to numbered node w
int cardinality = (int) graph.getAdjacentNodes(v).stream()
.filter(numbered::contains)
.count();
// find the maximum cardinality
if (cardinality > maxCardinality) {
maxCardinality = cardinality;
maxCardinalityNode = v;
}
}
}
// add the node with maximum cardinality to the ordering and number it
nodes[i] = maxCardinalityNode;
numbered.add(maxCardinalityNode);
}
}
private BayesPm createBayesPm(DataModel dataModel, Graph graph) {
Dag dag = new Dag(dataModel.getVariables());
(new Dag(graph)).getEdges().forEach(edge -> {
Node node1 = dag.getNode(edge.getNode1().getName());
Node node2 = dag.getNode(edge.getNode2().getName());
Endpoint endpoint1 = edge.getEndpoint1();
Endpoint endpoint2 = edge.getEndpoint2();
dag.addEdge(new Edge(node1, node2, endpoint1, endpoint2));
});
return new BayesPm(dag);
}
private BayesIm createBayesIm(DataModel dataModel, BayesPm bayesPm) {
return (new EmBayesEstimator(bayesPm, (DataSet) dataModel)).getEstimatedIm();
}
/**
* Put the nodes from the set to an array in the order they appear in the graph.
*
* @param nodes set of nodes
* @return nodes in the order they appear in the graph
*/
private Node[] toArray(Set nodes) {
int size = nodes.size();
Node[] order = new Node[size];
int index = 0;
for (Node node : this.graphNodes) {
if (nodes.contains(node)) {
order[index++] = node;
if (index == size) {
break;
}
}
}
return order;
}
private void normalize(double[] values) {
// sum up all the values
double sum = 0;
for (double value : values) {
sum += value;
}
// divide each value by the sum
for (int i = 0; i < values.length; i++) {
values[i] /= sum;
}
}
private int getCardinality(Set nodes) {
int count = 1;
count = nodes.stream()
.map(this.bayesPm::getNumCategories)
.reduce(count, (accumulator, element) -> accumulator * element);
return count;
}
private void updateValues(int size, int[] values, Node[] nodes) {
int j = size - 1;
values[j]++;
while (j >= 0 && values[j] == this.bayesPm.getNumCategories(nodes[j])) {
values[j] = 0;
j--;
if (j >= 0) {
values[j]++;
}
}
}
private int getIndexOfCPT(Node[] nodes, int[] values, Node[] order) {
int index = 0;
int j = 0;
for (int i = 0; i < order.length && j < nodes.length; i++) {
if (order[i] == nodes[j]) {
index *= this.bayesPm.getNumCategories(nodes[j]);
index += values[i];
j++;
}
}
return index;
}
private int getIndexOfCPT(Node[] nodes, int[] values) {
int index = 0;
for (int i = 0; i < nodes.length; i++) {
index *= this.bayesPm.getNumCategories(nodes[i]);
index += values[i];
}
return index;
}
private void clear(double[] array) {
Arrays.fill(array, 0);
}
private TreeNode getCliqueContainsNode(Node node) {
for (Node k : this.graphNodes) {
if (this.treeNodes.containsKey(k) && this.treeNodes.get(k).contains(node)) {
return this.treeNodes.get(k);
}
}
return null;
}
private void validate(int iNode) {
int maxIndex = this.margins.length - 1;
if (iNode < 0 || iNode > maxIndex) {
String msg = String.format(
"Invalid node index %d. Node index must be between 0 and %d.",
iNode,
maxIndex);
throw new IllegalArgumentException(msg);
}
}
private void validate(int iNode, int value) {
validate(iNode);
int maxValue = this.margins[iNode].length - 1;
if (value < 0 || value > maxValue) {
String msg = String.format(
"Invalid value %d for node index %d. Value must be between 0 and %d.",
value,
iNode,
maxValue);
throw new IllegalArgumentException(msg);
}
}
private void validate(int[] nodes) {
if (nodes == null) {
throw new IllegalArgumentException("Node indices cannot be null.");
}
if (nodes.length == 0) {
throw new IllegalArgumentException("Node indices are required.");
}
if (nodes.length > this.graphNodes.length) {
String msg = String.format(
"Number of nodes cannot exceed %d.",
this.graphNodes.length);
throw new IllegalArgumentException(msg);
}
}
private void validate(int[] nodes, int[] values) {
validate(nodes);
if (values == null) {
throw new IllegalArgumentException("Node values cannot be null.");
}
if (values.length == 0) {
throw new IllegalArgumentException("Node values are required.");
}
if (values.length != nodes.length) {
throw new IllegalArgumentException("Number of nodes values must be equal to the number of nodes.");
}
for (int i = 0; i < nodes.length; i++) {
validate(nodes[i], values[i]);
}
}
private void validateAll(int[] values) {
if (values == null) {
throw new IllegalArgumentException("Node values cannot be null.");
}
if (values.length == 0) {
throw new IllegalArgumentException("Node values are required.");
}
if (values.length != this.graphNodes.length) {
throw new IllegalArgumentException("Number of nodes values must be equal to the number of nodes.");
}
for (int i = 0; i < values.length; i++) {
int maxValue = this.margins[i].length - 1;
if (values[i] < 0 && values[i] > maxValue) {
String msg = String.format(
"Invalid value %d for node index %d. Value must be between 0 and %d.",
values[i],
i,
maxValue);
throw new IllegalArgumentException(msg);
}
}
}
public void setEvidence(int iNode, int value) {
validate(iNode, value);
Node node = this.graphNodes[iNode];
TreeNode treeNode = getCliqueContainsNode(node);
if (treeNode == null) {
String msg = String.format("Node %s is not in junction tree.", node.getName());
throw new IllegalArgumentException(msg);
}
treeNode.setEvidence(node, value);
}
private double[] getConditionalProbabilities(int iNode, int parent, int parentValue) {
validate(iNode);
validate(parent, parentValue);
setEvidence(parent, parentValue);
double[] condProbs = new double[this.margins[iNode].length];
System.arraycopy(this.margins[iNode], 0, condProbs, 0, condProbs.length);
normalize(condProbs);
// reset
initialize();
return condProbs;
}
private boolean isAllNodes(int[] nodes) {
if (nodes.length == this.graphNodes.length) {
long sum = Arrays.stream(nodes).sum();
long total = ((this.graphNodes.length - 1) * this.graphNodes.length) / 2;
return sum == total;
}
return false;
}
/**
* Get the joint probability of the nodes given their parents. Example: given x <-- z --> y, we can find
* P(x,y|z). Another example: given x <-- z --> y <-- w, we can find P(x,y|z,w)
*/
public double getConditionalProbabilities(int[] nodes, int[] values, int[] parents, int[] parentValues) {
validate(nodes, values);
validate(parents, parentValues);
for (int i = 0; i < parents.length; i++) {
setEvidence(parents[i], parentValues[i]);
}
double prob = 1;
for (int i = 0; i < nodes.length; i++) {
double[] marg = this.margins[nodes[i]];
double[] condProbs = new double[marg.length];
System.arraycopy(marg, 0, condProbs, 0, marg.length);
normalize(condProbs);
prob *= condProbs[values[i]];
}
// reset
initialize();
return prob;
}
/**
* Get the conditional probability of a node for all of its values.
*/
public double[] getConditionalProbabilities(int iNode, int[] parents, int[] parentValues) {
validate(iNode);
validate(parents, parentValues);
if (parents.length == 1) {
return getConditionalProbabilities(iNode, parents[0], parentValues[0]);
} else {
for (int i = 0; i < parents.length; i++) {
setEvidence(parents[i], parentValues[i]);
}
double[] condProbs = new double[this.margins[iNode].length];
System.arraycopy(this.margins[iNode], 0, condProbs, 0, condProbs.length);
normalize(condProbs);
// reset
initialize();
return condProbs;
}
}
public double getConditionalProbability(int iNode, int value, int[] parents, int[] parentValues) {
validate(iNode, value);
return getConditionalProbabilities(iNode, parents, parentValues)[value];
}
/**
* Get the joint probability of all nodes (variables). Given the nodes are X1, X2,...,Xn, then nodeValues[0] =
* value(X1), nodeValues[1] = value(X2),...,nodeValues[n-1] = value(Xn).
*
* @param nodeValues an array of values for each node
*/
public double getJointProbabilityAll(int[] nodeValues) {
validateAll(nodeValues);
double logJointClusterPotentials = this.root.getLogJointClusterPotentials(nodeValues);
double logJointSeparatorPotentials = this.root.getLogJointSeparatorPotentials(nodeValues);
return FastMath.exp(logJointClusterPotentials - logJointSeparatorPotentials);
}
public double getJointProbability(int[] nodes, int[] values) {
validate(nodes, values);
if (isAllNodes(nodes)) {
return getJointProbabilityAll(values);
} else {
for (int i = 0; i < nodes.length; i++) {
setEvidence(nodes[i], values[i]);
}
// sum out a non-evidence variable
double prob = 0;
int index = 0;
for (int i = 0; i < this.margins.length; i++) {
if (i < nodes.length && i == nodes[index]) {
index++;
} else {
prob += Arrays.stream(this.margins[i]).sum();
break;
}
}
// reset
initialize();
return prob;
}
}
public double[] getMarginalProbability(int iNode) {
validate(iNode);
double[] marginals = new double[this.margins[iNode].length];
System.arraycopy(this.margins[iNode], 0, marginals, 0, marginals.length);
normalize(marginals);
return marginals;
}
public double getMarginalProbability(int iNode, int value) {
validate(iNode, value);
return this.margins[iNode][value];
}
public List getNodes() {
return Collections.unmodifiableList(Arrays.asList(this.graphNodes));
}
public int getNumberOfNodes() {
return this.graphNodes.length;
}
@Override
public String toString() {
return this.root.toString().trim();
}
private class TreeSeparator implements TetradSerializable {
private static final long serialVersionUID = 23L;
private final double[] parentPotentials;
private final double[] childPotentials;
private final Node[] nodes;
private final TreeNode childNode;
private final TreeNode parentNode;
public TreeSeparator(Set separator, TreeNode childNode, TreeNode parentNode) {
this.childNode = childNode;
this.parentNode = parentNode;
this.nodes = toArray(separator);
int cardinality = getCardinality(separator);
this.parentPotentials = new double[cardinality];
this.childPotentials = new double[cardinality];
}
/**
* Marginalize TreeNode node over all nodes outside the separator set
*
* @param node one of the neighboring junction tree nodes of this separator
*/
public void update(TreeNode node, double[] potentials) {
clear(potentials);
if (node.prob != null) {
int size = node.nodes.length;
int[] values = new int[size];
for (int i = 0; i < node.cardinality; i++) {
int indexNodeCPT = getIndexOfCPT(node.nodes, values);
int indexSepCPT = getIndexOfCPT(this.nodes, values, node.nodes);
potentials[indexSepCPT] += node.prob[indexNodeCPT];
updateValues(size, values, node.nodes);
}
}
}
public void updateFromParent() {
update(this.parentNode, this.parentPotentials);
}
public void updateFromChild() {
update(this.childNode, this.childPotentials);
}
}
private class TreeNode implements TetradSerializable {
private static final long serialVersionUID = 23L;
/**
* Distribution over this junction node according to its potentials.
*/
private final double[] prob;
private final double[][] margProb;
private final double[] potentials;
private final List children;
private final int cardinality;
private final Set clique;
private final Node[] nodes;
private TreeSeparator parentSeparator;
public TreeNode(Set clique, Set finishedCalculated) {
this.clique = clique;
this.nodes = toArray(clique);
this.children = new LinkedList<>();
this.cardinality = getCardinality(clique);
this.potentials = new double[this.cardinality];
this.prob = new double[this.cardinality];
this.margProb = new double[this.nodes.length][];
for (int iNode = 0; iNode < this.nodes.length; iNode++) {
this.margProb[iNode] = new double[JunctionTreeAlgorithm.this.bayesPm.getNumCategories(this.nodes[iNode])];
}
calculatePotentials(clique, finishedCalculated);
}
private void calculatePotentials(Set cliques, Set finishedCalculated) {
Graph dag = JunctionTreeAlgorithm.this.bayesIm.getDag();
Set nodesWithParentsInCluster = new HashSet<>();
for (Node node : this.nodes) {
if (!finishedCalculated.contains(node) && cliques.containsAll(dag.getParents(node))) {
nodesWithParentsInCluster.add(node);
finishedCalculated.add(node);
}
}
// fill in values
int size = this.nodes.length;
int[] values = new int[size];
for (int i = 0; i < this.cardinality; i++) {
int indexCPT = getIndexOfCPT(this.nodes, values);
this.potentials[indexCPT] = 1.0;
for (int iNode = 0; iNode < this.nodes.length; iNode++) {
Node node = this.nodes[iNode];
if (nodesWithParentsInCluster.contains(node)) {
int nodeIndex = JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(node);
int rowIndex = getRowIndex(nodeIndex, values, this.nodes);
this.potentials[indexCPT] *= JunctionTreeAlgorithm.this.bayesIm.getProbability(nodeIndex, rowIndex, values[iNode]);
}
}
updateValues(size, values, this.nodes);
}
}
public void initializeUp() {
System.arraycopy(this.potentials, 0, this.prob, 0, this.cardinality);
int size = this.nodes.length;
int[] values = new int[size];
this.children.forEach(childNode -> {
TreeSeparator separator = childNode.parentSeparator;
for (int i = 0; i < this.cardinality; i++) {
int indexSepCPT = getIndexOfCPT(separator.nodes, values, this.nodes);
int indexNodeCPT = getIndexOfCPT(this.nodes, values);
this.prob[indexNodeCPT] *= separator.childPotentials[indexSepCPT];
updateValues(size, values, this.nodes);
}
});
if (this.parentSeparator != null) { // not a root node
this.parentSeparator.updateFromChild();
}
}
public void initializeDown(boolean recursively) {
if (this.parentSeparator != null) {
this.parentSeparator.updateFromParent();
int size = this.nodes.length;
int[] values = new int[size];
for (int i = 0; i < this.cardinality; i++) {
int indexSepCPT = getIndexOfCPT(this.parentSeparator.nodes, values, this.nodes);
int indexNodeCPT = getIndexOfCPT(this.nodes, values);
if (this.parentSeparator.childPotentials[indexSepCPT] > 0) {
this.prob[indexNodeCPT] *= (this.parentSeparator.parentPotentials[indexSepCPT] / this.parentSeparator.childPotentials[indexSepCPT]);
} else {
this.prob[indexNodeCPT] = 0;
}
updateValues(size, values, this.nodes);
}
this.parentSeparator.updateFromChild();
}
calculateMarginalProbabilities();
if (recursively) {
this.children.forEach(childNode -> childNode.initializeDown(true));
}
}
/**
* Calculate marginal probabilities for the individual nodes in the clique.
*/
private void calculateMarginalProbabilities() {
// reset
for (int iNode = 0; iNode < this.nodes.length; iNode++) {
clear(this.margProb[iNode]);
}
int size = this.nodes.length;
int[] values = new int[size];
for (int i = 0; i < this.cardinality; i++) {
int indexNodeCPT = getIndexOfCPT(this.nodes, values);
for (int iNode = 0; iNode < size; iNode++) {
this.margProb[iNode][values[iNode]] += this.prob[indexNodeCPT];
}
updateValues(size, values, this.nodes);
}
for (int iNode = 0; iNode < size; iNode++) {
JunctionTreeAlgorithm.this.margins[JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(this.nodes[iNode])] = this.margProb[iNode];
}
}
private int getRowIndex(int nodeIndex, int[] values, Node[] nodes) {
int index = 0;
int[] parents = JunctionTreeAlgorithm.this.bayesIm.getParents(nodeIndex);
for (int parent : parents) {
Node node = JunctionTreeAlgorithm.this.bayesIm.getNode(parent);
index *= JunctionTreeAlgorithm.this.bayesPm.getNumCategories(node);
for (int j = 0; j < nodes.length; j++) {
if (node == nodes[j]) {
index += values[j];
}
}
}
return index;
}
private int getNodeIndex(Node node) {
for (int i = 0; i < this.nodes.length; i++) {
if (this.nodes[i] == node) {
return i;
}
}
return -1;
}
public void setEvidence(Node node, int value) {
int nodeIndex = getNodeIndex(node);
if (nodeIndex < 0) {
String msg = String.format("Unable to find node %s in clique.", node.getName());
throw new IllegalArgumentException(msg);
}
int size = this.nodes.length;
int[] values = new int[size];
for (int i = 0; i < this.cardinality; i++) {
if (values[nodeIndex] != value) {
int indexNodeCPT = getIndexOfCPT(this.nodes, values);
this.prob[indexNodeCPT] = 0;
}
updateValues(size, values, this.nodes);
}
calculateMarginalProbabilities();
updateEvidence(this);
}
private void updateEvidence(TreeNode source) {
if (source != this) {
int size = this.nodes.length;
int[] values = new int[size];
for (int i = 0; i < this.cardinality; i++) {
int indexNodeCPT = getIndexOfCPT(this.nodes, values);
int indexChildNodeCPT = getIndexOfCPT(source.parentSeparator.nodes, values, this.nodes);
if (source.parentSeparator.parentPotentials[indexChildNodeCPT] != 0) {
this.prob[indexNodeCPT] *= source.parentSeparator.childPotentials[indexChildNodeCPT]
/ source.parentSeparator.parentPotentials[indexChildNodeCPT];
} else {
this.prob[indexNodeCPT] = 0;
}
updateValues(size, values, this.nodes);
}
calculateMarginalProbabilities();
}
this.children.stream()
.filter(e -> e != source)
.forEach(e -> e.initializeDown(true));
if (this.parentSeparator != null) {
this.parentSeparator.updateFromChild();
this.parentSeparator.parentNode.updateEvidence(this);
this.parentSeparator.updateFromParent();
}
}
private double getLogJointSeparatorPotentials(int[] nodeValues) {
double logJointPotentials = FastMath.log(1);
if (this.parentSeparator != null) {
Node[] parentNodes = this.parentSeparator.nodes;
int size = parentNodes.length;
int[] values = new int[size];
for (int iNode = 0; iNode < size; iNode++) {
values[iNode] = nodeValues[JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(parentNodes[iNode])];
}
logJointPotentials += FastMath.log(this.parentSeparator.childPotentials[getIndexOfCPT(parentNodes, values)]);
}
logJointPotentials = this.children.stream()
.map(child -> child.getLogJointSeparatorPotentials(nodeValues))
.reduce(logJointPotentials, Double::sum);
return logJointPotentials;
}
private double getLogJointClusterPotentials(int[] nodeValues) {
int size = this.nodes.length;
int[] values = new int[size];
for (int iNode = 0; iNode < size; iNode++) {
values[iNode] = nodeValues[JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(this.nodes[iNode])];
}
double logJointPotentials = FastMath.log(this.prob[getIndexOfCPT(this.nodes, values)]);
logJointPotentials = this.children.stream()
.map(child -> child.getLogJointClusterPotentials(nodeValues))
.reduce(logJointPotentials, Double::sum);
return logJointPotentials;
}
public void setParentSeparator(TreeSeparator parentSeparator) {
this.parentSeparator = parentSeparator;
}
public void addChildClique(TreeNode child) {
this.children.add(child);
}
public Set getClique() {
return this.clique;
}
/**
* Check if the clique contains the given node.
*/
public boolean contains(Node node) {
return this.clique.contains(node);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < this.nodes.length; i++) {
sb.append(this.nodes[i].getName());
sb.append(": ");
sb.append(Arrays.stream(this.margProb[i])
.mapToObj(String::valueOf)
.collect(Collectors.joining(" ")));
sb.append('\n');
}
this.children.forEach(childNode -> {
sb.append("----------------\n");
sb.append(childNode.toString());
});
return sb.toString();
}
}
}