All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.pitt.dbmi.algo.resampling.GeneralResamplingSearch Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
package edu.pitt.dbmi.algo.resampling;

import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.task.GeneralResamplingSearchRunnable;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.apache.commons.math3.random.Well44497b;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;

/**
 * Sep 7, 2018 1:38:50 PM
 *
 * @author Chirayu Kong Wongchokprasitti, PhD ([email protected])
 */
public class GeneralResamplingSearch {

    private final int numberResampling;
    private final List graphs = Collections.synchronizedList(new ArrayList<>());
    private final ForkJoinPool pool;
    private Algorithm algorithm;
    private MultiDataSetAlgorithm multiDataSetAlgorithm;
    private double percentResampleSize = 100.;
    private boolean resamplingWithReplacement = true;
    private boolean runParallel;
    private boolean addOriginalDataset;
    private boolean verbose;
    private DataSet data;

    private List dataSets;

    /**
     * Specification of forbidden and required edges.
     */
    private Knowledge knowledge = new Knowledge();

    private PrintStream out = System.out;

    private Parameters parameters;

    /**
     * An initial graph to start from.
     */
    private Graph externalGraph;
    private int numNograph = 0;
    private ScoreWrapper scoreWrapper;

    public GeneralResamplingSearch(DataSet data, int numberResampling) {
        this.data = data;
        this.pool = ForkJoinPool.commonPool();
        this.numberResampling = numberResampling;
    }

    public GeneralResamplingSearch(List dataSets, int numberResampling) {
        this.dataSets = dataSets;
        this.pool = ForkJoinPool.commonPool();
        this.numberResampling = numberResampling;
    }

    public void setAlgorithm(Algorithm algorithm) {
        this.algorithm = algorithm;
        this.multiDataSetAlgorithm = null;
    }

    public void setMultiDataSetAlgorithm(MultiDataSetAlgorithm multiDataSetAlgorithm) {
        this.multiDataSetAlgorithm = multiDataSetAlgorithm;
        this.algorithm = null;
    }

    public void setPercentResampleSize(double percentResampleSize) {
        this.percentResampleSize = percentResampleSize;
    }

    public void setResamplingWithReplacement(boolean resamplingWithReplacement) {
        this.resamplingWithReplacement = resamplingWithReplacement;
    }

    public void setRunParallel(boolean runParallel) {
        this.runParallel = runParallel;
    }

    public void setAddOriginalDataset(boolean addOriginalDataset) {
        this.addOriginalDataset = addOriginalDataset;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setData(DataSet data) {
        this.data = data;
    }

    /**
     * Sets the background knowledge.
     *
     * @param knowledge the knowledge object, specifying forbidden and required edges.
     */
    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge((Knowledge) knowledge);
    }

    public void setExternalGraph(Graph externalGraph) {
        this.externalGraph = externalGraph;
    }

    /**
     * @return the output stream that output (except for log output) should be sent to.
     */
    public PrintStream getOut() {
        return this.out;
    }

    /**
     * Sets the output stream that output (except for log output) should be sent to. By default System.out.
     */
    public void setOut(PrintStream out) {
        this.out = out;
    }

    public void setParameters(Parameters parameters) {
        this.parameters = parameters;
    }

    public List search() {

        this.graphs.clear();
        this.parameters.set("numberResampling", 0); // This needs to be set to zero to not loop indefinitely

        List> tasks = new ArrayList<>();

        // Running in the sequential form
        if (this.verbose) {
            this.out.println("Running Resamplings in Sequential Mode, numberResampling = " + this.numberResampling);
        }

        if (this.data != null) {
            Long seed = (parameters == null || parameters.get(Params.SEED) == null) ? null : (Long) parameters.get(Params.SEED);
            RandomGenerator randomGenerator = (seed == null || seed < 0) ? null : new SynchronizedRandomGenerator(new Well44497b(seed));
            for (int i1 = 0; i1 < this.numberResampling; i1++) {
                DataSet dataSet;

                if (this.resamplingWithReplacement) {
                    if ((randomGenerator == null)) {
                        int sampleSize = (int) (data.getNumRows() * this.percentResampleSize / 100.0);
                        dataSet = DataTransforms.getBootstrapSample(data, sampleSize);
                    } else {
                        int sampleSize = (int) (data.getNumRows() * this.percentResampleSize / 100.0);
                        dataSet = DataTransforms.getBootstrapSample(data, sampleSize, randomGenerator);
                    }
                } else {
                    if ((randomGenerator == null)) {
                        int sampleSize = (int) (data.getNumRows() * this.percentResampleSize / 100.0);
                        dataSet = DataTransforms.getResamplingDataset(data, sampleSize);
                    } else {
                        int sampleSize = (int) (data.getNumRows() * this.percentResampleSize / 100.0);
                        dataSet = DataTransforms.getResamplingDataset(data, sampleSize, randomGenerator);
                    }
                }

                dataSet.setKnowledge(data.getKnowledge());

                GeneralResamplingSearchRunnable task = new GeneralResamplingSearchRunnable(dataSet, this.algorithm, this.parameters, this, this.verbose);
                task.setExternalGraph(this.externalGraph);
                task.setKnowledge(this.knowledge);
                tasks.add(task);
                task.setScoreWrapper(scoreWrapper);
            }

            if (addOriginalDataset) {
                GeneralResamplingSearchRunnable task = new GeneralResamplingSearchRunnable(data.copy(),
                        this.algorithm, this.parameters, this,
                        this.verbose);
                task.setExternalGraph(this.externalGraph);
                task.setKnowledge(this.knowledge);
                tasks.add(task);
                task.setScoreWrapper(scoreWrapper);
            }
        } else {
            for (int i1 = 0; i1 < this.numberResampling; i1++) {
                List dataModels = new ArrayList<>();

                for (DataSet data : this.dataSets) {

                    if (this.resamplingWithReplacement) {
                        int sampleSize = (int) (data.getNumRows() * this.percentResampleSize / 100.0);
                        DataSet bootstrapSample = DataTransforms.getBootstrapSample(data, sampleSize);
                        bootstrapSample.setKnowledge(data.getKnowledge());
                        dataModels.add(bootstrapSample);
                    } else {
                        int sampleSize = (int) (data.getNumRows() * this.percentResampleSize / 100.0);
                        DataSet resamplingDataset = DataTransforms.getResamplingDataset(data, sampleSize);
                        resamplingDataset.setKnowledge(data.getKnowledge());
                        dataModels.add(resamplingDataset);
                    }
                }

                GeneralResamplingSearchRunnable task = new GeneralResamplingSearchRunnable(dataModels,
                        this.multiDataSetAlgorithm, this.parameters, this,
                        this.verbose);
                task.setExternalGraph(this.externalGraph);
                task.setKnowledge(dataModels.get(0).getKnowledge());
                task.setScoreWrapper(scoreWrapper);

                tasks.add(task);
            }
        }

        int numNoGraph = 0;

        if (this.runParallel) {
            List> futures = this.pool.invokeAll(tasks);
            for (Future future : futures) {
                Graph graph;
                try {
                    graph = future.get();

                    if (graph == null) {
                        numNograph++;
                    } else {
                        this.graphs.add(graph);
                    }
                } catch (InterruptedException | ExecutionException e) {
                    e.printStackTrace();
                }
            }
        } else {
            for (Callable callable : tasks) {
                try {
                    Graph graph = callable.call();

                    if (graph == null) {
                        numNoGraph++;
                    } else {
                        this.graphs.add(graph);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

        this.parameters.set("numberResampling", this.numberResampling);
        this.numNograph = numNoGraph;

        return this.graphs;
    }

    public int getNumNograph() {
        return numNograph;
    }

    public void setScoreWrapper(ScoreWrapper scoreWrapper) {
        this.scoreWrapper = scoreWrapper;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy