All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.search.MixtureModel Maven / Gradle / Ivy

package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CovarianceMatrixOnTheFly;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.search.score.SemBicScore;
import edu.cmu.tetrad.util.Matrix;

/**
 * Represents a Gaussian mixture model -- a dataset with data sampled from two or more multivariate Gaussian
 * distributions.
 *
 * @author Madelyn Glymour
 */
public class MixtureModel {
    private final DataSet data;
    private final int[] cases;
    private final int[] caseCounts;
    private final double[][] dataArray;   // v-by-n data matrix
    private final double[][] meansArray;  // k-by-v matrix representing means for each variable for each of k models
    private final double[] weightsArray;  // array of length k representing weights for each model
    private final double[][] gammaArray;  // k-by-n matrix representing gamma for each data case in each model
    private final Matrix[] variancesArray; // k-by-v-by-v matrix representing covariance matrix for each of k models
    private final int numModels;  // number of models in mixture

    public MixtureModel(DataSet data, double[][] dataArray, double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, double[][] gammaArray) {
        this.data = data;
        this.dataArray = dataArray;
        this.meansArray = meansArray;
        this.weightsArray = weightsArray;
        this.variancesArray = variancesArray;
        this.numModels = weightsArray.length;
        this.gammaArray = gammaArray;
        this.cases = new int[data.getNumRows()];

        // set the individual model for each case
        for (int i = 0; i < cases.length; i++) {
            cases[i] = getDistribution(i);
        }

        this.caseCounts = new int[numModels];

        // count the number of cases in each individual data set
        for (int i = 0; i < numModels; i++) {
            caseCounts[i] = 0;
        }

        for (int aCase : cases) {
            for (int j = 0; j < numModels; j++) {
                if (aCase == j) {
                    caseCounts[j]++;
                    break;
                }
            }
        }
    }

    /**
     * @return the mixed data set in array form
     */
    public double[][] getData() {
        return dataArray;
    }

    /**
     * @return the means matrix
     */
    public double[][] getMeans() {
        return meansArray;
    }

    /**
     * @return the weights array
     */
    public double[] getWeights() {
        return weightsArray;
    }

    /**
     * @return the variance matrix
     */
    public Matrix[] getVariances() {
        return variancesArray;
    }

    /**
     * @return an array assigning each case an integer corresponding to a model
     */
    public int[] getCases() {
        return cases;
    }

    /**
     * Classifies a given case into a model, based on which model has the highest gamma value for that case.
     */
    public int getDistribution(int caseNum) {

        // hard classification
        int dist = 0;
        double highest = 0;

        for (int i = 0; i < numModels; i++) {
            if (gammaArray[i][caseNum] > highest) {
                highest = gammaArray[i][caseNum];
                dist = i;
            }

        }

        return dist;

        // soft classification, deprecated because it doesn't classify as well

        /*int gammaSum = 0;

        for (int i = 0; i < k; i++) {
            gammaSum += gammaArray[i][caseNum];
        }

        Random rand = new Random();
        double test = gammaSum * rand.nextDouble();

        if(test < gammaArray[0][caseNum]){
            return 0;
        }

        double sum = gammaArray[0][caseNum];

        for (int i = 1; i < k; i++){
            sum = sum+gammaArray[i][caseNum];
            if(test < sum){
                return i;
            }
        }

        return k - 1; */
    }

    /*
     * Sort the mixed data set into its component data sets.
     *
     * @return a list of data sets
     */
    public DataSet[] getDemixedData() {
        DoubleDataBox[] dataBoxes = new DoubleDataBox[numModels];
        int[] caseIndices = new int[numModels];

        for (int i = 0; i < numModels; i++) {
            dataBoxes[i] = new DoubleDataBox(caseCounts[i], data.getNumColumns());
            caseIndices[i] = 0;
        }

        int index;
        DoubleDataBox box;
        int count;
        for (int i = 0; i < cases.length; i++) {

            // get the correct data set and corresponding case count for this case
            index = cases[i];
            box = dataBoxes[index];
            count = caseIndices[index];

            // set the [count]th row of the given data set to the ith row of the mixed data set
            for (int j = 0; j < data.getNumColumns(); j++) {
                box.set(count, j, data.getDouble(i, j));
            }

            dataBoxes[index] = box; //make sure that the changes get carried to the next iteration of the loop
            caseIndices[index] = count + 1; //increment case count of this data set
        }

        // create list of data sets
        DataSet[] dataSets = new DataSet[numModels];
        for (int i = 0; i < numModels; i++) {
            dataSets[i] = new BoxDataSet(dataBoxes[i], data.getVariables());
        }

        return dataSets;
    }

    /**
     * Perform an FGES search on each of the demixed data sets.
     *
     * @return the BIC scores of the graphs returned by searches.
     */
    public double[] searchDemixedData() {
        DataSet[] dataSets = getDemixedData();
        SemBicScore score;
        edu.cmu.tetrad.search.Fges fges;
        DataSet dataSet;
        double bic;
        double[] bicScores = new double[numModels];

        for (int i = 0; i < numModels; i++) {
            dataSet = dataSets[i];
            score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
            score.setPenaltyDiscount(2.0);
            fges = new edu.cmu.tetrad.search.Fges(score);
            fges.search();
            bic = fges.getModelScore();
            bicScores[i] = bic;
        }

        return bicScores;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy