All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.study.performance.ComparisonScript Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
package edu.cmu.tetrad.study.performance;

import edu.cmu.tetrad.sem.ScoreType;
import edu.cmu.tetrad.util.TextTable;

import java.util.ArrayList;
import java.util.List;

/**
 * Runs algorithm on data set (simulation is OK), printing out error statistics.
 *
 * @author josephramsey 2016.03.24
 * @author dmalinsky 2016.03.28
 */
public class ComparisonScript {

    public static void main(String... args) {
        new ComparisonScript().runFromSimulation();
    }

    private void runFromSimulation() {
        ComparisonParameters params = new ComparisonParameters();
        params.setDataType(ComparisonParameters.DataType.Continuous); // Continuous or Discrete
        params.setNumVars(12); // number of variables
        params.setNumEdges(12); // number of edges
        int minSample = 100; // smallest sample size to generate
        int maxSample = 1000; // largest sample size to generate
        int increment = 100; // ramp up sample size by this increment
        int numTrials = 2; // number of data sets to run for a particular sample size, results will be averaged

        // If you want to run data sets from file instead of generating random graphs **/
        params.setDataFromFile(false); // set this to true
        final int maxGraphs = 2; // how many true graphs are in your directory?
        final int dataSetsPerGraph = 3; // how many data sets are there in your directory for each true graph?
        // remember the path to the data directory is set in Comparison2.java
        // ******************** **/

        // If you want to run on NO DATA, i.e., just run each algorithm directly on some random true graphs **/
        params.setNoData(false); // set this to true
        // note that the number of random graphs will be equal to numTrials, set above
        // ******************** **/

        if (params.isDataFromFile() && params.isNoData()) {
            throw new IllegalArgumentException("Cannot have setDataFromFile and setNoData both be true!");
        }

        ArrayList tableColumns = new ArrayList<>();
        tableColumns.add(Comparison2.TableColumn.AdjPrec);
        tableColumns.add(Comparison2.TableColumn.AdjRec);
        tableColumns.add(Comparison2.TableColumn.AhdPrec);
        tableColumns.add(Comparison2.TableColumn.AhdRec);
        tableColumns.add(Comparison2.TableColumn.SHD);
        tableColumns.add(Comparison2.TableColumn.Elapsed);

        List algList = new ArrayList<>();

        // add algorithm to compare to the list algList. comment out those you don't want to consider. **/
        algList.add(ComparisonParameters.Algorithm.SVARFCI);

        // User shouldn't need to change anything below this line **/
        //*********************************************************/

        if (params.isDataFromFile()) {
            System.out.println("running algorithm on data from input files");
            minSample = 1;
            maxSample = maxGraphs;
            increment = 1;
            numTrials = dataSetsPerGraph;
        }

        if (params.isNoData()) {
            System.out.println("running algorithm on NO DATA, only true graph");
            minSample = 1;
            maxSample = 1;
            increment = 1;
        }

        int count = 0;
        TextTable avgTable = new TextTable((((maxSample - minSample) / increment) + 1) * algList.size() + 2, tableColumns.size() + 1);
        for (int sampleSize = minSample; sampleSize <= maxSample; sampleSize += increment) {
            params.setSampleSize(sampleSize);
            if (params.isDataFromFile()) {
                params.setGraphNum(sampleSize);
                System.out.println("graph file number = " + sampleSize);
            } else System.out.println("sample size = " + sampleSize);


            for (ComparisonParameters.Algorithm alg : algList) {
                count++;
                params.setAlgorithm(alg);
                params.setIndependenceTest(ComparisonParameters.IndependenceTestType.FisherZ);
                params.setScore(ScoreType.SemBic);
                // params.setOneEdgeFaithfulnessAssumed(false);

                List resultsTrials = new ArrayList<>();
                for (int trial = 1; trial <= numTrials; trial++) {
                    params.setTrial(trial);
                    resultsTrials.add(Comparison2.compare(params));
                }
                TextTable tempTable = new TextTable(numTrials + 2, tableColumns.size() + 1);
                tempTable = Comparison2.summarize(resultsTrials, tableColumns);

                System.out.println(tempTable);

                for (int k = 0; k <= tempTable.getNumColumns() - 1; k++) {
                    avgTable.setToken(count, k, tempTable.getTokenAt(tempTable.getNumRows() - 1, k));
                }

                if (params.isDataFromFile()) {
                    avgTable.setToken(count, 0, "G=" + sampleSize + ", alg = " + alg);
                } else if (params.isNoData()) {
                    avgTable.setToken(count, 0, "N=" + 0 + ", alg = " + alg);
                } else avgTable.setToken(count, 0, "N=" + sampleSize + ", alg = " + alg);

            } // loop over algorithm in algList
        } // loop over sample sizes

        // add column names, then print table
        for (int j = 0; j <= tableColumns.size() - 1; j++) {
            avgTable.setToken(0, j, tableColumns.get(j).toString());
        }
        System.out.println(avgTable);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy