All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.classify.ClassifierMbDiscrete Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.classify;

import edu.cmu.tetrad.bayes.*;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.Pc;
import edu.cmu.tetrad.search.PcMb;
import edu.cmu.tetrad.search.test.IndTestChiSquare;
import edu.cmu.tetrad.search.utils.MbUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradLogger;
import edu.pitt.dbmi.data.reader.Delimiter;

import java.io.File;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.*;

/**
 * Performs a Bayesian classification of a test set based on a given training set. PC-MB is used to select a Markov
 * blanket DAG of the target; this DAG is used to estimate a Bayes model using the training data. The Bayes model is
 * then updated for each case in the test data to produce classifications.
 *
 * @author Frank Wimberly
 * @author josephramsey
 * @version $Id: $Id
 */
public class ClassifierMbDiscrete implements ClassifierDiscrete {

    /**
     * Train data.
     */
    private DataSet train;

    /**
     * Test data.
     */
    private DataSet test;

    /**
     * Target variable.
     */
    private Node target;
    private double alpha;

    /**
     * Depth for PC-MB search.
     */
    private int depth;

    /**
     * Prior for Dirichlet estimator.
     */
    private double prior;

    /**
     * Maximum number of missing values for a test case.
     */
    private int maxMissing;

    /**
     * Target variable.
     */
    private DiscreteVariable targetVariable;

    /**
     * Percent correct.
     */
    private double percentCorrect;

    /**
     * Cross-tabulation.
     */
    private int[][] crossTabulation;

    //============================CONSTRUCTOR===========================//

    /**
     * Constructs a new ClassifierMbDiscrete object using the given training and test data, target variable, alpha
     * value,
     *
     * @param trainPath        the path to the training data file
     * @param testPath         the path to the test data file
     * @param targetString     the name of the target variable
     * @param alphaString      the alpha value for the Dirichlet estimator
     * @param depthString      the depth for the PC-MB search
     * @param priorString      the prior for the Dirichlet estimator
     * @param maxMissingString the maximum number of missing values for a test case
     */
    public ClassifierMbDiscrete(String trainPath, String testPath, String targetString,
                                String alphaString, String depthString, String priorString, String maxMissingString) {
        try {
            String s = "MbClassify " +
                       trainPath + " " +
                       testPath + " " +
                       targetString + " " +
                       alphaString + " " +
                       depthString + " " +
                       priorString + " " +
                       maxMissingString + " ";

            TetradLogger.getInstance().log(s);

            DataSet train = SimpleDataLoader.loadContinuousData(new File(trainPath), "//", '\"',
                    "*", true, Delimiter.TAB, false);
            DataSet test = SimpleDataLoader.loadContinuousData(new File(testPath), "//", '\"',
                    "*", true, Delimiter.TAB, false);

            double alpha = Double.parseDouble(alphaString);
            int depth = Integer.parseInt(depthString);
            double prior = Double.parseDouble(priorString);
            int maxMissing = Integer.parseInt(maxMissingString);

            setup(train, test, target, alpha, depth, prior, maxMissing);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Runs MbClassify using moves-line arguments. The syntax is:
     * 
     * java MbClassify train.dat test.dat target alpha depth
     * 
* * @param args train.dat test.dat alpha depth dirichlet_prior max_missing */ public static void main(String[] args) { String trainPath = args[0]; String testPath = args[1]; String targetString = args[2]; String alphaString = args[3]; String depthString = args[4]; String priorString = args[5]; String maxMissingString = args[6]; new ClassifierMbDiscrete(trainPath, testPath, targetString, alphaString, depthString, priorString, maxMissingString); } //============================PUBLIC METHODS=========================// private void setup(DataSet train, DataSet test, Node target, double alpha, int depth, double prior, int maxMissing) { this.train = train; this.test = test; this.alpha = alpha; this.target = target; this.depth = depth; this.prior = prior; this.maxMissing = maxMissing; this.targetVariable = (DiscreteVariable) target; if (this.targetVariable == null) { throw new IllegalArgumentException("Target variable not in data: " + target); } } /** * Classifies the test data by Bayesian updating. The procedure is as follows. First, PC-MB is run on the training * data to estimate an MB CPDAG. Bidirected edges are removed; an MB DAG G is selected from the CPDAG that remains. * Second, a Bayes model B is estimated using this G and the training data. Third, for each case in the test data, * the marginal for the target variable in B is calculated conditioning on values of the other varialbes in B in the * test data; these are reported as classifications. Estimation of B is done using a Dirichlet estimator, with a * symmetric prior, with the given alpha value. Updating is done using a row-summing exact updater. *

* One consequence of using the row-summing exact updater is that classification will be fast except for cases in * which there are lots of missing values. The reason for this is that for such cases the number of rows that need * to be summed over will be exponential in the number of missing values for that case. Hence the parameter for max * num missing values. A good default for this is like 5. Any test case with more than that number of missing values * will be skipped. * * @return The classifications. */ public int[] classify() { IndependenceTest indTest = new IndTestChiSquare(this.train, this.alpha); PcMb search = new PcMb(indTest, this.depth); search.setDepth(this.depth); Set mbPlusTarget = search.findMb(this.target); mbPlusTarget.add(this.target); ArrayList vars = new ArrayList<>(mbPlusTarget); Collections.sort(vars); DataSet subset = this.train.subsetColumns(vars); System.out.println("subset vars = " + subset.getVariables()); Pc cpdagSearch = new Pc(new IndTestChiSquare(subset, 0.05)); Graph mbCPDAG = cpdagSearch.search(); TetradLogger.getInstance().log("CPDAG = " + mbCPDAG); MbUtils.trimToMbNodes(mbCPDAG, this.target, true); TetradLogger.getInstance().log("Trimmed CPDAG = " + mbCPDAG); // Removing bidirected edges from the CPDAG before selecting a DAG. 4 for (Edge edge : mbCPDAG.getEdges()) { if (Edges.isBidirectedEdge(edge)) { mbCPDAG.removeEdge(edge); } } Graph selectedDag = MbUtils.getOneMbDag(mbCPDAG); TetradLogger.getInstance().log("Selected DAG = " + selectedDag); String message1 = "Vars = " + selectedDag.getNodes(); TetradLogger.getInstance().log(message1); TetradLogger.getInstance().log("\nClassification using selected MB DAG:"); NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat(); List mbNodes = selectedDag.getNodes(); //The Markov blanket nodes will correspond to a subset of the variables //in the training dataset. Find the subset dataset. DataSet trainDataSubset = this.train.subsetColumns(mbNodes); //To create a Bayes net for the Markov blanket we need the DAG. BayesPm bayesPm = new BayesPm(selectedDag); //To parameterize the Bayes net we need the number of values //of each variable. List varsTrain = trainDataSubset.getVariables(); for (int i1 = 0; i1 < varsTrain.size(); i1++) { DiscreteVariable trainingVar = (DiscreteVariable) varsTrain.get(i1); bayesPm.setCategories(mbNodes.get(i1), trainingVar.getCategories()); } //Create an updater for the instantiated Bayes net. TetradLogger.getInstance().log("Estimating Bayes net; please wait..."); DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPm, this.prior); BayesIm bayesIm = DirichletEstimator.estimate(prior, trainDataSubset); RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm); //The subset dataset of the dataset to be classified containing //the variables in the Markov blanket. DataSet testSubset = this.test.subsetColumns(mbNodes); //Get the raw data from the dataset to be classified, the number //of variables, and the number of cases. int numCases = testSubset.getNumRows(); int[] estimatedCategories = new int[numCases]; Arrays.fill(estimatedCategories, -1); //The variables in the dataset. List varsClassify = testSubset.getVariables(); //For each case in the dataset to be classified compute the estimated //value of the target variable and increment the appropriate element //of the crosstabulation array. for (int k = 0; k < numCases; k++) { //Create an Evidence instance for the instantiated Bayes net //which will allow that updating. Proposition proposition = Proposition.tautology(bayesIm); //Restrict all other variables to their observed values in //this case. int numMissing = 0; for (int testIndex = 0; testIndex < varsClassify.size(); testIndex++) { DiscreteVariable var = (DiscreteVariable) varsClassify.get(testIndex); // If it's the target, ignore it. if (var.equals(this.targetVariable)) { continue; } int trainIndex = proposition.getNodeIndex(var.getName()); // If it's not in the train subset, ignore it. if (trainIndex == -99) { continue; } int testValue = testSubset.getInt(k, testIndex); if (testValue == -99) { numMissing++; } else { proposition.setCategory(trainIndex, testValue); } } if (numMissing > this.maxMissing) { TetradLogger.getInstance().log("classification(" + k + ") = " + "not done since number of missing values too high " + "(" + numMissing + ")."); continue; } Evidence evidence = Evidence.tautology(bayesIm); evidence.getProposition().restrictToProposition(proposition); updater.setEvidence(evidence); // for each possible value of target compute its probability in // the updated Bayes net. Select the value with the highest // probability as the estimated getValue. int targetIndex = proposition.getNodeIndex(this.targetVariable.getName()); //Straw man values--to be replaced. double highestProb = -0.1; int _category = -1; for (int category = 0; category < this.targetVariable.getNumCategories(); category++) { double marginal = updater.getMarginal(targetIndex, category); if (marginal > highestProb) { highestProb = marginal; _category = category; } } //Sometimes the marginal cannot be computed because certain //combinations of values of the variables do not occur in the //training dataset. If that happens skip the case. if (_category < 0) { System.out.println("classification(" + k + ") is undefined " + "(undefined marginals)."); continue; } String estimatedCategory = this.targetVariable.getCategories().get(_category); TetradLogger.getInstance().log("classification(" + k + ") = " + estimatedCategory); estimatedCategories[k] = _category; } //Create a crosstabulation table to store the coefs of observed //versus estimated occurrences of each value of the target variable. int targetIndex = varsClassify.indexOf(this.targetVariable); int numCategories = this.targetVariable.getNumCategories(); int[][] crossTabs = new int[numCategories][numCategories]; //Will count the number of cases where the target variable //is correctly classified. int numberCorrect = 0; int numberCounted = 0; for (int k = 0; k < numCases; k++) { int estimatedCategory = estimatedCategories[k]; int observedValue = testSubset.getInt(k, targetIndex); if (estimatedCategory < 0) { continue; } crossTabs[observedValue][estimatedCategory]++; numberCounted++; if (observedValue == estimatedCategory) { numberCorrect++; } } double percentCorrect1 = 100.0 * ((double) numberCorrect) / ((double) numberCounted); // Print the cross classification. TetradLogger.getInstance().log(""); TetradLogger.getInstance().log("\t\t\tEstimated\t"); TetradLogger.getInstance().log("Observed\t"); StringBuilder buf0 = new StringBuilder(); buf0.append("\t"); for (int m = 0; m < numCategories; m++) { buf0.append(this.targetVariable.getCategory(m)).append("\t"); } TetradLogger.getInstance().log(buf0.toString()); for (int k = 0; k < numCategories; k++) { StringBuilder buf = new StringBuilder(); buf.append(this.targetVariable.getCategory(k)).append("\t"); for (int m = 0; m < numCategories; m++) buf.append(crossTabs[k][m]).append("\t"); TetradLogger.getInstance().log(buf.toString()); } TetradLogger.getInstance().log(""); TetradLogger.getInstance().log("Number correct = " + numberCorrect); TetradLogger.getInstance().log("Number counted = " + numberCounted); String message = "Percent correct = " + nf.format(percentCorrect1) + "%"; TetradLogger.getInstance().log(message); this.crossTabulation = crossTabs; this.percentCorrect = percentCorrect1; return estimatedCategories; } /** *

crossTabulation.

* * @return the cross-tabulation from the classify method. The classify method must be run first. */ public int[][] crossTabulation() { return this.crossTabulation; } /** *

Getter for the field percentCorrect.

* * @return the percent correct from the classify method. The classify method must be run first. */ public double getPercentCorrect() { return this.percentCorrect; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy