edu.cmu.tetrad.search.work_in_progress.MixtureModel Maven / Gradle / Ivy
The newest version!
package edu.cmu.tetrad.search.work_in_progress;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CovarianceMatrixOnTheFly;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.search.score.SemBicScore;
import edu.cmu.tetrad.util.Matrix;
/**
* Represents a Gaussian mixture model -- a dataset with data sampled from two or more multivariate Gaussian
* distributions.
*
* @author Madelyn Glymour
* @version $Id: $Id
*/
public class MixtureModel {
// The mixed data set
private final DataSet data;
// The individual data sets
private final int[] cases;
// The number of cases in each individual data set
private final int[] caseCounts;
// The data set in array form
private final double[][] dataArray;
// The means matrix
private final double[][] meansArray;
// The weights array
private final double[] weightsArray;
// The gamma matrix
private final double[][] gammaArray;
// The variance matrix
private final Matrix[] variancesArray;
// The number of models in the mixture
private final int numModels;
/**
* Constructs a mixture model from a mixed data set, a means matrix, a weights array, a variance matrix, and a gamma
* matrix.
*
* @param data the mixed data set
* @param dataArray the mixed data set in array form
* @param meansArray the means matrix
* @param weightsArray the weights array
* @param variancesArray the variance matrix
* @param gammaArray the gamma matrix
*/
public MixtureModel(DataSet data, double[][] dataArray, double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, double[][] gammaArray) {
this.data = data;
this.dataArray = dataArray;
this.meansArray = meansArray;
this.weightsArray = weightsArray;
this.variancesArray = variancesArray;
this.numModels = weightsArray.length;
this.gammaArray = gammaArray;
this.cases = new int[data.getNumRows()];
// set the individual model for each case
for (int i = 0; i < cases.length; i++) {
cases[i] = getDistribution(i);
}
this.caseCounts = new int[numModels];
// count the number of cases in each individual data set
for (int i = 0; i < numModels; i++) {
caseCounts[i] = 0;
}
for (int aCase : cases) {
for (int j = 0; j < numModels; j++) {
if (aCase == j) {
caseCounts[j]++;
break;
}
}
}
}
/**
* Getter for the field data
.
*
* @return the mixed data set in array form
*/
public double[][] getData() {
return dataArray;
}
/**
* getMeans.
*
* @return the means matrix
*/
public double[][] getMeans() {
return meansArray;
}
/**
* getWeights.
*
* @return the weights array
*/
public double[] getWeights() {
return weightsArray;
}
/**
* getVariances.
*
* @return the variance matrix
*/
public Matrix[] getVariances() {
return variancesArray;
}
/**
* Getter for the field cases
.
*
* @return an array assigning each case an integer corresponding to a model
*/
public int[] getCases() {
return cases;
}
/**
* Classifies a given case into a model, based on which model has the highest gamma value for that case.
*
* @param caseNum a int
* @return a int
*/
public int getDistribution(int caseNum) {
// hard classification
int dist = 0;
double highest = 0;
for (int i = 0; i < numModels; i++) {
if (gammaArray[i][caseNum] > highest) {
highest = gammaArray[i][caseNum];
dist = i;
}
}
return dist;
// soft classification, deprecated because it doesn't classify as well
/*int gammaSum = 0;
for (int i = 0; i < k; i++) {
gammaSum += gammaArray[i][caseNum];
}
Random rand = new Random();
double test = gammaSum * rand.nextDouble();
if(test < gammaArray[0][caseNum]){
return 0;
}
double sum = gammaArray[0][caseNum];
for (int i = 1; i < k; i++){
sum = sum+gammaArray[i][caseNum];
if(test < sum){
return i;
}
}
return k - 1; */
}
/*
* Sort the mixed data set into its component data sets.
*
* @return a list of data sets
*/
/**
* getDemixedData.
*
* @return an array of {@link edu.cmu.tetrad.data.DataSet} objects
*/
public DataSet[] getDemixedData() {
DoubleDataBox[] dataBoxes = new DoubleDataBox[numModels];
int[] caseIndices = new int[numModels];
for (int i = 0; i < numModels; i++) {
dataBoxes[i] = new DoubleDataBox(caseCounts[i], data.getNumColumns());
caseIndices[i] = 0;
}
int index;
DoubleDataBox box;
int count;
for (int i = 0; i < cases.length; i++) {
// get the correct data set and corresponding case count for this case
index = cases[i];
box = dataBoxes[index];
count = caseIndices[index];
// set the [count]th row of the given data set to the ith row of the mixed data set
for (int j = 0; j < data.getNumColumns(); j++) {
box.set(count, j, data.getDouble(i, j));
}
dataBoxes[index] = box; //make sure that the changes get carried to the next iteration of the loop
caseIndices[index] = count + 1; //increment case count of this data set
}
// create list of data sets
DataSet[] dataSets = new DataSet[numModels];
for (int i = 0; i < numModels; i++) {
dataSets[i] = new BoxDataSet(dataBoxes[i], data.getVariables());
}
return dataSets;
}
/**
* Perform an FGES search on each of the demixed data sets.
*
* @return the BIC scores of the graphs returned by searches.
*/
public double[] searchDemixedData() {
DataSet[] dataSets = getDemixedData();
SemBicScore score;
edu.cmu.tetrad.search.Fges fges;
DataSet dataSet;
double bic;
double[] bicScores = new double[numModels];
for (int i = 0; i < numModels; i++) {
dataSet = dataSets[i];
score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
score.setPenaltyDiscount(2.0);
fges = new edu.cmu.tetrad.search.Fges(score);
fges.search();
bic = fges.getModelScore();
bicScores[i] = bic;
}
return bicScores;
}
}