All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.bayes.net.MarginCalculator Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
package weka.classifiers.bayes.net;

import weka.classifiers.bayes.BayesNet;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.Vector;


public class MarginCalculator implements Serializable, RevisionHandler {
	  /** for serialization */
	  private static final long serialVersionUID = 650278019241175534L;

	  boolean m_debug = false;
	  public JunctionTreeNode m_root = null;
	JunctionTreeNode [] jtNodes;

	public int getNode(String sNodeName) {
    	int iNode = 0;
    	while (iNode < m_root.m_bayesNet.m_Instances.numAttributes()) {
    		if (m_root.m_bayesNet.m_Instances.attribute(iNode).name().equals(sNodeName)) {
    			return iNode;
    		}
	    	iNode++; 
    	}
    	//throw new Exception("Could not find node [[" + sNodeName + "]]");
    	return -1;
	}
	public String toXMLBIF03() {return m_root.m_bayesNet.toXMLBIF03();}
	
	/**
	 * Calc marginal distributions of nodes in Bayesian network
	 *	 Note that a connected network is assumed. 
	 *	 Unconnected networks may give unexpected results.
	 * @param bayesNet
	 * @return root of junction tree
	 */
	public void calcMargins(BayesNet bayesNet) throws Exception {
		//System.out.println(bayesNet.toString());
		boolean[][] bAdjacencyMatrix = moralize(bayesNet);
		process(bAdjacencyMatrix, bayesNet);
	} // calcMargins

	public void calcFullMargins(BayesNet bayesNet) throws Exception {
		//System.out.println(bayesNet.toString());
		int nNodes = bayesNet.getNrOfNodes();
		boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
		for (int iNode = 0; iNode < nNodes; iNode++) {
			for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
				bAdjacencyMatrix[iNode][iNode2] = true;
			}
		}
		process(bAdjacencyMatrix, bayesNet);
	} // calcMargins
	
	
	public void process(boolean[][] bAdjacencyMatrix, BayesNet bayesNet) throws Exception {
		int[] order = getMaxCardOrder(bAdjacencyMatrix);
		bAdjacencyMatrix = fillIn(order, bAdjacencyMatrix);
		order = getMaxCardOrder(bAdjacencyMatrix);
		Set [] cliques = getCliques(order, bAdjacencyMatrix);
		Set [] separators = getSeparators(order, cliques);
		int [] parentCliques = getCliqueTree(order, cliques, separators);
		// report cliques
		int nNodes = bAdjacencyMatrix.length;
		if (m_debug) {
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			if (cliques[iNode] != null) {
				System.out.print("Clique " + iNode + " (");
				Iterator nodes = cliques[iNode].iterator();
				while (nodes.hasNext()) {
					int iNode2 = (Integer) nodes.next();
					System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
					if (nodes.hasNext()) {
						System.out.print(",");
					}
				}
				System.out.print(") S(");
				nodes = separators[iNode].iterator();
				while (nodes.hasNext()) {
					int iNode2 = (Integer) nodes.next();
					System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
					if (nodes.hasNext()) {
						System.out.print(",");
					}
				}
				System.out.println(") parent clique " + parentCliques[iNode]);
			}		
		}
		}
				
		jtNodes = getJunctionTree(cliques, separators, parentCliques, order, bayesNet);
		m_root = null;
		for (int iNode = 0; iNode < nNodes; iNode++) {
			if (parentCliques[iNode] < 0 && jtNodes[iNode] != null) {
				m_root = jtNodes[iNode];
				break;
			}
		}
		m_Margins = new double[nNodes][];
		initialize(jtNodes, order, cliques, separators, parentCliques);
		
		// sanity check
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			if (cliques[iNode] != null) {
				if (parentCliques[iNode] == -1 && separators[iNode].size() > 0) {
					throw new Exception("Something wrong in clique tree");
				}
			}
		}
		if (m_debug) {
			//System.out.println(m_root.toString());
		}
	} // process
		
	void initialize(JunctionTreeNode [] jtNodes, int [] order, Set [] cliques, Set [] separators, int [] parentCliques) {
		int nNodes = order.length;
		for (int i = nNodes - 1; i >= 0; i--) {
			int iNode = order[i];
			if (jtNodes[iNode]!=null) {
				jtNodes[iNode].initializeUp();
			}
		}	
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			if (jtNodes[iNode]!=null) {
				jtNodes[iNode].initializeDown(false);
			}
		}	
	} // initialize
	
	JunctionTreeNode [] getJunctionTree(Set [] cliques, Set [] separators, int [] parentCliques, int [] order, BayesNet bayesNet) {
		int nNodes = order.length;
		JunctionTreeNode root = null;
		JunctionTreeNode [] jtns = new JunctionTreeNode[nNodes]; 
		boolean [] bDone = new boolean[nNodes];
		// create junction tree nodes
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			if (cliques[iNode] != null) {
				jtns[iNode] = new JunctionTreeNode(cliques[iNode], bayesNet, bDone);
			}
		}
		// create junction tree separators
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			if (cliques[iNode] != null) {
				JunctionTreeNode parent = null;
				if (parentCliques[iNode] > 0) {
					parent = jtns[parentCliques[iNode]];
					JunctionTreeSeparator jts = new JunctionTreeSeparator(separators[iNode], bayesNet, jtns[iNode], parent);
					jtns[iNode].setParentSeparator(jts);
					jtns[parentCliques[iNode]].addChildClique(jtns[iNode]);
				} else {
					root = jtns[iNode];	
				}
			}
		}
		return jtns;
	} // getJunctionTree
	
	public class JunctionTreeSeparator implements Serializable, RevisionHandler {
	  
		  private static final long serialVersionUID = 6502780192411755343L;
		int [] m_nNodes;
		int m_nCardinality;
		double [] m_fiParent;
		double [] m_fiChild;
		JunctionTreeNode m_parentNode;
		JunctionTreeNode m_childNode;
		BayesNet m_bayesNet;
		
		JunctionTreeSeparator(Set separator, BayesNet bayesNet, JunctionTreeNode childNode, JunctionTreeNode parentNode) {
			//////////////////////
			// initialize node set
			m_nNodes = new int[separator.size()];
			int iPos = 0;
			m_nCardinality = 1;
			for(Iterator nodes = separator.iterator(); nodes.hasNext();) {
				int iNode = (Integer) nodes.next();
				m_nNodes[iPos++] = iNode;
				m_nCardinality *= bayesNet.getCardinality(iNode);
			}
			m_parentNode = parentNode;
			m_childNode = childNode;
			m_bayesNet = bayesNet;
		} // c'tor
		
		/** marginalize junciontTreeNode node over all nodes outside the separator set
		 * of the parent clique
		 *
		 */
		public void updateFromParent() {
			double [] fis = update(m_parentNode); 
			if (fis == null) {
				m_fiParent = null;
			} else {
				m_fiParent = fis;
				// normalize
				double sum = 0;
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					sum += m_fiParent[iPos];
				}
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					m_fiParent[iPos] /= sum;
				}
			}
		} // updateFromParent

		/** marginalize junciontTreeNode node over all nodes outside the separator set
		 * of the child clique
		 *
		 */
		public void updateFromChild() {
			double [] fis = update(m_childNode); 
			if (fis == null) {
				m_fiChild = null;
			} else {
				m_fiChild = fis;
				// normalize
				double sum = 0;
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					sum += m_fiChild[iPos];
				}
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					m_fiChild[iPos] /= sum;
				}
			}
		} // updateFromChild
		
		/** marginalize junciontTreeNode node over all nodes outside the separator set
		 * 
		 * @param node: one of the neighboring junciont tree nodes of this separator
		 */
		public double [] update(JunctionTreeNode node) {
			if (node.m_P == null) {
				return null;
			}
			double [] fi = new double[m_nCardinality];

			int [] values = new int[node.m_nNodes.length];
			int [] order = new int[m_bayesNet.getNrOfNodes()];
			for (int iNode = 0; iNode < node.m_nNodes.length; iNode++) {
				order[node.m_nNodes[iNode]] = iNode;
			}
			// fill in the values
			for (int iPos = 0; iPos < node.m_nCardinality; iPos++) {
				int iNodeCPT = getCPT(node.m_nNodes, node.m_nNodes.length, values, order, m_bayesNet);
				int iSepCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
				fi[iSepCPT] += node.m_P[iNodeCPT];
				// update values
				int i = 0;
				values[i]++;
				while (i < node.m_nNodes.length && values[i] == m_bayesNet.getCardinality(node.m_nNodes[i])) {
					values[i] = 0;
					i++;
					if (i < node.m_nNodes.length) {
						values[i]++;
					}
				}
			}
			return fi;
		} // update
		  
		/**
		 * Returns the revision string.
		 * 
		 * @return		the revision
		 */
		public String getRevision() {
		  return RevisionUtils.extract("$Revision: 1.2 $");
		}

	} // class JunctionTreeSeparator

	public class JunctionTreeNode implements Serializable, RevisionHandler {
	  
		  private static final long serialVersionUID = 650278019241175536L;
		/** reference Bayes net for information about variables like name, cardinality, etc.
		 * but not for relations between nodes **/
		BayesNet m_bayesNet;
		/** nodes of the Bayes net in this junction node **/
		public int [] m_nNodes;
		/** cardinality of the instances of variables in this junction node **/
		int m_nCardinality;
		/** potentials for first network **/
		double [] m_fi;

		/** distribution over this junction node according to first Bayes network **/
		double [] m_P;


		double [][] m_MarginalP;		

		
		JunctionTreeSeparator m_parentSeparator;
		public void setParentSeparator(JunctionTreeSeparator parentSeparator) {m_parentSeparator = parentSeparator;}
		public Vector m_children;
		public void addChildClique(JunctionTreeNode child) {m_children.add(child);}

		public void initializeUp() {
			m_P = new double[m_nCardinality];
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				m_P[iPos] = m_fi[iPos];
			}
			int [] values = new int[m_nNodes.length];
			int [] order = new int[m_bayesNet.getNrOfNodes()];
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				order[m_nNodes[iNode]] = iNode;
			}
			for (Iterator child = m_children.iterator(); child.hasNext(); ) {
				JunctionTreeNode childNode = (JunctionTreeNode) child.next();
				JunctionTreeSeparator separator = childNode.m_parentSeparator;
			// Update the values
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				int iSepCPT = getCPT(separator.m_nNodes, separator.m_nNodes.length, values, order, m_bayesNet);
				int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
					m_P[iNodeCPT] *= separator.m_fiChild[iSepCPT];					
				// update values
				int i = 0;
				values[i]++;
				while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
					values[i] = 0;
					i++;
					if (i < m_nNodes.length) {
						values[i]++;
					}
				}
			}
			}
			// normalize
			double sum = 0;
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				sum += m_P[iPos];
			}
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				m_P[iPos] /= sum;
			}

			if (m_parentSeparator != null) { // not a root node
				m_parentSeparator.updateFromChild();
			}
		} // initializeUp

		public void initializeDown(boolean recursively) {
			if (m_parentSeparator == null) { // a root node
				calcMarginalProbabilities();
			} else {
			m_parentSeparator.updateFromParent();
				int [] values = new int[m_nNodes.length];
				int [] order = new int[m_bayesNet.getNrOfNodes()];
				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
					order[m_nNodes[iNode]] = iNode;
				}

				
				// Update the values
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					int iSepCPT = getCPT(m_parentSeparator.m_nNodes, m_parentSeparator.m_nNodes.length, values, order, m_bayesNet);
					int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
					if ( m_parentSeparator.m_fiChild[iSepCPT] > 0) {
						m_P[iNodeCPT] *= m_parentSeparator.m_fiParent[iSepCPT] / m_parentSeparator.m_fiChild[iSepCPT];
					} else {
						m_P[iNodeCPT] = 0;
					}
					// update values
					int i = 0;
					values[i]++;
					while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
						values[i] = 0;
						i++;
						if (i < m_nNodes.length) {
							values[i]++;
						}
					}
				}
				// normalize
				double sum = 0;
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					sum += m_P[iPos];
				}
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					m_P[iPos] /= sum;
				}
				m_parentSeparator.updateFromChild();
				calcMarginalProbabilities();
			}
			if (recursively) {
				for (Iterator child = m_children.iterator(); child.hasNext(); ) {
					JunctionTreeNode childNode = (JunctionTreeNode) child.next();
					childNode.initializeDown(true);
				}			
			}
		} // initializeDown
		
		
		/** calculate marginal probabilities for the individual nodes in the clique.
		 * Store results in m_MarginalP 
		 */
		void calcMarginalProbabilities() {			
			// calculate marginal probabilities
			int [] values = new int[m_nNodes.length];
			int [] order = new int[m_bayesNet.getNrOfNodes()];
			m_MarginalP = new double[m_nNodes.length][];
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				order[m_nNodes[iNode]] = iNode;
				m_MarginalP[iNode]=new double[m_bayesNet.getCardinality(m_nNodes[iNode])];
			}
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
					m_MarginalP[iNode][values[iNode]] += m_P[iNodeCPT];
				}
				// update values
				int i = 0;
				values[i]++;
				while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
					values[i] = 0;
					i++;
					if (i < m_nNodes.length) {
						values[i]++;
					}
				}
			}
			
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				m_Margins[m_nNodes[iNode]] = m_MarginalP[iNode]; 
			}
		} // calcMarginalProbabilities
		
		public String toString() {
			StringBuffer buf = new StringBuffer();
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				buf.append(m_bayesNet.getNodeName(m_nNodes[iNode]) + ": ");
				for (int iValue = 0; iValue < m_MarginalP[iNode].length; iValue++) {
					buf.append(m_MarginalP[iNode][iValue] + " ");
				}
				buf.append('\n');
			}
			for (Iterator child = m_children.iterator(); child.hasNext(); ) {
				JunctionTreeNode childNode = (JunctionTreeNode) child.next();
				buf.append("----------------\n");
				buf.append(childNode.toString());
			}			
			return buf.toString();
		} // toString
		
		void calculatePotentials(BayesNet bayesNet, Set clique, boolean [] bDone) {
			m_fi = new double[m_nCardinality];
			
			int [] values = new int[m_nNodes.length];
			int [] order = new int[bayesNet.getNrOfNodes()];
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				order[m_nNodes[iNode]] = iNode;
			}
			// find conditional probabilities that need to be taken in account
			boolean [] bIsContained = new boolean[m_nNodes.length];
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				int nNode = m_nNodes[iNode];
				bIsContained[iNode] = !bDone[nNode];
				for (int iParent = 0; iParent < bayesNet.getNrOfParents(nNode); iParent++) {
					int nParent = bayesNet.getParent(nNode, iParent);
					if (!clique.contains(nParent)) {
						bIsContained[iNode] = false;
					}
				}
				if (bIsContained[iNode]) {
					bDone[nNode] = true;
					if (m_debug) {
						System.out.println("adding node " +nNode);
					}
				}
			}			

			// fill in the values
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				int iCPT = getCPT(m_nNodes, m_nNodes.length, values, order, bayesNet);
				m_fi[iCPT] = 1.0;
				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
					if (bIsContained[iNode]) {
						int nNode = m_nNodes[iNode];
						int [] nNodes = bayesNet.getParentSet(nNode).getParents();
						int iCPT2 = getCPT(nNodes, bayesNet.getNrOfParents(nNode), values, order, bayesNet);
						double f = bayesNet.getDistributions()[nNode][iCPT2].getProbability(values[iNode]);
						m_fi[iCPT] *= f;
					}
				}
				
				// update values
				int i = 0;
				values[i]++;
				while (i < m_nNodes.length && values[i] == bayesNet.getCardinality(m_nNodes[i])) {
					values[i] = 0;
					i++;
					if (i < m_nNodes.length) {
						values[i]++;
					}
				}
			}
		} // calculatePotentials

		JunctionTreeNode(Set clique, BayesNet bayesNet, boolean [] bDone) {
			m_bayesNet = bayesNet;
			m_children = new Vector();
			//////////////////////
			// initialize node set
			m_nNodes = new int[clique.size()];
			int iPos = 0;
			m_nCardinality = 1;
			for(Iterator nodes = clique.iterator(); nodes.hasNext();) {
				int iNode = (Integer) nodes.next();
				m_nNodes[iPos++] = iNode;
				m_nCardinality *= bayesNet.getCardinality(iNode);
			}
			////////////////////////////////
			// initialize potential function
			calculatePotentials(bayesNet, clique, bDone);
       } // JunctionTreeNode c'tor

		/* check whether this junciton tree node contains node nNode
		 * 
		 */
		boolean contains(int nNode) {
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				if (m_nNodes[iNode]== nNode){
					return true;
				}
			}
			return false;
		} // contains
		
		public void setEvidence(int nNode, int iValue) throws Exception {
			int [] values = new int[m_nNodes.length];
			int [] order = new int[m_bayesNet.getNrOfNodes()];

			int nNodeIdx = -1;
			for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
				order[m_nNodes[iNode]] = iNode;
				if (m_nNodes[iNode] == nNode) {
					nNodeIdx = iNode;
				}
			}
			if (nNodeIdx < 0) {
				throw new Exception("setEvidence: Node " + nNode + " not found in this clique");
			}
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				if (values[nNodeIdx] != iValue) {
					int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
					m_P[iNodeCPT] = 0;
				}
				// update values
				int i = 0;
				values[i]++;
				while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
					values[i] = 0;
					i++;
					if (i < m_nNodes.length) {
						values[i]++;
					}
				}
			}		
			// normalize
			double sum = 0;
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				sum += m_P[iPos];
			}
			for (int iPos = 0; iPos < m_nCardinality; iPos++) {
				m_P[iPos] /= sum;
			}
			calcMarginalProbabilities();
			updateEvidence(this);
		} // setEvidence

		void updateEvidence(JunctionTreeNode source) {
			if (source != this) {
				int [] values = new int[m_nNodes.length];
				int [] order = new int[m_bayesNet.getNrOfNodes()];
				for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
					order[m_nNodes[iNode]] = iNode;
				}
				int [] nChildNodes = source.m_parentSeparator.m_nNodes;
				int nNumChildNodes = nChildNodes.length; 
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
					int iChildCPT =  getCPT(nChildNodes, nNumChildNodes, values, order, m_bayesNet);
					if (source.m_parentSeparator.m_fiParent[iChildCPT] != 0) {
						m_P[iNodeCPT] *= source.m_parentSeparator.m_fiChild[iChildCPT]/source.m_parentSeparator.m_fiParent[iChildCPT];
					} else {
						m_P[iNodeCPT] = 0;
					}
					// update values
					int i = 0;
					values[i]++;
					while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
						values[i] = 0;
						i++;
						if (i < m_nNodes.length) {
							values[i]++;
						}
					}
				}		
				// normalize
				double sum = 0;
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					sum += m_P[iPos];
				}
				for (int iPos = 0; iPos < m_nCardinality; iPos++) {
					m_P[iPos] /= sum;
				}
				calcMarginalProbabilities();
			}
			for (Iterator child = m_children.iterator(); child.hasNext(); ) {
				JunctionTreeNode childNode = (JunctionTreeNode) child.next();
				if (childNode != source) {
					childNode.initializeDown(true);
				}
			}			
			if (m_parentSeparator != null) {
				m_parentSeparator.updateFromChild();
				m_parentSeparator.m_parentNode.updateEvidence(this);
				m_parentSeparator.updateFromParent();
			}
		} // updateEvidence

		/**
		 * Returns the revision string.
		 * 
		 * @return		the revision
		 */
		public String getRevision() {
		  return RevisionUtils.extract("$Revision: 1.2 $");
		}
		
	} // class JunctionTreeNode

	int getCPT(int [] nodeSet, int nNodes, int[] values, int[] order, BayesNet bayesNet) {
		int iCPTnew = 0;
		for (int iNode = 0; iNode < nNodes; iNode++) {
			int nNode = nodeSet[iNode];
			iCPTnew = iCPTnew * bayesNet.getCardinality(nNode);
			iCPTnew += values[order[nNode]];
		}
		return iCPTnew;
	} // getCPT

	int [] getCliqueTree(int [] order, Set [] cliques, Set [] separators) {
		int nNodes = order.length;
		int [] parentCliques = new int[nNodes];
		//for (int i = nNodes - 1; i >= 0; i--) {
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			parentCliques[iNode] = -1;
			if (cliques[iNode] != null && separators[iNode].size() > 0) {
				//for (int j = nNodes - 1; j > i; j--) {
				for (int j = 0; j < nNodes; j++) {
					int iNode2 = order[j];
					if (iNode!= iNode2 && cliques[iNode2] != null && cliques[iNode2].containsAll(separators[iNode])) {
						parentCliques[iNode] = iNode2;
						j = i;
						j = 0;
						j = nNodes;
					}
				}
				
			}
		}
		return parentCliques;
	} // getCliqueTree
	
	/** calculate separator sets in clique tree
	 * 
	 * @param order: maximum cardinality ordering of the graph
	 * @param cliques: set of cliques
	 * @return set of separator sets
	 */
	Set [] getSeparators(int [] order, Set [] cliques) {
		int nNodes = order.length;
		Set [] separators = new HashSet[nNodes];
		Set processedNodes = new HashSet(); 
		//for (int i = nNodes - 1; i >= 0; i--) {
		for (int i = 0; i < nNodes; i++) {
			int iNode = order[i];
			if (cliques[iNode] != null) {
				Set separator = new HashSet();
				separator.addAll(cliques[iNode]);
				separator.retainAll(processedNodes);
				separators[iNode] = separator;
				processedNodes.addAll(cliques[iNode]);
			}
		}
		return separators;
	} // getSeparators
	
	/**
	 * get cliques in a decomposable graph represented by an adjacency matrix
	 * 
	 * @param order: maximum cardinality ordering of the graph
	 * @param bAdjacencyMatrix: decomposable graph
	 * @return set of cliques
	 */
	Set [] getCliques(int[] order, boolean[][] bAdjacencyMatrix) throws Exception {
		int nNodes = bAdjacencyMatrix.length;
		Set [] cliques = new HashSet[nNodes];
		//int[] inverseOrder = new int[nNodes];
		//for (int iNode = 0; iNode < nNodes; iNode++) {
			//inverseOrder[order[iNode]] = iNode;
		//}
		// consult nodes in reverse order
		for (int i = nNodes - 1; i >= 0; i--) {
			int iNode = order[i];
			if (iNode == 22) {
				int h = 3;
				h ++;
			}
			Set clique = new HashSet();
			clique.add(iNode);
			for (int j = 0; j < i; j++) {
				int iNode2 = order[j];
				if (bAdjacencyMatrix[iNode][iNode2]) {
					clique.add(iNode2);
				}
			}
			
			//for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
				//if (bAdjacencyMatrix[iNode][iNode2] && inverseOrder[iNode2] < inverseOrder[iNode]) {
					//clique.add(iNode2);
				//}
			//}
			cliques[iNode] = clique;
		}
		for (int iNode = 0; iNode < nNodes; iNode++) {
			for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
				if (iNode != iNode2 && cliques[iNode]!= null && cliques[iNode2]!= null && cliques[iNode].containsAll(cliques[iNode2])) {
					cliques[iNode2] = null;
				}
			}
		}		
		// sanity check
		if (m_debug) {
		int [] nNodeSet = new int[nNodes];
		for (int iNode = 0; iNode < nNodes; iNode++) {
			if (cliques[iNode] != null) {
				Iterator it = cliques[iNode].iterator();
				int k = 0;
				while (it.hasNext()) {
					nNodeSet[k++] = (Integer) it.next();
				}
				for (int i = 0; i < cliques[iNode].size(); i++) {
					for (int j = 0; j < cliques[iNode].size(); j++) {
						if (i!=j && !bAdjacencyMatrix[nNodeSet[i]][nNodeSet[j]]) {
							throw new Exception("Non clique" + i + " " + j);
						}
					}
				}
			}
		}
		}
		return cliques;
	} // getCliques

	/**
	 * moralize DAG and calculate
	 * adjacency matrix representation for a Bayes Network, effecively
	 * converting the directed acyclic graph to an undirected graph.
	 * 
	 * @param bayesNet:
	 *            Bayes Network to process
	 * @return adjacencies in boolean matrix format
	 */
	public boolean[][] moralize(BayesNet bayesNet) {
		int nNodes = bayesNet.getNrOfNodes();
		boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
		for (int iNode = 0; iNode < nNodes; iNode++) {
			ParentSet parents = bayesNet.getParentSets()[iNode];
			moralizeNode(parents, iNode, bAdjacencyMatrix);
		}
		return bAdjacencyMatrix;
	} // moralize

	private void moralizeNode(ParentSet parents, int iNode, boolean[][] bAdjacencyMatrix) {
		for (int iParent = 0; iParent < parents.getNrOfParents(); iParent++) {
			int nParent = parents.getParent(iParent);
			if ( m_debug && !bAdjacencyMatrix[iNode][nParent])
				System.out.println("Insert " + iNode + "--" + nParent);
			bAdjacencyMatrix[iNode][nParent] = true;
			bAdjacencyMatrix[nParent][iNode] = true;
			for (int iParent2 = iParent + 1; iParent2 < parents.getNrOfParents(); iParent2++) {
				int nParent2 = parents.getParent(iParent2);
				if (m_debug && !bAdjacencyMatrix[nParent2][nParent])
					System.out.println("Mary " + nParent + "--" + nParent2);
				bAdjacencyMatrix[nParent2][nParent] = true;
				bAdjacencyMatrix[nParent][nParent2] = true;
			}
		}	
	} // moralizeNode
	
	/**
	 * Apply Tarjan and Yannakakis (1984) fill in algorithm for graph
	 * triangulation. In reverse order, insert edges between any non-adjacent
	 * neighbors that are lower numbered in the ordering.
	 * 
	 * Side effect: input matrix is used as output
	 * 
	 * @param order:
	 *            node ordering
	 * @param bAdjacencyMatrix:
	 *            boolean matrix representing the graph
	 * @return boolean matrix representing the graph with fill ins
	 */
	public boolean[][] fillIn(int[] order, boolean[][] bAdjacencyMatrix) {
		int nNodes = bAdjacencyMatrix.length;
		int[] inverseOrder = new int[nNodes];
		for (int iNode = 0; iNode < nNodes; iNode++) {
			inverseOrder[order[iNode]] = iNode;
		}
		// consult nodes in reverse order
		for (int i = nNodes - 1; i >= 0; i--) {
			int iNode = order[i];
			// find pairs of neighbors with lower order
			for (int j = 0; j < i; j++) {
				int iNode2 = order[j];
				if (bAdjacencyMatrix[iNode][iNode2]) {
					for (int k = j+1; k < i; k++) {
						int iNode3 = order[k];
						if (bAdjacencyMatrix[iNode][iNode3]) {
							// fill in
							if (m_debug && (!bAdjacencyMatrix[iNode2][iNode3] || !bAdjacencyMatrix[iNode3][iNode2]) )
								System.out.println("Fill in " + iNode2 + "--" + iNode3);
							bAdjacencyMatrix[iNode2][iNode3] = true;
							bAdjacencyMatrix[iNode3][iNode2] = true;
						}
					}
				}
			}
		}
		return bAdjacencyMatrix;
	} // fillIn

	/**
	 * calculate maximum cardinality ordering; start with first node add node
	 * that has most neighbors already ordered till all nodes are in the
	 * ordering
	 * 
	 * This implementation does not assume the graph is connected
	 * 
	 * @param bAdjacencyMatrix:
	 *            n by n matrix with adjacencies in graph of n nodes
	 * @return maximum cardinality ordering
	 */
	int[] getMaxCardOrder(boolean[][] bAdjacencyMatrix) {
		int nNodes = bAdjacencyMatrix.length;
		int[] order = new int[nNodes];
		if (nNodes==0) {return order;}
		boolean[] bDone = new boolean[nNodes];
		// start with node 0
		order[0] = 0;
		bDone[0] = true;
		// order remaining nodes
		for (int iNode = 1; iNode < nNodes; iNode++) {
			int nMaxCard = -1;
			int iBestNode = -1;
			// find node with higest cardinality of previously ordered nodes
			for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
				if (!bDone[iNode2]) {
					int nCard = 0;
					// calculate cardinality for node iNode2
					for (int iNode3 = 0; iNode3 < nNodes; iNode3++) {
						if (bAdjacencyMatrix[iNode2][iNode3] && bDone[iNode3]) {
							nCard++;
						}
					}
					if (nCard > nMaxCard) {
						nMaxCard = nCard;
						iBestNode = iNode2;
					}
				}
			}
			order[iNode] = iBestNode;
			bDone[iBestNode] = true;
		}
		return order;
	} // getMaxCardOrder

	public void setEvidence(int nNode, int iValue) throws Exception {
		if (m_root == null) {
			throw new Exception("Junction tree not initialize yet");
		}
		int iJtNode = 0;
		while (iJtNode < jtNodes.length && (jtNodes[iJtNode] == null ||!jtNodes[iJtNode].contains(nNode))) {
			iJtNode++;
		}
		if (jtNodes.length == iJtNode) {
			throw new Exception("Could not find node " + nNode + " in junction tree");
		}
		jtNodes[iJtNode].setEvidence(nNode, iValue);
	} // setEvidence
	
	public String toString() {
		return m_root.toString();
	} // toString

	double [][] m_Margins;
	public double [] getMargin(int iNode) {
		return m_Margins[iNode];
	} // getMargin

	/**
	 * Returns the revision string.
	 * 
	 * @return		the revision
	 */
	public String getRevision() {
	  return RevisionUtils.extract("$Revision: 1.2 $");
	}
	
	public static void main(String[] args) {
		try {
			BIFReader bayesNet = new BIFReader();
			bayesNet.processFile(args[0]);

			MarginCalculator dc = new MarginCalculator();
			dc.calcMargins(bayesNet);
			int iNode = 2;
			int iValue = 0;
			int iNode2 = 4;
			int iValue2 = 0;
			dc.setEvidence(iNode, iValue);
			dc.setEvidence(iNode2, iValue2);
			System.out.print(dc.toString());


			dc.calcFullMargins(bayesNet);
			dc.setEvidence(iNode, iValue);
			dc.setEvidence(iNode2, iValue2);
			System.out.println("==============");
			System.out.print(dc.toString());
			
			
		} catch (Exception e) {
			e.printStackTrace();
		}
	} // main

} // class MarginCalculator




© 2015 - 2025 Weber Informatics LLC | Privacy Policy