All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.bayes.BayesPm Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.VariableSource;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.util.Pm;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import org.apache.commons.math3.util.FastMath;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serial;
import java.util.*;

/**
 * Implements a discrete Bayes parametric model--that is, a DAG together with a map from the nodes in the graph to a set
 * of discrete variables, specifying the number of categories for each variable and the name of each category for each
 * variable. This is all the information one needs to know in order to determine the parametric form of a Bayes net up
 * to actual values of parameters. Specific values for the Bayes net are stored in a BayesIM object (see).
 *
 * @author josephramsey
 * @version $Id: $Id
 * @see edu.cmu.tetrad.graph.Dag
 * @see BayesIm
 */
public final class BayesPm implements Pm, VariableSource {
    @Serial
    private static final long serialVersionUID = 23L;

    /**
     * The underlying graph that's being parameterized.
     *
     * @serial Cannot be null.
     */
    private final Graph dag;

    /**
     * The map from nodes to variables.
     *
     * @serial Cannot be null.
     */
    private final Map nodesToVariables;

    //=========================CONSTRUCTORS=============================//

    /**
     * Construct a new BayesPm using the given DAG, assigning each variable two values named "value1" and "value2"
     * unless nodes are discrete variables with categories already defined.
     *
     * @param graph Ibid.
     */
    public BayesPm(Graph graph) {
        if (graph == null) {
            throw new NullPointerException("The graph must not be null.");
        }
        this.dag = new EdgeListGraph(graph);
        this.nodesToVariables = new HashMap<>();

        boolean allDiscreteVars = true;
        for (Node node : graph.getNodes()) {
            if (!(node instanceof DiscreteVariable)) {
                allDiscreteVars = false;
                break;
            }
        }

        if (!allDiscreteVars) {
            initializeValues(2, 2);
        } else {
            for (Node node : this.dag.getNodes()) {
                this.nodesToVariables.put(node, (DiscreteVariable) node);
            }
        }
    }

    /**
     * Constructs a new BayesPm using a given DAG, using as much information from the old BayesPm as possible.
     *
     * @param graph      Ibid.
     * @param oldBayesPm Ibid.
     */
    public BayesPm(Graph graph, BayesPm oldBayesPm) {
        this(graph, oldBayesPm, 2, 2);
    }

    /**
     * Constructs a new BayesPm from the given DAG, assigning each variable a random number of values between
     * lowerBound and
     * upperBound. Uses a fixed number of values if lowerBound ==
     * upperBound. The values are named "value1" ... "valuen".
     *
     * @param graph      Ibid.
     * @param lowerBound Ibid.
     * @param upperBound Ibid.
     */
    public BayesPm(Graph graph, int lowerBound, int upperBound) {
        if (graph == null) {
            throw new NullPointerException("The graph must not be null.");
        }

        this.dag = new EdgeListGraph(graph);
        this.nodesToVariables = new HashMap<>();
        initializeValues(lowerBound, upperBound);
    }

    /**
     * Constructs a new BayesPm from the given DAG, using as much information from the old BayesPm as possible. For
     * variables not in the old BayesPm, assigns each variable a random number of values between
     * lowerBound and upperBound. Uses a fixed number
     * of values if lowerBound == upperBound. The values are named "value1" ... "valuen".
     *
     * @param graph      Ibid.
     * @param oldBayesPm Ibid.
     * @param lowerBound Ibid.
     * @param upperBound Ibid.
     */
    public BayesPm(Graph graph, BayesPm oldBayesPm, int lowerBound,
                   int upperBound) {

        // Should be OK wrt variable mismatch problems. jdramsey 2004/1/21

        if (graph == null) {
            throw new NullPointerException("The graph must not be null.");
        }

        if (oldBayesPm == null) {
            throw new NullPointerException("BayesPm must not be null.");
        }

        if (graph.getNumNodes() == 0) {
            throw new IllegalArgumentException(
                    "The graph must have at least " + "one node in it.");
        }

        this.dag = new EdgeListGraph(graph);
        this.nodesToVariables = new HashMap<>();
        copyAvailableInformationFromOldBayesPm(oldBayesPm, lowerBound,
                upperBound);
    }

    /**
     * Copy constructor.
     *
     * @param bayesPm Ibid.
     */
    public BayesPm(BayesPm bayesPm) {
        this.dag = bayesPm.dag;
        this.nodesToVariables = new HashMap<>();

        for (Node node : bayesPm.nodesToVariables.keySet()) {
            DiscreteVariable variable = bayesPm.nodesToVariables.get(node);
            DiscreteVariable newVariable = new DiscreteVariable(variable);

            newVariable.setNodeType(node.getNodeType());

            this.nodesToVariables.put(node, newVariable);
        }
    }

    /**
     * Generates a simple exemplar of this class to test serialization.
     *
     * @return Ibid.
     */
    public static BayesPm serializableInstance() {
        return new BayesPm(Dag.serializableInstance());
    }


    //=========================PUBLIC METHODS=============================//

    /**
     * Returns the parameter names.
     *
     * @return Ibid.
     */
    public static List getParameterNames() {
        List parameters = new ArrayList<>();
        parameters.add("minCategories");
        parameters.add("maxCategories");
        return parameters;
    }

    private static int pickNumVals(int lowerBound, int upperBound) {
        if (lowerBound < 2) {
            throw new IllegalArgumentException(
                    "Lower bound must be >= 2: " + lowerBound);
        }

        if (upperBound < lowerBound) {
            throw new IllegalArgumentException(
                    "Upper bound for number of categories must be >= lower " + "bound.");
        }

        int difference = upperBound - lowerBound;
        RandomUtil randomUtil = RandomUtil.getInstance();
        return randomUtil.nextInt(difference + 1) + lowerBound;
    }

    /**
     * Returns the DAG.
     *
     * @return the DAG as a Graph.
     */
    public Graph getDag() {
        return this.dag;
    }

    /**
     * Returns the number of values for the given node.
     *
     * @param node Ibid.
     * @return the number of values for the given node.
     */
    public int getNumCategories(Node node) {
        DiscreteVariable variable = this.nodesToVariables.get(node);

        if (variable == null) {
            return 0;
        }

        return variable.getNumCategories();
    }

    /**
     * Returns the index'th value for the given node.
     *
     * @param node  Ibid.
     * @param index Ibid.
     * @return the index'th value for the given node.
     */
    public String getCategory(Node node, int index) {
        DiscreteVariable variable = this.nodesToVariables.get(node);

        if (variable != null) {
            return variable.getCategory(index);
        }

        for (DiscreteVariable _node : this.nodesToVariables.values()) {
            if (_node == null) {
                continue;
            }

            if (_node.getName().equals(node.getName())) {
                return _node.getCategory(index);
            }
        }

        throw new IllegalStateException();
    }

    /**
     * Returns the index of the given category for the given node.
     *
     * @param node     Ibid.
     * @param category Ibid.
     * @return the index of the given category for the given node.
     */
    public int getCategoryIndex(Node node, String category) {
        DiscreteVariable variable = this.nodesToVariables.get(node);
        return variable.getIndex(category);
    }

    /**
     * Sets the number of values for the given node to the given number.
     *
     * @param node          Ibid.
     * @param numCategories Ibid.
     */
    public void setNumCategories(Node node, int numCategories) {
        if (!this.nodesToVariables.containsKey(node)) {
            throw new IllegalArgumentException("Node not in BayesPm: " + node);
        }

        if (numCategories < 1) {
            throw new IllegalArgumentException(
                    "Number of categories must be >= 1: " + numCategories);
        }

        DiscreteVariable variable = this.nodesToVariables.get(node);

        List oldCategories = variable.getCategories();
        List newCategories = new LinkedList<>();
        int min = FastMath.min(numCategories, oldCategories.size());

        for (int i = 0; i < min; i++) {
            newCategories.add(oldCategories.get(i));
        }

        for (int i = min; i < numCategories; i++) {
            String proposedName = DataUtils.defaultCategory(i);

            if (newCategories.contains(proposedName)) {
                throw new IllegalArgumentException("Default name already in " +
                                                   "list of categories: " + proposedName);
            }

            newCategories.add(proposedName);
        }

        mapNodeToVariable(node, newCategories);
    }

    /**
     * {@inheritDoc}
     * 

* Will return true if the argument is a BayesPm with the same graph and variables. */ public boolean equals(Object o) { if (o == null) { return false; } if (!(o instanceof BayesPm bayesPm)) { return false; } return bayesPm.dag.equals(this.dag) && bayesPm.nodesToVariables.equals(this.nodesToVariables); } /** * Sets the number of values for the given node to the given number. * * @param node Ibid. * @param categories Ibid. */ public void setCategories(Node node, List categories) { mapNodeToVariable(node, categories); } /** *

getVariables.

* * @return a {@link java.util.List} object */ public List getVariables() { List variables = new LinkedList<>(); for (Node node : this.nodesToVariables.keySet()) { variables.add(this.nodesToVariables.get(node)); } return variables; } /** * Returns the variable names. * * @return Ibid. */ public List getVariableNames() { List variables = getVariables(); List names = new ArrayList<>(); for (Node variable : variables) { DiscreteVariable discreteVariable = (DiscreteVariable) variable; names.add(discreteVariable.getName()); } return names; } /** * Returns the variable for the given node. * * @param node Ibid. * @return Ibid. */ public Node getVariable(Node node) { return this.nodesToVariables.get(node); } /** * Returns the measured nodes. * * @return the list of measured variableNodes. */ public List getMeasuredNodes() { List measuredNodes = new ArrayList<>(); for (Node variable : getVariables()) { if (variable.getNodeType() == NodeType.MEASURED) { measuredNodes.add(variable); } } return measuredNodes; } /** * Prints out the list of values for each node. * * @return Ibid. */ public String toString() { StringBuilder buf = new StringBuilder(); for (Node node1 : this.nodesToVariables.keySet()) { buf.append("\n"); buf.append((node1)); buf.append(": "); DiscreteVariable variable = this.nodesToVariables.get((node1)); for (int j = 0; j < variable.getNumCategories(); j++) { buf.append(variable.getCategory(j)); if (j < variable.getNumCategories() - 1) { buf.append(", "); } } } return buf.toString(); } /** * Returns the node by the given name. * * @param nodeName Ibid. * @return Ibid. */ public Node getNode(String nodeName) { return this.dag.getNode(nodeName); } /** * Returns the node at the given index. * * @param index Ibid. * @return Ibid. */ public Node getNode(int index) { return getVariables().get(index); } /** * Returns the node index. * * @return -1. */ public int getNodeIndex() { return -1; } /** * Returns the number of nodes. * * @return Ibid. */ public int getNumNodes() { return this.dag.getNumNodes(); } //=========================PRIVATE METHODS=============================// private void copyAvailableInformationFromOldBayesPm(BayesPm oldbayesPm, int lowerBound, int upperBound) { Graph newGraph = getDag(); Graph oldGraph = oldbayesPm.getDag(); for (Node node1 : newGraph.getNodes()) { if (oldGraph.containsNode(node1)) { copyOldValues(oldbayesPm, node1, node1, lowerBound, upperBound); } else { setNewValues(node1, lowerBound, upperBound); } } for (Node node2 : newGraph.getNodes()) { if (oldGraph.containsNode(node2)) { Node _node2 = this.dag.getNode(node2.getName()); DiscreteVariable oldNode2 = oldbayesPm.nodesToVariables.get(_node2); oldNode2.setNodeType(node2.getNodeType()); this.nodesToVariables.put(_node2, oldNode2); } else { setNewValues(node2, lowerBound, upperBound); } } } private void copyOldValues(BayesPm oldBayesPm, Node oldNode, Node node, int lowerBound, int upperBound) { List values = new ArrayList<>(); List oldNames = new LinkedList<>(); List oldNodes = oldBayesPm.getDag().getNodes(); for (Node oldNode1 : oldNodes) { oldNames.add(oldNode1.getName()); } int numVals; if (oldNames.contains(node.getName())) { Node oldNode2 = oldBayesPm.getDag().getNode(node.getName()); numVals = oldBayesPm.getNumCategories(oldNode2); } else { numVals = BayesPm.pickNumVals(lowerBound, upperBound); } int min = FastMath.min(oldBayesPm.getNumCategories(oldNode), numVals); for (int i = 0; i < min; i++) { values.add(oldBayesPm.getCategory(oldNode, i)); } for (int i = min; i < numVals; i++) { String proposedName = DataUtils.defaultCategory(i); if (values.contains(proposedName)) { throw new IllegalArgumentException("Default name already in " + "list of values: " + proposedName); } values.add(proposedName); } mapNodeToVariable(node, values); } private void setNewValues(Node node, int lowerBound, int upperBound) { if (node == null) { throw new NullPointerException("Node must not be null."); } List valueList = new ArrayList<>(); for (int i = 0; i < BayesPm.pickNumVals(lowerBound, upperBound); i++) { valueList.add(DataUtils.defaultCategory(i)); } mapNodeToVariable(node, valueList); } private void mapNodeToVariable(Node node, List categories) { if (categories.size() != new HashSet<>(categories).size()) { throw new IllegalArgumentException("Duplicate variable names."); } DiscreteVariable variable = new DiscreteVariable(node.getName(), categories); variable.setNodeType(node.getNodeType()); this.nodesToVariables.put(node, variable); } private void initializeValues(int lowerBound, int upperBound) { for (Node node : this.dag.getNodes()) { setNewValues(node, lowerBound, upperBound); } } /** * Writes the object to the specified ObjectOutputStream. * * @param out The ObjectOutputStream to write the object to. * @throws IOException If an I/O error occurs. */ @Serial private void writeObject(ObjectOutputStream out) throws IOException { try { out.defaultWriteObject(); } catch (IOException e) { TetradLogger.getInstance().log("Failed to serialize object: " + getClass().getCanonicalName() + ", " + e.getMessage()); throw e; } } /** * Reads the object from the specified ObjectInputStream. This method is used during deserialization * to restore the state of the object. * * @param in The ObjectInputStream to read the object from. * @throws IOException If an I/O error occurs. * @throws ClassNotFoundException If the class of the serialized object cannot be found. */ @Serial private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { try { in.defaultReadObject(); } catch (IOException e) { TetradLogger.getInstance().log("Failed to deserialize object: " + getClass().getCanonicalName() + ", " + e.getMessage()); throw e; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy