weka.classifiers.bayes.net.MarginCalculator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* MarginCalculator.java
* Copyright (C) 2007-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.bayes.net;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
public class MarginCalculator implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 650278019241175534L;
boolean m_debug = false;
public JunctionTreeNode m_root = null;
JunctionTreeNode[] jtNodes;
public int getNode(String sNodeName) {
int iNode = 0;
while (iNode < m_root.m_bayesNet.m_Instances.numAttributes()) {
if (m_root.m_bayesNet.m_Instances.attribute(iNode).name()
.equals(sNodeName)) {
return iNode;
}
iNode++;
}
// throw new Exception("Could not find node [[" + sNodeName + "]]");
return -1;
}
public String toXMLBIF03() {
return m_root.m_bayesNet.toXMLBIF03();
}
/**
* Calc marginal distributions of nodes in Bayesian network Note that a
* connected network is assumed. Unconnected networks may give unexpected
* results.
*
* @param bayesNet
*/
public void calcMargins(BayesNet bayesNet) throws Exception {
// System.out.println(bayesNet.toString());
boolean[][] bAdjacencyMatrix = moralize(bayesNet);
process(bAdjacencyMatrix, bayesNet);
} // calcMargins
public void calcFullMargins(BayesNet bayesNet) throws Exception {
// System.out.println(bayesNet.toString());
int nNodes = bayesNet.getNrOfNodes();
boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
for (int iNode = 0; iNode < nNodes; iNode++) {
for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
bAdjacencyMatrix[iNode][iNode2] = true;
}
}
process(bAdjacencyMatrix, bayesNet);
} // calcMargins
public void process(boolean[][] bAdjacencyMatrix, BayesNet bayesNet)
throws Exception {
int[] order = getMaxCardOrder(bAdjacencyMatrix);
bAdjacencyMatrix = fillIn(order, bAdjacencyMatrix);
order = getMaxCardOrder(bAdjacencyMatrix);
Set[] cliques = getCliques(order, bAdjacencyMatrix);
Set[] separators = getSeparators(order, cliques);
int[] parentCliques = getCliqueTree(order, cliques, separators);
// report cliques
int nNodes = bAdjacencyMatrix.length;
if (m_debug) {
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
if (cliques[iNode] != null) {
System.out.print("Clique " + iNode + " (");
Iterator nodes = cliques[iNode].iterator();
while (nodes.hasNext()) {
int iNode2 = nodes.next();
System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
if (nodes.hasNext()) {
System.out.print(",");
}
}
System.out.print(") S(");
nodes = separators[iNode].iterator();
while (nodes.hasNext()) {
int iNode2 = nodes.next();
System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
if (nodes.hasNext()) {
System.out.print(",");
}
}
System.out.println(") parent clique " + parentCliques[iNode]);
}
}
}
jtNodes = getJunctionTree(cliques, separators, parentCliques, order,
bayesNet);
m_root = null;
for (int iNode = 0; iNode < nNodes; iNode++) {
if (parentCliques[iNode] < 0 && jtNodes[iNode] != null) {
m_root = jtNodes[iNode];
break;
}
}
m_Margins = new double[nNodes][];
initialize(jtNodes, order, cliques, separators, parentCliques);
// sanity check
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
if (cliques[iNode] != null) {
if (parentCliques[iNode] == -1 && separators[iNode].size() > 0) {
throw new Exception("Something wrong in clique tree");
}
}
}
if (m_debug) {
// System.out.println(m_root.toString());
}
} // process
void initialize(JunctionTreeNode[] jtNodes, int[] order,
Set[] cliques, Set[] separators, int[] parentCliques) {
int nNodes = order.length;
for (int i = nNodes - 1; i >= 0; i--) {
int iNode = order[i];
if (jtNodes[iNode] != null) {
jtNodes[iNode].initializeUp();
}
}
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
if (jtNodes[iNode] != null) {
jtNodes[iNode].initializeDown(false);
}
}
} // initialize
JunctionTreeNode[] getJunctionTree(Set[] cliques,
Set[] separators, int[] parentCliques, int[] order,
BayesNet bayesNet) {
int nNodes = order.length;
JunctionTreeNode[] jtns = new JunctionTreeNode[nNodes];
boolean[] bDone = new boolean[nNodes];
// create junction tree nodes
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
if (cliques[iNode] != null) {
jtns[iNode] = new JunctionTreeNode(cliques[iNode], bayesNet, bDone);
}
}
// create junction tree separators
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
if (cliques[iNode] != null) {
JunctionTreeNode parent = null;
if (parentCliques[iNode] > 0) {
parent = jtns[parentCliques[iNode]];
JunctionTreeSeparator jts = new JunctionTreeSeparator(
separators[iNode], bayesNet, jtns[iNode], parent);
jtns[iNode].setParentSeparator(jts);
jtns[parentCliques[iNode]].addChildClique(jtns[iNode]);
} else {
}
}
}
return jtns;
} // getJunctionTree
public class JunctionTreeSeparator implements Serializable, RevisionHandler {
private static final long serialVersionUID = 6502780192411755343L;
int[] m_nNodes;
int m_nCardinality;
double[] m_fiParent;
double[] m_fiChild;
JunctionTreeNode m_parentNode;
JunctionTreeNode m_childNode;
BayesNet m_bayesNet;
JunctionTreeSeparator(Set separator, BayesNet bayesNet,
JunctionTreeNode childNode, JunctionTreeNode parentNode) {
// ////////////////////
// initialize node set
m_nNodes = new int[separator.size()];
int iPos = 0;
m_nCardinality = 1;
for (Integer element : separator) {
int iNode = element;
m_nNodes[iPos++] = iNode;
m_nCardinality *= bayesNet.getCardinality(iNode);
}
m_parentNode = parentNode;
m_childNode = childNode;
m_bayesNet = bayesNet;
} // c'tor
/**
* marginalize junciontTreeNode node over all nodes outside the separator
* set of the parent clique
*
*/
public void updateFromParent() {
double[] fis = update(m_parentNode);
if (fis == null) {
m_fiParent = null;
} else {
m_fiParent = fis;
// normalize
double sum = 0;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
sum += m_fiParent[iPos];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_fiParent[iPos] /= sum;
}
}
} // updateFromParent
/**
* marginalize junciontTreeNode node over all nodes outside the separator
* set of the child clique
*
*/
public void updateFromChild() {
double[] fis = update(m_childNode);
if (fis == null) {
m_fiChild = null;
} else {
m_fiChild = fis;
// normalize
double sum = 0;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
sum += m_fiChild[iPos];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_fiChild[iPos] /= sum;
}
}
} // updateFromChild
/**
* marginalize junciontTreeNode node over all nodes outside the separator
* set
*
* @param node one of the neighboring junciont tree nodes of this separator
*/
public double[] update(JunctionTreeNode node) {
if (node.m_P == null) {
return null;
}
double[] fi = new double[m_nCardinality];
int[] values = new int[node.m_nNodes.length];
int[] order = new int[m_bayesNet.getNrOfNodes()];
for (int iNode = 0; iNode < node.m_nNodes.length; iNode++) {
order[node.m_nNodes[iNode]] = iNode;
}
// fill in the values
for (int iPos = 0; iPos < node.m_nCardinality; iPos++) {
int iNodeCPT = getCPT(node.m_nNodes, node.m_nNodes.length, values,
order, m_bayesNet);
int iSepCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
m_bayesNet);
fi[iSepCPT] += node.m_P[iNodeCPT];
// update values
int i = 0;
values[i]++;
while (i < node.m_nNodes.length
&& values[i] == m_bayesNet.getCardinality(node.m_nNodes[i])) {
values[i] = 0;
i++;
if (i < node.m_nNodes.length) {
values[i]++;
}
}
}
return fi;
} // update
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10154 $");
}
} // class JunctionTreeSeparator
public class JunctionTreeNode implements Serializable, RevisionHandler {
private static final long serialVersionUID = 650278019241175536L;
/**
* reference Bayes net for information about variables like name,
* cardinality, etc. but not for relations between nodes
**/
BayesNet m_bayesNet;
/** nodes of the Bayes net in this junction node **/
public int[] m_nNodes;
/** cardinality of the instances of variables in this junction node **/
int m_nCardinality;
/** potentials for first network **/
double[] m_fi;
/** distribution over this junction node according to first Bayes network **/
double[] m_P;
double[][] m_MarginalP;
JunctionTreeSeparator m_parentSeparator;
public void setParentSeparator(JunctionTreeSeparator parentSeparator) {
m_parentSeparator = parentSeparator;
}
public Vector m_children;
public void addChildClique(JunctionTreeNode child) {
m_children.add(child);
}
public void initializeUp() {
m_P = new double[m_nCardinality];
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_P[iPos] = m_fi[iPos];
}
int[] values = new int[m_nNodes.length];
int[] order = new int[m_bayesNet.getNrOfNodes()];
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
order[m_nNodes[iNode]] = iNode;
}
for (JunctionTreeNode element : m_children) {
JunctionTreeNode childNode = element;
JunctionTreeSeparator separator = childNode.m_parentSeparator;
// Update the values
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
int iSepCPT = getCPT(separator.m_nNodes, separator.m_nNodes.length,
values, order, m_bayesNet);
int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
m_bayesNet);
m_P[iNodeCPT] *= separator.m_fiChild[iSepCPT];
// update values
int i = 0;
values[i]++;
while (i < m_nNodes.length
&& values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
values[i] = 0;
i++;
if (i < m_nNodes.length) {
values[i]++;
}
}
}
}
// normalize
double sum = 0;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
sum += m_P[iPos];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_P[iPos] /= sum;
}
if (m_parentSeparator != null) { // not a root node
m_parentSeparator.updateFromChild();
}
} // initializeUp
public void initializeDown(boolean recursively) {
if (m_parentSeparator == null) { // a root node
calcMarginalProbabilities();
} else {
m_parentSeparator.updateFromParent();
int[] values = new int[m_nNodes.length];
int[] order = new int[m_bayesNet.getNrOfNodes()];
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
order[m_nNodes[iNode]] = iNode;
}
// Update the values
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
int iSepCPT = getCPT(m_parentSeparator.m_nNodes,
m_parentSeparator.m_nNodes.length, values, order, m_bayesNet);
int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
m_bayesNet);
if (m_parentSeparator.m_fiChild[iSepCPT] > 0) {
m_P[iNodeCPT] *= m_parentSeparator.m_fiParent[iSepCPT]
/ m_parentSeparator.m_fiChild[iSepCPT];
} else {
m_P[iNodeCPT] = 0;
}
// update values
int i = 0;
values[i]++;
while (i < m_nNodes.length
&& values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
values[i] = 0;
i++;
if (i < m_nNodes.length) {
values[i]++;
}
}
}
// normalize
double sum = 0;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
sum += m_P[iPos];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_P[iPos] /= sum;
}
m_parentSeparator.updateFromChild();
calcMarginalProbabilities();
}
if (recursively) {
for (Object element : m_children) {
JunctionTreeNode childNode = (JunctionTreeNode) element;
childNode.initializeDown(true);
}
}
} // initializeDown
/**
* calculate marginal probabilities for the individual nodes in the clique.
* Store results in m_MarginalP
*/
void calcMarginalProbabilities() {
// calculate marginal probabilities
int[] values = new int[m_nNodes.length];
int[] order = new int[m_bayesNet.getNrOfNodes()];
m_MarginalP = new double[m_nNodes.length][];
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
order[m_nNodes[iNode]] = iNode;
m_MarginalP[iNode] = new double[m_bayesNet
.getCardinality(m_nNodes[iNode])];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
m_bayesNet);
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
m_MarginalP[iNode][values[iNode]] += m_P[iNodeCPT];
}
// update values
int i = 0;
values[i]++;
while (i < m_nNodes.length
&& values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
values[i] = 0;
i++;
if (i < m_nNodes.length) {
values[i]++;
}
}
}
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
m_Margins[m_nNodes[iNode]] = m_MarginalP[iNode];
}
} // calcMarginalProbabilities
@Override
public String toString() {
StringBuffer buf = new StringBuffer();
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
buf.append(m_bayesNet.getNodeName(m_nNodes[iNode]) + ": ");
for (int iValue = 0; iValue < m_MarginalP[iNode].length; iValue++) {
buf.append(m_MarginalP[iNode][iValue] + " ");
}
buf.append('\n');
}
for (Object element : m_children) {
JunctionTreeNode childNode = (JunctionTreeNode) element;
buf.append("----------------\n");
buf.append(childNode.toString());
}
return buf.toString();
} // toString
void calculatePotentials(BayesNet bayesNet, Set clique,
boolean[] bDone) {
m_fi = new double[m_nCardinality];
int[] values = new int[m_nNodes.length];
int[] order = new int[bayesNet.getNrOfNodes()];
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
order[m_nNodes[iNode]] = iNode;
}
// find conditional probabilities that need to be taken in account
boolean[] bIsContained = new boolean[m_nNodes.length];
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
int nNode = m_nNodes[iNode];
bIsContained[iNode] = !bDone[nNode];
for (int iParent = 0; iParent < bayesNet.getNrOfParents(nNode); iParent++) {
int nParent = bayesNet.getParent(nNode, iParent);
if (!clique.contains(nParent)) {
bIsContained[iNode] = false;
}
}
if (bIsContained[iNode]) {
bDone[nNode] = true;
if (m_debug) {
System.out.println("adding node " + nNode);
}
}
}
// fill in the values
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
int iCPT = getCPT(m_nNodes, m_nNodes.length, values, order, bayesNet);
m_fi[iCPT] = 1.0;
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
if (bIsContained[iNode]) {
int nNode = m_nNodes[iNode];
int[] nNodes = bayesNet.getParentSet(nNode).getParents();
int iCPT2 = getCPT(nNodes, bayesNet.getNrOfParents(nNode), values,
order, bayesNet);
double f = bayesNet.getDistributions()[nNode][iCPT2]
.getProbability(values[iNode]);
m_fi[iCPT] *= f;
}
}
// update values
int i = 0;
values[i]++;
while (i < m_nNodes.length
&& values[i] == bayesNet.getCardinality(m_nNodes[i])) {
values[i] = 0;
i++;
if (i < m_nNodes.length) {
values[i]++;
}
}
}
} // calculatePotentials
JunctionTreeNode(Set clique, BayesNet bayesNet, boolean[] bDone) {
m_bayesNet = bayesNet;
m_children = new Vector();
// ////////////////////
// initialize node set
m_nNodes = new int[clique.size()];
int iPos = 0;
m_nCardinality = 1;
for (Integer integer : clique) {
int iNode = integer;
m_nNodes[iPos++] = iNode;
m_nCardinality *= bayesNet.getCardinality(iNode);
}
// //////////////////////////////
// initialize potential function
calculatePotentials(bayesNet, clique, bDone);
} // JunctionTreeNode c'tor
/*
* check whether this junciton tree node contains node nNode
*/
boolean contains(int nNode) {
for (int m_nNode : m_nNodes) {
if (m_nNode == nNode) {
return true;
}
}
return false;
} // contains
public void setEvidence(int nNode, int iValue) throws Exception {
int[] values = new int[m_nNodes.length];
int[] order = new int[m_bayesNet.getNrOfNodes()];
int nNodeIdx = -1;
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
order[m_nNodes[iNode]] = iNode;
if (m_nNodes[iNode] == nNode) {
nNodeIdx = iNode;
}
}
if (nNodeIdx < 0) {
throw new Exception("setEvidence: Node " + nNode
+ " not found in this clique");
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
if (values[nNodeIdx] != iValue) {
int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
m_bayesNet);
m_P[iNodeCPT] = 0;
}
// update values
int i = 0;
values[i]++;
while (i < m_nNodes.length
&& values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
values[i] = 0;
i++;
if (i < m_nNodes.length) {
values[i]++;
}
}
}
// normalize
double sum = 0;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
sum += m_P[iPos];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_P[iPos] /= sum;
}
calcMarginalProbabilities();
updateEvidence(this);
} // setEvidence
void updateEvidence(JunctionTreeNode source) {
if (source != this) {
int[] values = new int[m_nNodes.length];
int[] order = new int[m_bayesNet.getNrOfNodes()];
for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
order[m_nNodes[iNode]] = iNode;
}
int[] nChildNodes = source.m_parentSeparator.m_nNodes;
int nNumChildNodes = nChildNodes.length;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
m_bayesNet);
int iChildCPT = getCPT(nChildNodes, nNumChildNodes, values, order,
m_bayesNet);
if (source.m_parentSeparator.m_fiParent[iChildCPT] != 0) {
m_P[iNodeCPT] *= source.m_parentSeparator.m_fiChild[iChildCPT]
/ source.m_parentSeparator.m_fiParent[iChildCPT];
} else {
m_P[iNodeCPT] = 0;
}
// update values
int i = 0;
values[i]++;
while (i < m_nNodes.length
&& values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
values[i] = 0;
i++;
if (i < m_nNodes.length) {
values[i]++;
}
}
}
// normalize
double sum = 0;
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
sum += m_P[iPos];
}
for (int iPos = 0; iPos < m_nCardinality; iPos++) {
m_P[iPos] /= sum;
}
calcMarginalProbabilities();
}
for (Object element : m_children) {
JunctionTreeNode childNode = (JunctionTreeNode) element;
if (childNode != source) {
childNode.initializeDown(true);
}
}
if (m_parentSeparator != null) {
m_parentSeparator.updateFromChild();
m_parentSeparator.m_parentNode.updateEvidence(this);
m_parentSeparator.updateFromParent();
}
} // updateEvidence
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10154 $");
}
} // class JunctionTreeNode
int getCPT(int[] nodeSet, int nNodes, int[] values, int[] order,
BayesNet bayesNet) {
int iCPTnew = 0;
for (int iNode = 0; iNode < nNodes; iNode++) {
int nNode = nodeSet[iNode];
iCPTnew = iCPTnew * bayesNet.getCardinality(nNode);
iCPTnew += values[order[nNode]];
}
return iCPTnew;
} // getCPT
int[] getCliqueTree(int[] order, Set[] cliques,
Set[] separators) {
int nNodes = order.length;
int[] parentCliques = new int[nNodes];
// for (int i = nNodes - 1; i >= 0; i--) {
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
parentCliques[iNode] = -1;
if (cliques[iNode] != null && separators[iNode].size() > 0) {
// for (int j = nNodes - 1; j > i; j--) {
for (int j = 0; j < nNodes; j++) {
int iNode2 = order[j];
if (iNode != iNode2 && cliques[iNode2] != null
&& cliques[iNode2].containsAll(separators[iNode])) {
parentCliques[iNode] = iNode2;
j = i;
j = 0;
j = nNodes;
}
}
}
}
return parentCliques;
} // getCliqueTree
/**
* calculate separator sets in clique tree
*
* @param order: maximum cardinality ordering of the graph
* @param cliques: set of cliques
* @return set of separator sets
*/
Set[] getSeparators(int[] order, Set[] cliques) {
int nNodes = order.length;
@SuppressWarnings("unchecked")
Set[] separators = new HashSet[nNodes];
Set processedNodes = new HashSet();
// for (int i = nNodes - 1; i >= 0; i--) {
for (int i = 0; i < nNodes; i++) {
int iNode = order[i];
if (cliques[iNode] != null) {
Set separator = new HashSet();
separator.addAll(cliques[iNode]);
separator.retainAll(processedNodes);
separators[iNode] = separator;
processedNodes.addAll(cliques[iNode]);
}
}
return separators;
} // getSeparators
/**
* get cliques in a decomposable graph represented by an adjacency matrix
*
* @param order: maximum cardinality ordering of the graph
* @param bAdjacencyMatrix: decomposable graph
* @return set of cliques
*/
Set[] getCliques(int[] order, boolean[][] bAdjacencyMatrix)
throws Exception {
int nNodes = bAdjacencyMatrix.length;
@SuppressWarnings("unchecked")
Set[] cliques = new HashSet[nNodes];
// int[] inverseOrder = new int[nNodes];
// for (int iNode = 0; iNode < nNodes; iNode++) {
// inverseOrder[order[iNode]] = iNode;
// }
// consult nodes in reverse order
for (int i = nNodes - 1; i >= 0; i--) {
int iNode = order[i];
if (iNode == 22) {
}
Set clique = new HashSet();
clique.add(iNode);
for (int j = 0; j < i; j++) {
int iNode2 = order[j];
if (bAdjacencyMatrix[iNode][iNode2]) {
clique.add(iNode2);
}
}
// for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
// if (bAdjacencyMatrix[iNode][iNode2] && inverseOrder[iNode2] <
// inverseOrder[iNode]) {
// clique.add(iNode2);
// }
// }
cliques[iNode] = clique;
}
for (int iNode = 0; iNode < nNodes; iNode++) {
for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
if (iNode != iNode2 && cliques[iNode] != null
&& cliques[iNode2] != null
&& cliques[iNode].containsAll(cliques[iNode2])) {
cliques[iNode2] = null;
}
}
}
// sanity check
if (m_debug) {
int[] nNodeSet = new int[nNodes];
for (int iNode = 0; iNode < nNodes; iNode++) {
if (cliques[iNode] != null) {
Iterator it = cliques[iNode].iterator();
int k = 0;
while (it.hasNext()) {
nNodeSet[k++] = it.next();
}
for (int i = 0; i < cliques[iNode].size(); i++) {
for (int j = 0; j < cliques[iNode].size(); j++) {
if (i != j && !bAdjacencyMatrix[nNodeSet[i]][nNodeSet[j]]) {
throw new Exception("Non clique" + i + " " + j);
}
}
}
}
}
}
return cliques;
} // getCliques
/**
* moralize DAG and calculate adjacency matrix representation for a Bayes
* Network, effecively converting the directed acyclic graph to an undirected
* graph.
*
* @param bayesNet Bayes Network to process
* @return adjacencies in boolean matrix format
*/
public boolean[][] moralize(BayesNet bayesNet) {
int nNodes = bayesNet.getNrOfNodes();
boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
for (int iNode = 0; iNode < nNodes; iNode++) {
ParentSet parents = bayesNet.getParentSets()[iNode];
moralizeNode(parents, iNode, bAdjacencyMatrix);
}
return bAdjacencyMatrix;
} // moralize
private void moralizeNode(ParentSet parents, int iNode,
boolean[][] bAdjacencyMatrix) {
for (int iParent = 0; iParent < parents.getNrOfParents(); iParent++) {
int nParent = parents.getParent(iParent);
if (m_debug && !bAdjacencyMatrix[iNode][nParent]) {
System.out.println("Insert " + iNode + "--" + nParent);
}
bAdjacencyMatrix[iNode][nParent] = true;
bAdjacencyMatrix[nParent][iNode] = true;
for (int iParent2 = iParent + 1; iParent2 < parents.getNrOfParents(); iParent2++) {
int nParent2 = parents.getParent(iParent2);
if (m_debug && !bAdjacencyMatrix[nParent2][nParent]) {
System.out.println("Mary " + nParent + "--" + nParent2);
}
bAdjacencyMatrix[nParent2][nParent] = true;
bAdjacencyMatrix[nParent][nParent2] = true;
}
}
} // moralizeNode
/**
* Apply Tarjan and Yannakakis (1984) fill in algorithm for graph
* triangulation. In reverse order, insert edges between any non-adjacent
* neighbors that are lower numbered in the ordering.
*
* Side effect: input matrix is used as output
*
* @param order node ordering
* @param bAdjacencyMatrix boolean matrix representing the graph
* @return boolean matrix representing the graph with fill ins
*/
public boolean[][] fillIn(int[] order, boolean[][] bAdjacencyMatrix) {
int nNodes = bAdjacencyMatrix.length;
int[] inverseOrder = new int[nNodes];
for (int iNode = 0; iNode < nNodes; iNode++) {
inverseOrder[order[iNode]] = iNode;
}
// consult nodes in reverse order
for (int i = nNodes - 1; i >= 0; i--) {
int iNode = order[i];
// find pairs of neighbors with lower order
for (int j = 0; j < i; j++) {
int iNode2 = order[j];
if (bAdjacencyMatrix[iNode][iNode2]) {
for (int k = j + 1; k < i; k++) {
int iNode3 = order[k];
if (bAdjacencyMatrix[iNode][iNode3]) {
// fill in
if (m_debug
&& (!bAdjacencyMatrix[iNode2][iNode3] || !bAdjacencyMatrix[iNode3][iNode2])) {
System.out.println("Fill in " + iNode2 + "--" + iNode3);
}
bAdjacencyMatrix[iNode2][iNode3] = true;
bAdjacencyMatrix[iNode3][iNode2] = true;
}
}
}
}
}
return bAdjacencyMatrix;
} // fillIn
/**
* calculate maximum cardinality ordering; start with first node add node that
* has most neighbors already ordered till all nodes are in the ordering
*
* This implementation does not assume the graph is connected
*
* @param bAdjacencyMatrix: n by n matrix with adjacencies in graph of n nodes
* @return maximum cardinality ordering
*/
int[] getMaxCardOrder(boolean[][] bAdjacencyMatrix) {
int nNodes = bAdjacencyMatrix.length;
int[] order = new int[nNodes];
if (nNodes == 0) {
return order;
}
boolean[] bDone = new boolean[nNodes];
// start with node 0
order[0] = 0;
bDone[0] = true;
// order remaining nodes
for (int iNode = 1; iNode < nNodes; iNode++) {
int nMaxCard = -1;
int iBestNode = -1;
// find node with higest cardinality of previously ordered nodes
for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
if (!bDone[iNode2]) {
int nCard = 0;
// calculate cardinality for node iNode2
for (int iNode3 = 0; iNode3 < nNodes; iNode3++) {
if (bAdjacencyMatrix[iNode2][iNode3] && bDone[iNode3]) {
nCard++;
}
}
if (nCard > nMaxCard) {
nMaxCard = nCard;
iBestNode = iNode2;
}
}
}
order[iNode] = iBestNode;
bDone[iBestNode] = true;
}
return order;
} // getMaxCardOrder
public void setEvidence(int nNode, int iValue) throws Exception {
if (m_root == null) {
throw new Exception("Junction tree not initialize yet");
}
int iJtNode = 0;
while (iJtNode < jtNodes.length
&& (jtNodes[iJtNode] == null || !jtNodes[iJtNode].contains(nNode))) {
iJtNode++;
}
if (jtNodes.length == iJtNode) {
throw new Exception("Could not find node " + nNode + " in junction tree");
}
jtNodes[iJtNode].setEvidence(nNode, iValue);
} // setEvidence
@Override
public String toString() {
return m_root.toString();
} // toString
double[][] m_Margins;
public double[] getMargin(int iNode) {
return m_Margins[iNode];
} // getMargin
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10154 $");
}
public static void main(String[] args) {
try {
BIFReader bayesNet = new BIFReader();
bayesNet.processFile(args[0]);
MarginCalculator dc = new MarginCalculator();
dc.calcMargins(bayesNet);
int iNode = 2;
int iValue = 0;
int iNode2 = 4;
int iValue2 = 0;
dc.setEvidence(iNode, iValue);
dc.setEvidence(iNode2, iValue2);
System.out.print(dc.toString());
dc.calcFullMargins(bayesNet);
dc.setEvidence(iNode, iValue);
dc.setEvidence(iNode2, iValue2);
System.out.println("==============");
System.out.print(dc.toString());
} catch (Exception e) {
e.printStackTrace();
}
} // main
} // class MarginCalculator