edu.cmu.tetrad.algcomparison.algorithm.multi.FaskVote Maven / Gradle / Ivy
package edu.cmu.tetrad.algcomparison.algorithm.multi;
import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm;
import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.algcomparison.utils.HasKnowledge;
import edu.cmu.tetrad.algcomparison.utils.TakesIndependenceWrapper;
import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
import edu.cmu.tetrad.annotation.AlgType;
import edu.cmu.tetrad.annotation.Bootstrapping;
import edu.cmu.tetrad.annotation.Experimental;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Wraps the MultiFask algorithm for continuous variables.
*
* Requires that the parameter 'randomSelectionSize' be set to indicate how many datasets should be taken at a time
* (randomly). This cannot given multiple values.
*
* @author mglymour
* @author josephramsey
*/
@edu.cmu.tetrad.annotation.Algorithm(
name = "FASK-Vote",
command = "fask-vote",
algoType = AlgType.forbid_latent_common_causes,
dataType = DataType.Continuous
)
@Bootstrapping
@Experimental
public class FaskVote implements MultiDataSetAlgorithm, HasKnowledge, UsesScoreWrapper, TakesIndependenceWrapper {
private static final long serialVersionUID = 23L;
private Knowledge knowledge = new Knowledge();
private ScoreWrapper score;
private IndependenceWrapper test;
public FaskVote(ScoreWrapper score) {
this.score = score;
}
public FaskVote() {
}
public FaskVote(IndependenceWrapper test, ScoreWrapper score) {
this.test = test;
this.score = score;
}
@Override
public Graph search(List dataSets, Parameters parameters) {
for (DataModel d : dataSets) {
if (((DataSet) d).existsMissingValue()) {
throw new IllegalArgumentException("Please remove or impute missing values.");
}
}
if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
List _dataSets = new ArrayList<>();
for (DataModel d : dataSets) {
_dataSets.add((DataSet) d);
}
edu.cmu.tetrad.search.work_in_progress.FaskVote search = new edu.cmu.tetrad.search.work_in_progress.FaskVote(_dataSets, this.score, this.test);
search.setKnowledge(this.knowledge);
return search.search(parameters);
} else {
FaskVote imagesSemBic = new FaskVote(this.test, this.score);
List datasets = new ArrayList<>();
for (DataModel dataModel : dataSets) {
datasets.add((DataSet) dataModel);
}
GeneralResamplingTest search = new GeneralResamplingTest(
datasets, imagesSemBic,
parameters.getInt(Params.NUMBER_RESAMPLING),
parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE),
parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT), parameters.getInt(Params.RESAMPLING_ENSEMBLE), parameters.getBoolean(Params.ADD_ORIGINAL_DATASET));
search.setKnowledge(this.knowledge);
search.setScoreWrapper(score);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
return search.search();
}
}
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
return search(Collections.singletonList(SimpleDataLoader.getContinuousDataSet(dataSet)), parameters);
} else {
FaskVote imagesSemBic = new FaskVote();
List dataSets = Collections.singletonList(SimpleDataLoader.getContinuousDataSet(dataSet));
GeneralResamplingTest search = new GeneralResamplingTest(
dataSets,
imagesSemBic,
parameters.getInt(Params.NUMBER_RESAMPLING),
parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE),
parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT), parameters.getInt(Params.RESAMPLING_ENSEMBLE), parameters.getBoolean(Params.ADD_ORIGINAL_DATASET));
search.setKnowledge(this.knowledge);
search.setScoreWrapper(score);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
return search.search();
}
}
@Override
public Graph getComparisonGraph(Graph graph) {
return new EdgeListGraph(graph);
}
@Override
public String getDescription() {
return "FASK-Vote";
}
@Override
public DataType getDataType() {
return DataType.Continuous;
}
@Override
public List getParameters() {
List parameters = new Images().getParameters();
parameters.addAll(new Fask().getParameters());
return parameters;
}
@Override
public Knowledge getKnowledge() {
return this.knowledge;
}
@Override
public void setKnowledge(Knowledge knowledge) {
this.knowledge = new Knowledge(knowledge);
}
@Override
public void setIndTestWrapper(IndependenceWrapper test) {
this.test = test;
}
@Override
public ScoreWrapper getScoreWrapper() {
return this.score;
}
@Override
public void setScoreWrapper(ScoreWrapper score) {
this.score = score;
}
@Override
public IndependenceWrapper getIndependenceWrapper() {
return this.test;
}
@Override
public void setIndependenceWrapper(IndependenceWrapper independenceWrapper) {
this.test = independenceWrapper;
}
}