All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.pitt.dbmi.algo.bayesian.constraint.search.PagSamplingRfci Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
package edu.pitt.dbmi.algo.bayesian.constraint.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.search.IGraphSearch;
import edu.cmu.tetrad.search.Rfci;
import edu.cmu.tetrad.search.test.IndTestProbabilistic;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.util.GraphSampling;

import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.*;

/**
 * Jan 29, 2023 4:10:52 PM
 *
 * @author Kevin V. Bui ([email protected])
 */
public class PagSamplingRfci implements IGraphSearch {

    private final int NUM_THREADS = 10;
    private final DataSet dataSet;
    // PagSamplingRfci
    private int numRandomizedSearchModels = 10;
    private boolean verbose = false;
    // Rfci parameters
    private int depth = -1;
    private int maxPathLength = -1;
    // IndTestProbabilistic parameters
    private boolean threshold = true;
    private double cutoff = 0.5;
    private double priorEquivalentSampleSize = 10;
    /**
     * The background knowledge.
     */
    private Knowledge knowledge;

    public PagSamplingRfci(DataSet dataSet) {
        this.dataSet = dataSet;
    }

    @Override
    public Graph search() {
        List graphs = runSearches();

        return GraphSampling.createGraphWithHighProbabilityEdges(graphs);
    }

    List> createTasks(int numOfTasks) {
        List> callableTasks = new LinkedList<>();

        for (int i = 0; i < numOfTasks; i++) {
            callableTasks.add(new RfciSearchTask());
        }

        return callableTasks;
    }

    private List runSearches() {
        List graphs = new LinkedList<>();

        ExecutorService pool = Executors.newFixedThreadPool(NUM_THREADS);
        try {
            while (graphs.size() < numRandomizedSearchModels && !Thread.currentThread().isInterrupted()) {
                List> callableTasks = createTasks(numRandomizedSearchModels - graphs.size());
                List> completedTasks = pool.invokeAll(callableTasks);
                for (Future completedTask : completedTasks) {
                    try {
                        Graph graph = completedTask.get();
                        if (graph != null && GraphSearchUtils.isLegalPag(graph).isLegalPag()) {
                            graphs.add(graph);
                        }
                    } catch (ExecutionException exception) {
                        exception.printStackTrace(System.err);
                    }
                }
            }
        } catch (InterruptedException exception) {
            exception.printStackTrace(System.err);
        } finally {
            shutdownAndAwaitTermination(pool);
        }

        return graphs;
    }

    /**
     * Call shutdown to reject incoming tasks, and then calling shutdownNow, if necessary, to cancel any lingering
     * tasks.
     */
    private void shutdownAndAwaitTermination(ExecutorService pool) {
        pool.shutdown();
        try {
            if (!pool.awaitTermination(5, TimeUnit.SECONDS)) {
                pool.shutdownNow();
                if (!pool.awaitTermination(5, TimeUnit.SECONDS)) {
                    System.err.println("Pool did not terminate");
                }
            }
        } catch (InterruptedException ie) {
            pool.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    public void setNumRandomizedSearchModels(int numRandomizedSearchModels) {
        this.numRandomizedSearchModels = numRandomizedSearchModels;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setDepth(int depth) {
        this.depth = depth;
    }

    public void setMaxPathLength(int maxPathLength) {
        this.maxPathLength = maxPathLength;
    }

    public void setThreshold(boolean threshold) {
        this.threshold = threshold;
    }

    public void setCutoff(double cutoff) {
        this.cutoff = cutoff;
    }

    public void setPriorEquivalentSampleSize(double priorEquivalentSampleSize) {
        this.priorEquivalentSampleSize = priorEquivalentSampleSize;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
    }

    private class RfciSearchTask implements Callable {

        public RfciSearchTask() {
        }

        @Override
        public Graph call() throws Exception {
            try {
                IndTestProbabilistic independenceTest = new IndTestProbabilistic(dataSet);
                independenceTest.setThreshold(threshold);
                independenceTest.setCutoff(cutoff);
                independenceTest.setPriorEquivalentSampleSize(priorEquivalentSampleSize);
                independenceTest.setVerbose(verbose);

                Rfci rfci = new Rfci(independenceTest);
                if (knowledge != null) {
                    rfci.setKnowledge(knowledge);
                }
                rfci.setDepth(depth);
                rfci.setMaxPathLength(maxPathLength);
                rfci.setVerbose(verbose);

                return rfci.search();
            } catch (Exception exception) {
                exception.printStackTrace(System.err);

                return null;
            }
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy