All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.bayes.FactoredBayesStructuralEM Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TetradLogger;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * 

Implements the procedure Factored-Bayesian-SEM found on page 6 of "The * Bayesian Structural EM Algorithm" by Nir Friedman.> 0

The initial implementation simplifies the algorithm * somewhat by computing the score of each model by using the BdeMetric score, which is implemented in the bayes package * of Tetrad-4.2.> 0 * * @author Frank Wimberly * @author Robert Tillman (changes since 5-12-2008) */ public final class FactoredBayesStructuralEM { // Previous comment when there were three iterate methods. iterate2() has been // renamed to iterate(), and the former iterate() and iterate1() have been // commented out, below, all in the interests of simplifying the public API. // -jdramsey 2/2/2004 // // *

There are two obsolete methods in this class. The iterate method uses the // * BdeMetric class in an implementation of the Procedure Factored-Bayesian-SEM // * in the Friedman paper. The iterate1 method uses the BdeMetricCache class // * instead. The latter class factors each of the models it searches over // * exploiting the fact that different models may share factors whose score only // * has to be computed and stored once. Hence, the iterate and iterate1 methods // * should return the same model with the same score but iterate1 should be more // * efficient.> 0 // * // *

The iterate and iterate1 methods have been replaced by iterate2.> 0 private final BayesPm bayesPmM0; private final DataSet dataSet; private final int[] ncategories; private double tolerance; public FactoredBayesStructuralEM(DataSet dataSet, BayesPm bayesPmM0) { this.dataSet = dataSet; this.bayesPmM0 = bayesPmM0; List datasetVars = dataSet.getVariables(); this.ncategories = new int[datasetVars.size()]; // Store the number of categories for each variable in an array which will be used // to define the Bayes nets searched over in the iterative procedure. for (int i = 0; i < this.ncategories.length; i++) { this.ncategories[i] = ((DiscreteVariable) datasetVars.get(i)).getNumCategories(); } } private static double factorScoreMD(Dag dag, BdeMetricCache bdeMetricCache, BayesPm bayesPm, BayesIm bayesIm) { List nodes = dag.getNodes(); double score = 0.0; //Fast test 11/29/04 for (Node node1 : nodes) { List parents = dag.getParents(node1); Set parentsSet = new HashSet<>(parents); double fScore = bdeMetricCache.scoreLnGam(node1, parentsSet, bayesPm, bayesIm); TetradLogger.getInstance().log("details", "Score for factor " + node1.getName() + " = " + fScore); score += fScore; } return score; } /** * This method allows specification of the tolerance parameter used in Bayes EM estimation. */ public BayesIm maximization(double tolerance) { TetradLogger.getInstance().log("details", "FactoredBayesStructuralEM.maximization()"); this.tolerance = tolerance; return iterate(); } /** * This iterate2 method also uses BdeMetricCache but it uses the factorScoreMD method which can handle missing data * and latent variables. Ths method iteratively score models and finds that which contains the graph of the highest * scoring model (via its BaysPm) as well as parameters which yield the best score given the dataset by using the * EmBayesEstimator class. * * @return the instantiated Bayes net (BayesIm) */ public BayesIm iterate() { double start = MillisecondTimes.timeMillis(); BdeMetricCache bdeMetricCache = new BdeMetricCache(this.dataSet, this.bayesPmM0); BayesPm bayesPmMnplus1 = this.bayesPmM0; BayesPm bayesPmMn; final double oldBestScore = Double.NEGATIVE_INFINITY; final int iteration = 0; //Loop for n = 0,1,... until convergence or timeout has been exceeded TimedIterate ti = new TimedIterate(bdeMetricCache, bayesPmMnplus1, oldBestScore, iteration, start); Thread tithread = new Thread(ti); tithread.start(); try { tithread.join(); } catch (InterruptedException e) { e.printStackTrace(); } bayesPmMnplus1 = ti.bayesPmMnplus1; bayesPmMn = bayesPmMnplus1; EmBayesEstimator emBayesEst = new EmBayesEstimator(bayesPmMn, this.dataSet); return emBayesEst.maximization(this.tolerance); } /* * This scoring method uses factor caching and the log gamma scoring function that * handles missing data and latent variables. The Bayes PM contains a graph which * indicates which variables are latent. */ public void scoreTest() { TetradLogger.getInstance().log("details", "scoreTest"); //System.out.println(bayesPmM0.getGraph()); BdeMetricCache bdeMetricCache; BayesPm bayesPmMn = this.bayesPmM0; EmBayesEstimator emBayesEst = new EmBayesEstimator(bayesPmMn, this.dataSet); emBayesEst.maximization(0.0001); Dag dag0 = new Dag(bayesPmMn.getDag()); Node L1 = dag0.getNode("L1"); Node X1 = dag0.getNode("X1"); Dag dag1 = new Dag(dag0); dag1.addDirectedEdge(X1, L1); BayesPm bayesPm0 = new BayesPm(dag0); EmBayesEstimator emBayesEst0 = new EmBayesEstimator(bayesPm0, this.dataSet); BayesIm bayesImMn0 = emBayesEst0.maximization(0.0001); BayesPm bayesPmTest0 = new BayesPm(dag0); TetradLogger.getInstance().log("details", "Observed conts for nodes of L1,X1,X2,X3 (no edges) " + "using the MAP parameters based on that same graph"); TetradLogger.getInstance().log("details", "Graph of PM: "); TetradLogger.getInstance().log("details", "" + bayesPmTest0.getDag()); TetradLogger.getInstance().log("details", "Graph of IM: "); TetradLogger.getInstance().log("details", "" + bayesImMn0.getBayesPm().getDag()); bdeMetricCache = new BdeMetricCache(this.dataSet, bayesPmTest0); List nodes0 = dag0.getNodes(); for (Node aNodes0 : nodes0) { double[][] counts0 = bdeMetricCache.getObservedCounts(aNodes0, bayesPmTest0, bayesImMn0); for (double[] aCounts0 : counts0) { for (int j = 0; j < counts0[0].length; j++) { System.out.print(" " + aCounts0[j]); } TetradLogger.getInstance().log("details", "\n"); } TetradLogger.getInstance().log("details", "\n"); } double score0 = FactoredBayesStructuralEM.factorScoreMD(dag0, bdeMetricCache, bayesPmTest0, bayesImMn0); TetradLogger.getInstance().log("details", "Score of L1,X1,X2,X3 (no edges) for itself = " + score0); TetradLogger.getInstance().log("details", "===============\n\n"); TetradLogger.getInstance().log("details", "Score of X1-->L1 for L1,X1,X2,X3 (no edges) = " + score0); BayesPm bayesPmTest1 = new BayesPm(dag1); TetradLogger.getInstance().log("details", "Observed counts for nodes of X1-->L1 for L1,X1,X2,X3 (no edges)"); TetradLogger.getInstance().log("details", "Graph of PM : "); TetradLogger.getInstance().log("details", "" + bayesPmTest1.getDag()); TetradLogger.getInstance().log("details", "Graph of IM: "); TetradLogger.getInstance().log("details", "" + bayesImMn0.getBayesPm().getDag()); bdeMetricCache = new BdeMetricCache(this.dataSet, bayesPmTest1); List nodes1 = dag0.getNodes(); for (Node aNodes1 : nodes1) { double[][] counts1 = bdeMetricCache.getObservedCounts(aNodes1, bayesPmTest1, bayesImMn0); for (double[] aCounts1 : counts1) { for (int j = 0; j < counts1[0].length; j++) { TetradLogger.getInstance().log("details", " " + aCounts1[j]); } TetradLogger.getInstance().log("details", "\n"); } TetradLogger.getInstance().log("details", "\n"); } double score1 = FactoredBayesStructuralEM.factorScoreMD(dag1, bdeMetricCache, bayesPmTest1, bayesImMn0); TetradLogger.getInstance().log("details", "Score of X1-->L1 for L1,X1,X2,X3 (no edges) = " + score1); } public DataSet getDataSet() { return this.dataSet; } private class TimedIterate implements Runnable { final BdeMetricCache bdeMetricCache; final double start; BayesPm bayesPmMnplus1; BayesPm bayesPmMn; double oldBestScore; int iteration; public TimedIterate(BdeMetricCache bdeMetricCache, BayesPm bayesPmMnplus1, double oldBestScore, int iteration, double start) { this.bdeMetricCache = bdeMetricCache; this.bayesPmMnplus1 = bayesPmMnplus1; this.bayesPmMn = null; this.oldBestScore = oldBestScore; this.iteration = iteration; this.start = start; } public void run() { while (!this.bayesPmMnplus1.equals(this.bayesPmMn)) { this.iteration++; this.bayesPmMn = this.bayesPmMnplus1; TetradLogger.getInstance().log("details", "In Factored Bayes Struct EM Iteration number " + this.iteration); //Compute the MAP parameters for Mn given o. TetradLogger.getInstance().log("details", "Starting EM Bayes estimator to get MAP parameters of Mn"); EmBayesEstimator emBayesEst = new EmBayesEstimator(this.bayesPmMn, FactoredBayesStructuralEM.this.dataSet); BayesIm bayesImMn = emBayesEst.maximization(FactoredBayesStructuralEM.this.tolerance); TetradLogger.getInstance().log("details", "Estimation of MAP parameters of Mn complete. \n\n"); //Perform search over models... Graph graphMn = this.bayesPmMn.getDag(); Dag dagMn = new Dag(graphMn); List models = ModelGenerator.generate(graphMn); double bestScore = FactoredBayesStructuralEM.factorScoreMD(dagMn, this.bdeMetricCache, this.bayesPmMn, bayesImMn); EdgeListGraph edges = new EdgeListGraph(dagMn); TetradLogger.getInstance().log("details", "Initial graph Mn = "); TetradLogger.getInstance().log("details", edges.toString()); TetradLogger.getInstance().log("details", "Its score = " + bestScore); for (Graph model : models) { Dag dag = new Dag(model); BayesPm bayesPmTest = new BayesPm(dag); //Having instantiated the BayesPm, set the number of categories correctly. for (int i = 0; i < FactoredBayesStructuralEM.this.dataSet.getVariables().size(); i++) { String varName = FactoredBayesStructuralEM.this.dataSet.getVariableNames().get(i); Node node = dag.getNode(varName); bayesPmTest.setNumCategories(node, FactoredBayesStructuralEM.this.ncategories[i]); } double score = FactoredBayesStructuralEM.factorScoreMD(dag, this.bdeMetricCache, bayesPmTest, bayesImMn); EdgeListGraph edgesTest = new EdgeListGraph(dag); TetradLogger.getInstance().log("details", "For the model with graph \n" + edgesTest); TetradLogger.getInstance().log("details", "Model Score = " + score); if (score <= bestScore) { continue; //This is not better than the best to date. } bestScore = score; //Let M sub n+1 be the model with the highest score amonth those encountered //during the search. this.bayesPmMnplus1 = bayesPmTest; } TetradLogger.getInstance().log("details", "In iteration: " + this.iteration); TetradLogger.getInstance().log("details", "bestScore, oldBestScore " + bestScore + " " + this.oldBestScore); EdgeListGraph edgesBest = new EdgeListGraph(this.bayesPmMnplus1.getDag()); TetradLogger.getInstance().log("details", "Graph of model: \n" + edgesBest); TetradLogger.getInstance().log("details", "===================================="); this.oldBestScore = bestScore; } } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy