All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.bayes.JunctionTreeUpdater Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 2020 University of Pittsburgh.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 * MA 02110-1301  USA
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.TetradLogger;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serial;
import java.util.List;

/**
 * Jan 21, 2020 11:03:09 AM
 *
 * @author Kevin V. Bui ([email protected])
 * @version $Id: $Id
 */
public class JunctionTreeUpdater implements ManipulatingBayesUpdater {
    @Serial
    private static final long serialVersionUID = 23L;
    /**
     * The BayesIm which this updater modifies.
     */
    private final BayesIm bayesIm;
    /**
     * Stores evidence for all variables.
     */
    private Evidence evidence;
    /**
     * The last manipulated BayesIm.
     */
    private BayesIm manipulatedBayesIm;
    /**
     * The BayesIm after update, if this was calculated.
     */
    private BayesIm updatedBayesIm;
    /**
     * Calculates probabilities from the manipulated Bayes IM.
     */
    private JunctionTreeAlgorithm jta;

    /**
     * 

Constructor for JunctionTreeUpdater.

* * @param bayesIm a {@link edu.cmu.tetrad.bayes.BayesIm} object */ public JunctionTreeUpdater(BayesIm bayesIm) { this(bayesIm, Evidence.tautology(bayesIm)); } /** *

Constructor for JunctionTreeUpdater.

* * @param bayesIm a {@link edu.cmu.tetrad.bayes.BayesIm} object * @param evidence a {@link edu.cmu.tetrad.bayes.Evidence} object */ public JunctionTreeUpdater(BayesIm bayesIm, Evidence evidence) { if (bayesIm == null) { throw new NullPointerException(); } this.bayesIm = bayesIm; setEvidence(evidence); } /** * {@inheritDoc} */ @Override public BayesIm getManipulatedBayesIm() { return this.manipulatedBayesIm; } /** * {@inheritDoc} */ @Override public Graph getManipulatedGraph() { return getManipulatedBayesIm().getDag(); } /** * {@inheritDoc} */ @Override public Evidence getEvidence() { return new Evidence(this.evidence); } /** * {@inheritDoc} */ @Override public void setEvidence(Evidence evidence) { if (evidence == null) { throw new NullPointerException(); } if (evidence.isIncompatibleWith(this.bayesIm)) { throw new IllegalArgumentException("The variable list for the " + "given bayesIm must be compatible with the variable list " + "for this evidence."); } this.evidence = evidence; Graph graph = this.bayesIm.getBayesPm().getDag(); Dag manipulatedGraph = createManipulatedGraph(graph); BayesPm manipulatedPm = createUpdatedBayesPm(manipulatedGraph); this.manipulatedBayesIm = createdUpdatedBayesIm(manipulatedPm); Evidence evidence2 = new Evidence(evidence, this.manipulatedBayesIm); this.updatedBayesIm = new UpdatedBayesIm(this.manipulatedBayesIm, evidence2); this.jta = new JunctionTreeAlgorithm(this.updatedBayesIm); } /** * {@inheritDoc} */ @Override public BayesIm getUpdatedBayesIm() { if (this.updatedBayesIm == null) { updateAll(); } return this.updatedBayesIm; } /** * {@inheritDoc} */ @Override public double getMarginal(int variable, int category) { Proposition assertion = Proposition.tautology(this.manipulatedBayesIm); Proposition condition = new Proposition(this.manipulatedBayesIm, this.evidence.getProposition()); assertion.setCategory(variable, category); if (condition.existsCombination()) { return this.jta.getMarginalProbability(variable, category); } else { return Double.NaN; } } /** * {@inheritDoc} */ @Override public boolean isJointMarginalSupported() { return true; } /** * {@inheritDoc} */ @Override public double getJointMarginal(int[] variables, int[] values) { if (variables.length != values.length) { throw new IllegalArgumentException("Values must match variables."); } Proposition assertion = Proposition.tautology(this.manipulatedBayesIm); Proposition condition = new Proposition(this.manipulatedBayesIm, this.evidence.getProposition()); for (int i = 0; i < variables.length; i++) { assertion.setCategory(variables[i], values[i]); } if (condition.existsCombination()) { double joint = 1.0; for (int i = 0; i < variables.length; i++) { joint *= this.jta.getMarginalProbability(variables[i], values[i]); } return joint; } else { return Double.NaN; } } /** * {@inheritDoc} */ @Override public BayesIm getBayesIm() { return this.bayesIm; } /** * {@inheritDoc} */ @Override public double[] calculatePriorMarginals(int nodeIndex) { Evidence evidence = getEvidence(); setEvidence(Evidence.tautology(evidence.getVariableSource())); double[] marginals = new double[evidence.getNumCategories(nodeIndex)]; for (int i = 0; i < getBayesIm().getNumColumns(nodeIndex); i++) { marginals[i] = getMarginal(nodeIndex, i); } setEvidence(evidence); return marginals; } /** * {@inheritDoc} */ @Override public double[] calculateUpdatedMarginals(int nodeIndex) { double[] marginals = new double[this.evidence.getNumCategories(nodeIndex)]; for (int i = 0; i < getBayesIm().getNumColumns(nodeIndex); i++) { marginals[i] = getMarginal(nodeIndex, i); } return marginals; } /** * {@inheritDoc} */ @Override public String toString() { return "Junction tree updater, evidence = " + this.evidence; } private void updateAll() { this.updatedBayesIm = new MlBayesIm(this.manipulatedBayesIm); int numNodes = this.manipulatedBayesIm.getNumNodes(); Proposition assertion = Proposition.tautology(this.manipulatedBayesIm); Proposition condition = Proposition.tautology(this.manipulatedBayesIm); Evidence evidence2 = new Evidence(this.evidence, this.manipulatedBayesIm); for (int node = 0; node < numNodes; node++) { int numRows = this.manipulatedBayesIm.getNumRows(node); int numCols = this.manipulatedBayesIm.getNumColumns(node); int[] parents = this.manipulatedBayesIm.getParents(node); for (int row = 0; row < numRows; row++) { int[] parentValues = this.manipulatedBayesIm.getParentValues(node, row); for (int col = 0; col < numCols; col++) { assertion.setToTautology(); condition.setToTautology(); for (int i = 0; i < numNodes; i++) { for (int j = 0; j < evidence2.getNumCategories(i); j++) { if (!evidence2.getProposition().isAllowed(i, j)) { condition.removeCategory(i, j); } } } assertion.disallowComplement(node, col); for (int k = 0; k < parents.length; k++) { condition.disallowComplement(parents[k], parentValues[k]); } if (condition.existsCombination()) { double p = (parents.length > 0) ? this.jta.getConditionalProbability(node, col, parents, parentValues) : this.jta.getMarginalProbability(node, col); this.updatedBayesIm.setProbability(node, row, col, p); } else { this.updatedBayesIm.setProbability(node, row, col, Double.NaN); } } } } } private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { return new BayesPm(updatedGraph, this.bayesIm.getBayesPm()); } private Dag createManipulatedGraph(Graph graph) { Dag updatedGraph = new Dag(graph); // alters graph for manipulated evidenceItems for (int i = 0; i < this.evidence.getNumNodes(); ++i) { if (this.evidence.isManipulated(i)) { Node node = updatedGraph.getNode(this.evidence.getNode(i).getName()); List parents = updatedGraph.getParents(node); for (Node parent1 : parents) { updatedGraph.removeEdge(node, parent1); } } } return updatedGraph; } /** * Writes the object to the specified ObjectOutputStream. * * @param out The ObjectOutputStream to write the object to. * @throws IOException If an I/O error occurs. */ @Serial private void writeObject(ObjectOutputStream out) throws IOException { try { out.defaultWriteObject(); } catch (IOException e) { TetradLogger.getInstance().log("Failed to serialize object: " + getClass().getCanonicalName() + ", " + e.getMessage()); throw e; } } /** * Reads the object from the specified ObjectInputStream. This method is used during deserialization * to restore the state of the object. * * @param in The ObjectInputStream to read the object from. * @throws IOException If an I/O error occurs. * @throws ClassNotFoundException If the class of the serialized object cannot be found. */ @Serial private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { try { in.defaultReadObject(); } catch (IOException e) { TetradLogger.getInstance().log("Failed to deserialize object: " + getClass().getCanonicalName() + ", " + e.getMessage()); throw e; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy