All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.bayes.DirichletEstimator Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Node;

/**
 * Estimates a DirichletBayesIm from a DirichletBayesIm (the prior) and a data set.
 *
 * @author josephramsey
 * @version $Id: $Id
 */
public final class DirichletEstimator {

    /**
     * Creates a new instance of DirichletEstimator.
     */
    public DirichletEstimator() {
    }

    /**
     * 

estimate.

* * @param prior a {@link edu.cmu.tetrad.bayes.DirichletBayesIm} object * @param dataSet a {@link edu.cmu.tetrad.data.DataSet} object * @return a {@link edu.cmu.tetrad.bayes.DirichletBayesIm} object */ public static DirichletBayesIm estimate(DirichletBayesIm prior, DataSet dataSet) { if (prior == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } // if (DataUtils.containsMissingValue(dataSet)) { // throw new IllegalArgumentException("Please remove or impute missing values."); // } // Make sure all of the variables in the PM are in the data set; // otherwise, estimation is impossible. BayesUtils.ensureVarsInData(prior.getVariables(), dataSet); // Create the posterior. BayesPm bayesPm = prior.getBayesPm(); DirichletBayesIm posterior = DirichletBayesIm.blankDirichletIm(bayesPm); // Number of rows of data int numPoints = dataSet.getNumRows(); // Loop over all nodes in prior. for (int n = 0; n < prior.getNumNodes(); ++n) { // Make any easy access table of node data @ 0 and parent data @ p+1. int[] varIndices = new int[prior.getNumParents(n) + 1]; Node node = prior.getNode(n); String name = node.getName(); varIndices[0] = dataSet.getColumn(dataSet.getVariable(name)); for (int p = 0; p < prior.getNumParents(n); p++) { Node parentNode = prior.getNode(prior.getParent(n, p)); name = parentNode.getName(); varIndices[p + 1] = dataSet.getColumn(dataSet.getVariable(name)); } // Loop over conditioning set. for (int row = 0; row < prior.getNumRows(n); row++) { int numCategories = bayesPm.getNumCategories(node); // Count occurrences of category. int[] nCount = new int[numCategories]; int[] pVals = prior.getParentValues(n, row); // Count the occurrence of each category satisfying the // various condition in the data. for (int i = 0; i < numPoints; i++) { //first make sure conditions are satisfied boolean satisfied = true; for (int p = 0; p < prior.getNumParents(n); p++) { // Ignore cases where one of the parents has a // missing value. if (dataSet.getInt(i, varIndices[p + 1]) == DiscreteVariable.MISSING_VALUE) { satisfied = false; break; } if (pVals[p] != dataSet.getInt(i, varIndices[p + 1])) { satisfied = false; break; } } if (dataSet.getInt(i, varIndices[0]) == DiscreteVariable .MISSING_VALUE) { satisfied = false; } if (satisfied) { nCount[dataSet.getInt(i, varIndices[0])]++; } } // include prior for (int i = 0; i < numCategories; ++i) { double priorValue = prior.getPseudocount(n, row, i); double value = nCount[i] + priorValue; posterior.setPseudocount(n, row, i, value); } } } return posterior; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy