All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.bayes.net.MarginCalculator Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    MarginCalculator.java
 *    Copyright (C) 2007-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.bayes.net;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.Vector;

import weka.classifiers.bayes.BayesNet;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;

public class MarginCalculator implements Serializable, RevisionHandler {
  /** for serialization */
  private static final long serialVersionUID = 650278019241175534L;

  boolean m_debug = false;
  public JunctionTreeNode m_root = null;
  JunctionTreeNode[] jtNodes;

  public int getNode(String sNodeName) {
    int iNode = 0;
    while (iNode < m_root.m_bayesNet.m_Instances.numAttributes()) {
      if (m_root.m_bayesNet.m_Instances.attribute(iNode).name()
        .equals(sNodeName)) {
        return iNode;
      }
      iNode++;
    }
    // throw new Exception("Could not find node [[" + sNodeName + "]]");
    return -1;
  }

  public String toXMLBIF03() {
    return m_root.m_bayesNet.toXMLBIF03();
  }

  /**
   * Calc marginal distributions of nodes in Bayesian network Note that a
   * connected network is assumed. Unconnected networks may give unexpected
   * results.
   * 
   * @param bayesNet
   */
  public void calcMargins(BayesNet bayesNet) throws Exception {
    // System.out.println(bayesNet.toString());
    boolean[][] bAdjacencyMatrix = moralize(bayesNet);
    process(bAdjacencyMatrix, bayesNet);
  } // calcMargins

  public void calcFullMargins(BayesNet bayesNet) throws Exception {
    // System.out.println(bayesNet.toString());
    int nNodes = bayesNet.getNrOfNodes();
    boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
    for (int iNode = 0; iNode < nNodes; iNode++) {
      for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
        bAdjacencyMatrix[iNode][iNode2] = true;
      }
    }
    process(bAdjacencyMatrix, bayesNet);
  } // calcMargins

  public void process(boolean[][] bAdjacencyMatrix, BayesNet bayesNet)
    throws Exception {
    int[] order = getMaxCardOrder(bAdjacencyMatrix);
    bAdjacencyMatrix = fillIn(order, bAdjacencyMatrix);
    order = getMaxCardOrder(bAdjacencyMatrix);
    Set[] cliques = getCliques(order, bAdjacencyMatrix);
    Set[] separators = getSeparators(order, cliques);
    int[] parentCliques = getCliqueTree(order, cliques, separators);
    // report cliques
    int nNodes = bAdjacencyMatrix.length;
    if (m_debug) {
      for (int i = 0; i < nNodes; i++) {
        int iNode = order[i];
        if (cliques[iNode] != null) {
          System.out.print("Clique " + iNode + " (");
          Iterator nodes = cliques[iNode].iterator();
          while (nodes.hasNext()) {
            int iNode2 = nodes.next();
            System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
            if (nodes.hasNext()) {
              System.out.print(",");
            }
          }
          System.out.print(") S(");
          nodes = separators[iNode].iterator();
          while (nodes.hasNext()) {
            int iNode2 = nodes.next();
            System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
            if (nodes.hasNext()) {
              System.out.print(",");
            }
          }
          System.out.println(") parent clique " + parentCliques[iNode]);
        }
      }
    }

    jtNodes = getJunctionTree(cliques, separators, parentCliques, order,
      bayesNet);
    m_root = null;
    for (int iNode = 0; iNode < nNodes; iNode++) {
      if (parentCliques[iNode] < 0 && jtNodes[iNode] != null) {
        m_root = jtNodes[iNode];
        break;
      }
    }
    m_Margins = new double[nNodes][];
    initialize(jtNodes, order, cliques, separators, parentCliques);

    // sanity check
    for (int i = 0; i < nNodes; i++) {
      int iNode = order[i];
      if (cliques[iNode] != null) {
        if (parentCliques[iNode] == -1 && separators[iNode].size() > 0) {
          throw new Exception("Something wrong in clique tree");
        }
      }
    }
    if (m_debug) {
      // System.out.println(m_root.toString());
    }
  } // process

  void initialize(JunctionTreeNode[] jtNodes, int[] order,
    Set[] cliques, Set[] separators, int[] parentCliques) {
    int nNodes = order.length;
    for (int i = nNodes - 1; i >= 0; i--) {
      int iNode = order[i];
      if (jtNodes[iNode] != null) {
        jtNodes[iNode].initializeUp();
      }
    }
    for (int i = 0; i < nNodes; i++) {
      int iNode = order[i];
      if (jtNodes[iNode] != null) {
        jtNodes[iNode].initializeDown(false);
      }
    }
  } // initialize

  JunctionTreeNode[] getJunctionTree(Set[] cliques,
    Set[] separators, int[] parentCliques, int[] order,
    BayesNet bayesNet) {
    int nNodes = order.length;
    JunctionTreeNode[] jtns = new JunctionTreeNode[nNodes];
    boolean[] bDone = new boolean[nNodes];
    // create junction tree nodes
    for (int i = 0; i < nNodes; i++) {
      int iNode = order[i];
      if (cliques[iNode] != null) {
        jtns[iNode] = new JunctionTreeNode(cliques[iNode], bayesNet, bDone);
      }
    }
    // create junction tree separators
    for (int i = 0; i < nNodes; i++) {
      int iNode = order[i];
      if (cliques[iNode] != null) {
        JunctionTreeNode parent = null;
        if (parentCliques[iNode] > 0) {
          parent = jtns[parentCliques[iNode]];
          JunctionTreeSeparator jts = new JunctionTreeSeparator(
            separators[iNode], bayesNet, jtns[iNode], parent);
          jtns[iNode].setParentSeparator(jts);
          jtns[parentCliques[iNode]].addChildClique(jtns[iNode]);
        } else {
        }
      }
    }
    return jtns;
  } // getJunctionTree

  public class JunctionTreeSeparator implements Serializable, RevisionHandler {

    private static final long serialVersionUID = 6502780192411755343L;
    int[] m_nNodes;
    int m_nCardinality;
    double[] m_fiParent;
    double[] m_fiChild;
    JunctionTreeNode m_parentNode;
    JunctionTreeNode m_childNode;
    BayesNet m_bayesNet;

    JunctionTreeSeparator(Set separator, BayesNet bayesNet,
      JunctionTreeNode childNode, JunctionTreeNode parentNode) {
      // ////////////////////
      // initialize node set
      m_nNodes = new int[separator.size()];
      int iPos = 0;
      m_nCardinality = 1;
      for (Integer element : separator) {
        int iNode = element;
        m_nNodes[iPos++] = iNode;
        m_nCardinality *= bayesNet.getCardinality(iNode);
      }
      m_parentNode = parentNode;
      m_childNode = childNode;
      m_bayesNet = bayesNet;
    } // c'tor

    /**
     * marginalize junciontTreeNode node over all nodes outside the separator
     * set of the parent clique
     * 
     */
    public void updateFromParent() {
      double[] fis = update(m_parentNode);
      if (fis == null) {
        m_fiParent = null;
      } else {
        m_fiParent = fis;
        // normalize
        double sum = 0;
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          sum += m_fiParent[iPos];
        }
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          m_fiParent[iPos] /= sum;
        }
      }
    } // updateFromParent

    /**
     * marginalize junciontTreeNode node over all nodes outside the separator
     * set of the child clique
     * 
     */
    public void updateFromChild() {
      double[] fis = update(m_childNode);
      if (fis == null) {
        m_fiChild = null;
      } else {
        m_fiChild = fis;
        // normalize
        double sum = 0;
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          sum += m_fiChild[iPos];
        }
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          m_fiChild[iPos] /= sum;
        }
      }
    } // updateFromChild

    /**
     * marginalize junciontTreeNode node over all nodes outside the separator
     * set
     * 
     * @param node one of the neighboring junciont tree nodes of this separator
     */
    public double[] update(JunctionTreeNode node) {
      if (node.m_P == null) {
        return null;
      }
      double[] fi = new double[m_nCardinality];

      int[] values = new int[node.m_nNodes.length];
      int[] order = new int[m_bayesNet.getNrOfNodes()];
      for (int iNode = 0; iNode < node.m_nNodes.length; iNode++) {
        order[node.m_nNodes[iNode]] = iNode;
      }
      // fill in the values
      for (int iPos = 0; iPos < node.m_nCardinality; iPos++) {
        int iNodeCPT = getCPT(node.m_nNodes, node.m_nNodes.length, values,
          order, m_bayesNet);
        int iSepCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
          m_bayesNet);
        fi[iSepCPT] += node.m_P[iNodeCPT];
        // update values
        int i = 0;
        values[i]++;
        while (i < node.m_nNodes.length
          && values[i] == m_bayesNet.getCardinality(node.m_nNodes[i])) {
          values[i] = 0;
          i++;
          if (i < node.m_nNodes.length) {
            values[i]++;
          }
        }
      }
      return fi;
    } // update

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    @Override
    public String getRevision() {
      return RevisionUtils.extract("$Revision: 10154 $");
    }

  } // class JunctionTreeSeparator

  public class JunctionTreeNode implements Serializable, RevisionHandler {

    private static final long serialVersionUID = 650278019241175536L;
    /**
     * reference Bayes net for information about variables like name,
     * cardinality, etc. but not for relations between nodes
     **/
    BayesNet m_bayesNet;
    /** nodes of the Bayes net in this junction node **/
    public int[] m_nNodes;
    /** cardinality of the instances of variables in this junction node **/
    int m_nCardinality;
    /** potentials for first network **/
    double[] m_fi;

    /** distribution over this junction node according to first Bayes network **/
    double[] m_P;

    double[][] m_MarginalP;

    JunctionTreeSeparator m_parentSeparator;

    public void setParentSeparator(JunctionTreeSeparator parentSeparator) {
      m_parentSeparator = parentSeparator;
    }

    public Vector m_children;

    public void addChildClique(JunctionTreeNode child) {
      m_children.add(child);
    }

    public void initializeUp() {
      m_P = new double[m_nCardinality];
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        m_P[iPos] = m_fi[iPos];
      }
      int[] values = new int[m_nNodes.length];
      int[] order = new int[m_bayesNet.getNrOfNodes()];
      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        order[m_nNodes[iNode]] = iNode;
      }
      for (JunctionTreeNode element : m_children) {
        JunctionTreeNode childNode = element;
        JunctionTreeSeparator separator = childNode.m_parentSeparator;
        // Update the values
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          int iSepCPT = getCPT(separator.m_nNodes, separator.m_nNodes.length,
            values, order, m_bayesNet);
          int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
            m_bayesNet);
          m_P[iNodeCPT] *= separator.m_fiChild[iSepCPT];
          // update values
          int i = 0;
          values[i]++;
          while (i < m_nNodes.length
            && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
            values[i] = 0;
            i++;
            if (i < m_nNodes.length) {
              values[i]++;
            }
          }
        }
      }
      // normalize
      double sum = 0;
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        sum += m_P[iPos];
      }
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        m_P[iPos] /= sum;
      }

      if (m_parentSeparator != null) { // not a root node
        m_parentSeparator.updateFromChild();
      }
    } // initializeUp

    public void initializeDown(boolean recursively) {
      if (m_parentSeparator == null) { // a root node
        calcMarginalProbabilities();
      } else {
        m_parentSeparator.updateFromParent();
        int[] values = new int[m_nNodes.length];
        int[] order = new int[m_bayesNet.getNrOfNodes()];
        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
          order[m_nNodes[iNode]] = iNode;
        }

        // Update the values
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          int iSepCPT = getCPT(m_parentSeparator.m_nNodes,
            m_parentSeparator.m_nNodes.length, values, order, m_bayesNet);
          int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
            m_bayesNet);
          if (m_parentSeparator.m_fiChild[iSepCPT] > 0) {
            m_P[iNodeCPT] *= m_parentSeparator.m_fiParent[iSepCPT]
              / m_parentSeparator.m_fiChild[iSepCPT];
          } else {
            m_P[iNodeCPT] = 0;
          }
          // update values
          int i = 0;
          values[i]++;
          while (i < m_nNodes.length
            && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
            values[i] = 0;
            i++;
            if (i < m_nNodes.length) {
              values[i]++;
            }
          }
        }
        // normalize
        double sum = 0;
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          sum += m_P[iPos];
        }
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          m_P[iPos] /= sum;
        }
        m_parentSeparator.updateFromChild();
        calcMarginalProbabilities();
      }
      if (recursively) {
        for (Object element : m_children) {
          JunctionTreeNode childNode = (JunctionTreeNode) element;
          childNode.initializeDown(true);
        }
      }
    } // initializeDown

    /**
     * calculate marginal probabilities for the individual nodes in the clique.
     * Store results in m_MarginalP
     */
    void calcMarginalProbabilities() {
      // calculate marginal probabilities
      int[] values = new int[m_nNodes.length];
      int[] order = new int[m_bayesNet.getNrOfNodes()];
      m_MarginalP = new double[m_nNodes.length][];
      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        order[m_nNodes[iNode]] = iNode;
        m_MarginalP[iNode] = new double[m_bayesNet
          .getCardinality(m_nNodes[iNode])];
      }
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
          m_bayesNet);
        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
          m_MarginalP[iNode][values[iNode]] += m_P[iNodeCPT];
        }
        // update values
        int i = 0;
        values[i]++;
        while (i < m_nNodes.length
          && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
          values[i] = 0;
          i++;
          if (i < m_nNodes.length) {
            values[i]++;
          }
        }
      }

      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        m_Margins[m_nNodes[iNode]] = m_MarginalP[iNode];
      }
    } // calcMarginalProbabilities

    @Override
    public String toString() {
      StringBuffer buf = new StringBuffer();
      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        buf.append(m_bayesNet.getNodeName(m_nNodes[iNode]) + ": ");
        for (int iValue = 0; iValue < m_MarginalP[iNode].length; iValue++) {
          buf.append(m_MarginalP[iNode][iValue] + " ");
        }
        buf.append('\n');
      }
      for (Object element : m_children) {
        JunctionTreeNode childNode = (JunctionTreeNode) element;
        buf.append("----------------\n");
        buf.append(childNode.toString());
      }
      return buf.toString();
    } // toString

    void calculatePotentials(BayesNet bayesNet, Set clique,
      boolean[] bDone) {
      m_fi = new double[m_nCardinality];

      int[] values = new int[m_nNodes.length];
      int[] order = new int[bayesNet.getNrOfNodes()];
      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        order[m_nNodes[iNode]] = iNode;
      }
      // find conditional probabilities that need to be taken in account
      boolean[] bIsContained = new boolean[m_nNodes.length];
      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        int nNode = m_nNodes[iNode];
        bIsContained[iNode] = !bDone[nNode];
        for (int iParent = 0; iParent < bayesNet.getNrOfParents(nNode); iParent++) {
          int nParent = bayesNet.getParent(nNode, iParent);
          if (!clique.contains(nParent)) {
            bIsContained[iNode] = false;
          }
        }
        if (bIsContained[iNode]) {
          bDone[nNode] = true;
          if (m_debug) {
            System.out.println("adding node " + nNode);
          }
        }
      }

      // fill in the values
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        int iCPT = getCPT(m_nNodes, m_nNodes.length, values, order, bayesNet);
        m_fi[iCPT] = 1.0;
        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
          if (bIsContained[iNode]) {
            int nNode = m_nNodes[iNode];
            int[] nNodes = bayesNet.getParentSet(nNode).getParents();
            int iCPT2 = getCPT(nNodes, bayesNet.getNrOfParents(nNode), values,
              order, bayesNet);
            double f = bayesNet.getDistributions()[nNode][iCPT2]
              .getProbability(values[iNode]);
            m_fi[iCPT] *= f;
          }
        }

        // update values
        int i = 0;
        values[i]++;
        while (i < m_nNodes.length
          && values[i] == bayesNet.getCardinality(m_nNodes[i])) {
          values[i] = 0;
          i++;
          if (i < m_nNodes.length) {
            values[i]++;
          }
        }
      }
    } // calculatePotentials

    JunctionTreeNode(Set clique, BayesNet bayesNet, boolean[] bDone) {
      m_bayesNet = bayesNet;
      m_children = new Vector();
      // ////////////////////
      // initialize node set
      m_nNodes = new int[clique.size()];
      int iPos = 0;
      m_nCardinality = 1;
      for (Integer integer : clique) {
        int iNode = integer;
        m_nNodes[iPos++] = iNode;
        m_nCardinality *= bayesNet.getCardinality(iNode);
      }
      // //////////////////////////////
      // initialize potential function
      calculatePotentials(bayesNet, clique, bDone);
    } // JunctionTreeNode c'tor

    /*
     * check whether this junciton tree node contains node nNode
     */
    boolean contains(int nNode) {
      for (int m_nNode : m_nNodes) {
        if (m_nNode == nNode) {
          return true;
        }
      }
      return false;
    } // contains

    public void setEvidence(int nNode, int iValue) throws Exception {
      int[] values = new int[m_nNodes.length];
      int[] order = new int[m_bayesNet.getNrOfNodes()];

      int nNodeIdx = -1;
      for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
        order[m_nNodes[iNode]] = iNode;
        if (m_nNodes[iNode] == nNode) {
          nNodeIdx = iNode;
        }
      }
      if (nNodeIdx < 0) {
        throw new Exception("setEvidence: Node " + nNode
          + " not found in this clique");
      }
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        if (values[nNodeIdx] != iValue) {
          int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
            m_bayesNet);
          m_P[iNodeCPT] = 0;
        }
        // update values
        int i = 0;
        values[i]++;
        while (i < m_nNodes.length
          && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
          values[i] = 0;
          i++;
          if (i < m_nNodes.length) {
            values[i]++;
          }
        }
      }
      // normalize
      double sum = 0;
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        sum += m_P[iPos];
      }
      for (int iPos = 0; iPos < m_nCardinality; iPos++) {
        m_P[iPos] /= sum;
      }
      calcMarginalProbabilities();
      updateEvidence(this);
    } // setEvidence

    void updateEvidence(JunctionTreeNode source) {
      if (source != this) {
        int[] values = new int[m_nNodes.length];
        int[] order = new int[m_bayesNet.getNrOfNodes()];
        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
          order[m_nNodes[iNode]] = iNode;
        }
        int[] nChildNodes = source.m_parentSeparator.m_nNodes;
        int nNumChildNodes = nChildNodes.length;
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          int iNodeCPT = getCPT(m_nNodes, m_nNodes.length, values, order,
            m_bayesNet);
          int iChildCPT = getCPT(nChildNodes, nNumChildNodes, values, order,
            m_bayesNet);
          if (source.m_parentSeparator.m_fiParent[iChildCPT] != 0) {
            m_P[iNodeCPT] *= source.m_parentSeparator.m_fiChild[iChildCPT]
              / source.m_parentSeparator.m_fiParent[iChildCPT];
          } else {
            m_P[iNodeCPT] = 0;
          }
          // update values
          int i = 0;
          values[i]++;
          while (i < m_nNodes.length
            && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
            values[i] = 0;
            i++;
            if (i < m_nNodes.length) {
              values[i]++;
            }
          }
        }
        // normalize
        double sum = 0;
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          sum += m_P[iPos];
        }
        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
          m_P[iPos] /= sum;
        }
        calcMarginalProbabilities();
      }
      for (Object element : m_children) {
        JunctionTreeNode childNode = (JunctionTreeNode) element;
        if (childNode != source) {
          childNode.initializeDown(true);
        }
      }
      if (m_parentSeparator != null) {
        m_parentSeparator.updateFromChild();
        m_parentSeparator.m_parentNode.updateEvidence(this);
        m_parentSeparator.updateFromParent();
      }
    } // updateEvidence

    /**
     * Returns the revision string.
     * 
     * @return the revision
     */
    @Override
    public String getRevision() {
      return RevisionUtils.extract("$Revision: 10154 $");
    }

  } // class JunctionTreeNode

  int getCPT(int[] nodeSet, int nNodes, int[] values, int[] order,
    BayesNet bayesNet) {
    int iCPTnew = 0;
    for (int iNode = 0; iNode < nNodes; iNode++) {
      int nNode = nodeSet[iNode];
      iCPTnew = iCPTnew * bayesNet.getCardinality(nNode);
      iCPTnew += values[order[nNode]];
    }
    return iCPTnew;
  } // getCPT

  int[] getCliqueTree(int[] order, Set[] cliques,
    Set[] separators) {
    int nNodes = order.length;
    int[] parentCliques = new int[nNodes];
    // for (int i = nNodes - 1; i >= 0; i--) {
    for (int i = 0; i < nNodes; i++) {
      int iNode = order[i];
      parentCliques[iNode] = -1;
      if (cliques[iNode] != null && separators[iNode].size() > 0) {
        // for (int j = nNodes - 1; j > i; j--) {
        for (int j = 0; j < nNodes; j++) {
          int iNode2 = order[j];
          if (iNode != iNode2 && cliques[iNode2] != null
            && cliques[iNode2].containsAll(separators[iNode])) {
            parentCliques[iNode] = iNode2;
            j = i;
            j = 0;
            j = nNodes;
          }
        }

      }
    }
    return parentCliques;
  } // getCliqueTree

  /**
   * calculate separator sets in clique tree
   * 
   * @param order: maximum cardinality ordering of the graph
   * @param cliques: set of cliques
   * @return set of separator sets
   */
  Set[] getSeparators(int[] order, Set[] cliques) {
    int nNodes = order.length;
    @SuppressWarnings("unchecked")
    Set[] separators = new HashSet[nNodes];
    Set processedNodes = new HashSet();
    // for (int i = nNodes - 1; i >= 0; i--) {
    for (int i = 0; i < nNodes; i++) {
      int iNode = order[i];
      if (cliques[iNode] != null) {
        Set separator = new HashSet();
        separator.addAll(cliques[iNode]);
        separator.retainAll(processedNodes);
        separators[iNode] = separator;
        processedNodes.addAll(cliques[iNode]);
      }
    }
    return separators;
  } // getSeparators

  /**
   * get cliques in a decomposable graph represented by an adjacency matrix
   * 
   * @param order: maximum cardinality ordering of the graph
   * @param bAdjacencyMatrix: decomposable graph
   * @return set of cliques
   */
  Set[] getCliques(int[] order, boolean[][] bAdjacencyMatrix)
    throws Exception {
    int nNodes = bAdjacencyMatrix.length;
    @SuppressWarnings("unchecked")
    Set[] cliques = new HashSet[nNodes];
    // int[] inverseOrder = new int[nNodes];
    // for (int iNode = 0; iNode < nNodes; iNode++) {
    // inverseOrder[order[iNode]] = iNode;
    // }
    // consult nodes in reverse order
    for (int i = nNodes - 1; i >= 0; i--) {
      int iNode = order[i];
      if (iNode == 22) {
      }
      Set clique = new HashSet();
      clique.add(iNode);
      for (int j = 0; j < i; j++) {
        int iNode2 = order[j];
        if (bAdjacencyMatrix[iNode][iNode2]) {
          clique.add(iNode2);
        }
      }

      // for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
      // if (bAdjacencyMatrix[iNode][iNode2] && inverseOrder[iNode2] <
      // inverseOrder[iNode]) {
      // clique.add(iNode2);
      // }
      // }
      cliques[iNode] = clique;
    }
    for (int iNode = 0; iNode < nNodes; iNode++) {
      for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
        if (iNode != iNode2 && cliques[iNode] != null
          && cliques[iNode2] != null
          && cliques[iNode].containsAll(cliques[iNode2])) {
          cliques[iNode2] = null;
        }
      }
    }
    // sanity check
    if (m_debug) {
      int[] nNodeSet = new int[nNodes];
      for (int iNode = 0; iNode < nNodes; iNode++) {
        if (cliques[iNode] != null) {
          Iterator it = cliques[iNode].iterator();
          int k = 0;
          while (it.hasNext()) {
            nNodeSet[k++] = it.next();
          }
          for (int i = 0; i < cliques[iNode].size(); i++) {
            for (int j = 0; j < cliques[iNode].size(); j++) {
              if (i != j && !bAdjacencyMatrix[nNodeSet[i]][nNodeSet[j]]) {
                throw new Exception("Non clique" + i + " " + j);
              }
            }
          }
        }
      }
    }
    return cliques;
  } // getCliques

  /**
   * moralize DAG and calculate adjacency matrix representation for a Bayes
   * Network, effecively converting the directed acyclic graph to an undirected
   * graph.
   * 
   * @param bayesNet Bayes Network to process
   * @return adjacencies in boolean matrix format
   */
  public boolean[][] moralize(BayesNet bayesNet) {
    int nNodes = bayesNet.getNrOfNodes();
    boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
    for (int iNode = 0; iNode < nNodes; iNode++) {
      ParentSet parents = bayesNet.getParentSets()[iNode];
      moralizeNode(parents, iNode, bAdjacencyMatrix);
    }
    return bAdjacencyMatrix;
  } // moralize

  private void moralizeNode(ParentSet parents, int iNode,
    boolean[][] bAdjacencyMatrix) {
    for (int iParent = 0; iParent < parents.getNrOfParents(); iParent++) {
      int nParent = parents.getParent(iParent);
      if (m_debug && !bAdjacencyMatrix[iNode][nParent]) {
        System.out.println("Insert " + iNode + "--" + nParent);
      }
      bAdjacencyMatrix[iNode][nParent] = true;
      bAdjacencyMatrix[nParent][iNode] = true;
      for (int iParent2 = iParent + 1; iParent2 < parents.getNrOfParents(); iParent2++) {
        int nParent2 = parents.getParent(iParent2);
        if (m_debug && !bAdjacencyMatrix[nParent2][nParent]) {
          System.out.println("Mary " + nParent + "--" + nParent2);
        }
        bAdjacencyMatrix[nParent2][nParent] = true;
        bAdjacencyMatrix[nParent][nParent2] = true;
      }
    }
  } // moralizeNode

  /**
   * Apply Tarjan and Yannakakis (1984) fill in algorithm for graph
   * triangulation. In reverse order, insert edges between any non-adjacent
   * neighbors that are lower numbered in the ordering.
   * 
   * Side effect: input matrix is used as output
   * 
   * @param order node ordering
   * @param bAdjacencyMatrix boolean matrix representing the graph
   * @return boolean matrix representing the graph with fill ins
   */
  public boolean[][] fillIn(int[] order, boolean[][] bAdjacencyMatrix) {
    int nNodes = bAdjacencyMatrix.length;
    int[] inverseOrder = new int[nNodes];
    for (int iNode = 0; iNode < nNodes; iNode++) {
      inverseOrder[order[iNode]] = iNode;
    }
    // consult nodes in reverse order
    for (int i = nNodes - 1; i >= 0; i--) {
      int iNode = order[i];
      // find pairs of neighbors with lower order
      for (int j = 0; j < i; j++) {
        int iNode2 = order[j];
        if (bAdjacencyMatrix[iNode][iNode2]) {
          for (int k = j + 1; k < i; k++) {
            int iNode3 = order[k];
            if (bAdjacencyMatrix[iNode][iNode3]) {
              // fill in
              if (m_debug
                && (!bAdjacencyMatrix[iNode2][iNode3] || !bAdjacencyMatrix[iNode3][iNode2])) {
                System.out.println("Fill in " + iNode2 + "--" + iNode3);
              }
              bAdjacencyMatrix[iNode2][iNode3] = true;
              bAdjacencyMatrix[iNode3][iNode2] = true;
            }
          }
        }
      }
    }
    return bAdjacencyMatrix;
  } // fillIn

  /**
   * calculate maximum cardinality ordering; start with first node add node that
   * has most neighbors already ordered till all nodes are in the ordering
   * 
   * This implementation does not assume the graph is connected
   * 
   * @param bAdjacencyMatrix: n by n matrix with adjacencies in graph of n nodes
   * @return maximum cardinality ordering
   */
  int[] getMaxCardOrder(boolean[][] bAdjacencyMatrix) {
    int nNodes = bAdjacencyMatrix.length;
    int[] order = new int[nNodes];
    if (nNodes == 0) {
      return order;
    }
    boolean[] bDone = new boolean[nNodes];
    // start with node 0
    order[0] = 0;
    bDone[0] = true;
    // order remaining nodes
    for (int iNode = 1; iNode < nNodes; iNode++) {
      int nMaxCard = -1;
      int iBestNode = -1;
      // find node with higest cardinality of previously ordered nodes
      for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
        if (!bDone[iNode2]) {
          int nCard = 0;
          // calculate cardinality for node iNode2
          for (int iNode3 = 0; iNode3 < nNodes; iNode3++) {
            if (bAdjacencyMatrix[iNode2][iNode3] && bDone[iNode3]) {
              nCard++;
            }
          }
          if (nCard > nMaxCard) {
            nMaxCard = nCard;
            iBestNode = iNode2;
          }
        }
      }
      order[iNode] = iBestNode;
      bDone[iBestNode] = true;
    }
    return order;
  } // getMaxCardOrder

  public void setEvidence(int nNode, int iValue) throws Exception {
    if (m_root == null) {
      throw new Exception("Junction tree not initialize yet");
    }
    int iJtNode = 0;
    while (iJtNode < jtNodes.length
      && (jtNodes[iJtNode] == null || !jtNodes[iJtNode].contains(nNode))) {
      iJtNode++;
    }
    if (jtNodes.length == iJtNode) {
      throw new Exception("Could not find node " + nNode + " in junction tree");
    }
    jtNodes[iJtNode].setEvidence(nNode, iValue);
  } // setEvidence

  @Override
  public String toString() {
    return m_root.toString();
  } // toString

  double[][] m_Margins;

  public double[] getMargin(int iNode) {
    return m_Margins[iNode];
  } // getMargin

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 10154 $");
  }

  public static void main(String[] args) {
    try {
      BIFReader bayesNet = new BIFReader();
      bayesNet.processFile(args[0]);

      MarginCalculator dc = new MarginCalculator();
      dc.calcMargins(bayesNet);
      int iNode = 2;
      int iValue = 0;
      int iNode2 = 4;
      int iValue2 = 0;
      dc.setEvidence(iNode, iValue);
      dc.setEvidence(iNode2, iValue2);
      System.out.print(dc.toString());

      dc.calcFullMargins(bayesNet);
      dc.setEvidence(iNode, iValue);
      dc.setEvidence(iNode2, iValue2);
      System.out.println("==============");
      System.out.print(dc.toString());

    } catch (Exception e) {
      e.printStackTrace();
    }
  } // main

} // class MarginCalculator




© 2015 - 2025 Weber Informatics LLC | Privacy Policy