All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.algcomparison.simulation.LeeHastieSimulation Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
package edu.cmu.tetrad.algcomparison.simulation;

import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.cmu.tetrad.util.RandomUtil;
import edu.pitt.csb.mgm.MixedUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

/**
 * A version of the Lee and Hastic simulation which is guaranteed ot generate a discrete data set.
 *
 * @author josephramsey
 */
public class LeeHastieSimulation implements Simulation {

    private static final long serialVersionUID = 23L;
    private final RandomGraph randomGraph;
    private List dataSets = new ArrayList<>();
    private List graphs = new ArrayList<>();
    private DataType dataType;
    private List shuffledOrder;

    public LeeHastieSimulation(RandomGraph graph) {
        this.randomGraph = graph;
    }

    @Override
    public void createData(Parameters parameters, boolean newModel) {
//        if (parameters.getLong(Params.SEED) != -1L) {
//            RandomUtil.getInstance().setSeed(parameters.getLong(Params.SEED));
//        }

        double percentDiscrete = parameters.getDouble(Params.PERCENT_DISCRETE);

        boolean discrete = parameters.getString(Params.DATA_TYPE).equals("discrete");
        boolean continuous = parameters.getString(Params.DATA_TYPE).equals("continuous");

        if (discrete && percentDiscrete != 100.0) {
            throw new IllegalArgumentException("To simulate discrete data, 'percentDiscrete' must be set to 0.0.");
        } else if (continuous && percentDiscrete != 0.0) {
            throw new IllegalArgumentException("To simulate continuoue data, 'percentDiscrete' must be set to 100.0.");
        }

        if (discrete) {
            this.dataType = DataType.Discrete;
        }
        if (continuous) {
            this.dataType = DataType.Continuous;
        }

        this.shuffledOrder = null;

        Graph graph = this.randomGraph.createGraph(parameters);

        this.dataSets = new ArrayList<>();
        this.graphs = new ArrayList<>();

        for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
            System.out.println("Simulating dataset #" + (i + 1));

            if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
                graph = this.randomGraph.createGraph(parameters);
            }

            this.graphs.add(graph);

            DataSet dataSet = simulate(graph, parameters);
            dataSet.setName("" + (i + 1));

            if (parameters.getBoolean(Params.RANDOMIZE_COLUMNS)) {
                dataSet = DataTransforms.shuffleColumns(dataSet);
            }

            if (parameters.getDouble(Params.PROB_REMOVE_COLUMN) > 0) {
                double aDouble = parameters.getDouble(Params.PROB_REMOVE_COLUMN);
                dataSet = DataTransforms.removeRandomColumns(dataSet, aDouble);
            }

            this.dataSets.add(dataSet);
        }
    }

    @Override
    public Graph getTrueGraph(int index) {
        return this.graphs.get(index);
    }

    @Override
    public DataModel getDataModel(int index) {
        return this.dataSets.get(index);
    }

    @Override
    public String getDescription() {
        return "Lee & Hastie simulation using " + this.randomGraph.getDescription();
    }

    @Override
    public List getParameters() {
        List parameters = this.randomGraph.getParameters();
        parameters.add(Params.MIN_CATEGORIES);
        parameters.add(Params.MAX_CATEGORIES);
        parameters.add(Params.PERCENT_DISCRETE);
        parameters.add(Params.NUM_RUNS);
        parameters.add(Params.PROB_REMOVE_COLUMN);
        parameters.add(Params.DIFFERENT_GRAPHS);
        parameters.add(Params.RANDOMIZE_COLUMNS);
        parameters.add(Params.SAMPLE_SIZE);
        parameters.add(Params.SAVE_LATENT_VARS);

        parameters.add(Params.VERBOSE);
//        parameters.add(Params.SEED);

        return parameters;
    }

    @Override
    public int getNumDataModels() {
        return this.dataSets.size();
    }

    @Override
    public DataType getDataType() {
        return this.dataType;
    }

    private DataSet simulate(Graph dag, Parameters parameters) {
        HashMap nd = new HashMap<>();

        List nodes = dag.getNodes();

        List shuffledNodes = new ArrayList<>(nodes);
        RandomUtil.shuffle(shuffledNodes);

        if (this.shuffledOrder == null) {
            this.shuffledOrder = shuffledNodes;
        }

        for (int i = 0; i < this.shuffledOrder.size(); i++) {
            if (i < this.shuffledOrder.size() * parameters.getDouble(Params.PERCENT_DISCRETE) * 0.01) {
                int minNumCategories = parameters.getInt(Params.MIN_CATEGORIES);
                int maxNumCategories = parameters.getInt(Params.MAX_CATEGORIES);
                int value = pickNumCategories(minNumCategories, maxNumCategories);
                nd.put(this.shuffledOrder.get(i).getName(), value);
            } else {
                nd.put(this.shuffledOrder.get(i).getName(), 0);
            }
        }

        Graph graph = MixedUtils.makeMixedGraph(dag, nd);

        GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(graph, "Split(-1.0,-.0,.0,1.0)");
        GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);

        boolean saveLatentVars = parameters.getBoolean(Params.SAVE_LATENT_VARS);
        DataSet ds = im.simulateDataAvoidInfinity(parameters.getInt(Params.SAMPLE_SIZE), saveLatentVars);

        return MixedUtils.makeMixedData(ds, nd);
    }

    private int pickNumCategories(int min, int max) {
        return min + RandomUtil.getInstance().nextInt(max - min + 1);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy