edu.cmu.tetradapp.model.GeneralAlgorithmRunner Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.model;
import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm;
import edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm;
import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper;
import edu.cmu.tetrad.algcomparison.independence.MSeparationTest;
import edu.cmu.tetrad.algcomparison.independence.TakesGraph;
import edu.cmu.tetrad.algcomparison.score.MSeparationScore;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.algcomparison.utils.HasKnowledge;
import edu.cmu.tetrad.algcomparison.utils.TakesIndependenceWrapper;
import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.LayoutUtil;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Triple;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.test.ScoreIndTest;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.search.utils.TsUtils;
import edu.cmu.tetrad.session.ParamsResettable;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.Unmarshallable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.*;
/**
* Stores an algorithms in the format of the algorithm comparison API.
*
* @author josephramsey
*/
public class GeneralAlgorithmRunner implements AlgorithmRunner, ParamsResettable,
Unmarshallable, IndTestProducer,
KnowledgeBoxInput {
private static final long serialVersionUID = 23L;
private final Map userAlgoSelections = new HashMap<>();
private DataWrapper dataWrapper;
private String name;
private Algorithm algorithm;
private Parameters parameters;
private Graph sourceGraph;
private Graph externalGraph;
private List graphList = new ArrayList<>();
private Knowledge knowledge;
private transient List independenceTests;
//===========================CONSTRUCTORS===========================//
public GeneralAlgorithmRunner(GeneralAlgorithmRunner runner, Parameters parameters) {
this(runner.getDataWrapper(), runner, parameters, null, null);
this.sourceGraph = runner.sourceGraph;
this.knowledge = runner.knowledge;
this.algorithm = runner.algorithm;
this.parameters = parameters;
this.userAlgoSelections.putAll(runner.userAlgoSelections);
}
public GeneralAlgorithmRunner(DataWrapper dataWrapper, Parameters parameters) {
this(dataWrapper, null, parameters, null, null);
}
/**
* Constructs a wrapper for the given DataWrapper. The DatWrapper must contain a DataSet that is either a DataSet or
* a DataSet or a DataList containing either a DataSet or a DataSet as its selected model.
*/
public GeneralAlgorithmRunner(DataWrapper dataWrapper, Parameters parameters,
KnowledgeBoxModel knowledgeBoxModel) {
this(dataWrapper, null, parameters, knowledgeBoxModel, null);
}
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GraphSource graphSource, Parameters parameters) {
this(dataWrapper, graphSource, parameters, null, null);
}
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GraphSource graphSource,
KnowledgeBoxModel knowledgeBoxModel,
Parameters parameters) {
this(dataWrapper, graphSource, parameters, knowledgeBoxModel, null);
}
/**
* Constructs a wrapper for the given DataWrapper. The DatWrapper must contain a DataSet that is either a DataSet or
* a DataSet or a DataList containing either a DataSet or a DataSet as its selected model.
*/
public GeneralAlgorithmRunner(DataWrapper dataWrapper, Parameters parameters,
KnowledgeBoxModel knowledgeBoxModel, IndependenceFactsModel facts) {
this(dataWrapper, null, parameters, knowledgeBoxModel, facts);
}
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GeneralAlgorithmRunner runner, Parameters parameters) {
this(dataWrapper, null, parameters, null, null);
this.algorithm = runner.algorithm;
this.userAlgoSelections.putAll(runner.userAlgoSelections);
}
/**
* Constructs a wrapper for the given DataWrapper. The DatWrapper must contain a DataSet that is either a DataSet or
* a DataSet or a DataList containing either a DataSet or a DataSet as its selected model.
*/
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GeneralAlgorithmRunner runner, Parameters parameters,
KnowledgeBoxModel knowledgeBoxModel) {
this(dataWrapper, null, parameters, knowledgeBoxModel, null);
this.algorithm = runner.algorithm;
this.userAlgoSelections.putAll(runner.userAlgoSelections);
}
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GraphSource graphSource, GeneralAlgorithmRunner runner,
Parameters parameters) {
this(dataWrapper, graphSource, parameters, null, null);
this.algorithm = runner.algorithm;
this.userAlgoSelections.putAll(runner.userAlgoSelections);
}
/**
* Constructs a wrapper for the given DataWrapper. The DatWrapper must contain a DataSet that is either a DataSet or
* a DataSet or a DataList containing either a DataSet or a DataSet as its selected model.
*/
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GraphSource graphSource, GeneralAlgorithmRunner runner,
Parameters parameters,
KnowledgeBoxModel knowledgeBoxModel) {
this(dataWrapper, graphSource, parameters, knowledgeBoxModel, null);
this.algorithm = runner.algorithm;
this.userAlgoSelections.putAll(runner.userAlgoSelections);
}
/**
* Constucts a wrapper for the given graph.
*/
public GeneralAlgorithmRunner(GraphSource graphSource, GeneralAlgorithmRunner runner, Parameters parameters) {
this(null, graphSource, parameters, null, null);
this.algorithm = runner.algorithm;
this.userAlgoSelections.putAll(runner.userAlgoSelections);
}
public GeneralAlgorithmRunner(GraphSource graphSource, Parameters parameters,
KnowledgeBoxModel knowledgeBoxModel) {
this(null, graphSource, parameters, knowledgeBoxModel, null);
}
public GeneralAlgorithmRunner(IndependenceFactsModel model,
Parameters parameters, KnowledgeBoxModel knowledgeBoxModel) {
this(null, null, parameters, knowledgeBoxModel, model);
}
/**
* Constucts a wrapper for the given graph.
*/
public GeneralAlgorithmRunner(GraphSource graphSource, Parameters parameters) {
this(null, graphSource, parameters, null, null);
}
/**
* Constructs a wrapper for the given DataWrapper. The DatWrapper must contain a DataSet that is either a DataSet or
* a DataSet or a DataList containing either a DataSet or a DataSet as its selected model.
*/
public GeneralAlgorithmRunner(DataWrapper dataWrapper, GraphSource graphSource, Parameters parameters,
KnowledgeBoxModel knowledgeBoxModel, IndependenceFactsModel facts) {
if (parameters == null) {
throw new NullPointerException();
}
this.parameters = parameters;
if (graphSource instanceof GeneralAlgorithmRunner) {
this.algorithm = ((GeneralAlgorithmRunner) graphSource).getAlgorithm();
}
if (dataWrapper != null) {
this.dataWrapper = dataWrapper;
if (dataWrapper.getDataModelList().isEmpty() && dataWrapper instanceof Simulation) {
((Simulation) dataWrapper).createSimulation();
}
}
if (graphSource != null) {
if (dataWrapper == null && graphSource instanceof DataWrapper) {
this.dataWrapper = (DataWrapper) graphSource;
} else {
this.sourceGraph = graphSource.getGraph();
}
}
if (dataWrapper != null) {
List names = this.dataWrapper.getVariableNames();
transferVarNamesToParams(names);
}
if (knowledgeBoxModel != null) {
this.knowledge = knowledgeBoxModel.getKnowledge();
} else {
this.knowledge = new Knowledge();
}
if (facts != null) {
getParameters().set("independenceFacts", facts.getFacts());
}
}
//============================PUBLIC METHODS==========================//
@Override
public void execute() {
List graphList = new ArrayList<>();
if (this.independenceTests != null) {
this.independenceTests.clear();
}
Algorithm algo = getAlgorithm();
if (this.knowledge != null && !knowledge.isEmpty()) {
if (algo instanceof HasKnowledge) {
((HasKnowledge) algo).setKnowledge(this.knowledge.copy());
} else {
throw new IllegalArgumentException("Knowledge has been supplied, but this algorithm does not use knowledge.");
}
}
if (getDataModelList().size() == 0 && getSourceGraph() != null) {
if (algo instanceof UsesScoreWrapper) {
// We inject the graph to the score to satisfy the tests like MSeparationScore - Zhou
ScoreWrapper scoreWrapper = ((UsesScoreWrapper) algo).getScoreWrapper();
if (scoreWrapper instanceof MSeparationScore) {
((MSeparationScore) scoreWrapper).setGraph(getSourceGraph());
}
}
if (algo instanceof TakesIndependenceWrapper) {
IndependenceWrapper wrapper = ((TakesIndependenceWrapper) algo).getIndependenceWrapper();
if (wrapper instanceof MSeparationTest) {
((MSeparationTest) wrapper).setGraph(getSourceGraph());
}
}
if (algo instanceof TakesGraph) {
((TakesGraph) algo).setGraph(this.sourceGraph);
}
if (this.algorithm instanceof HasKnowledge) {
Knowledge knowledge1 = TsUtils.getKnowledge(getSourceGraph());
if (this.knowledge.isEmpty() && !knowledge1.isEmpty()) {
((HasKnowledge) algo).setKnowledge(knowledge1);
} else {
((HasKnowledge) this.algorithm).setKnowledge(this.knowledge.copy());
}
}
Graph graph = algo.search(null, this.parameters);
LayoutUtil.defaultLayout(graph);
graphList.add(graph);
} else {
if (getAlgorithm() instanceof MultiDataSetAlgorithm) {
for (int k = 0; k < this.parameters.getInt("numRuns"); k++) {
Knowledge knowledge1 = getDataModelList().get(0).getKnowledge();
List dataSets = new ArrayList<>(getDataModelList());
for (DataModel dataSet : dataSets) dataSet.setKnowledge(knowledge1);
int randomSelectionSize = this.parameters.getInt("randomSelectionSize");
if (randomSelectionSize == 0) {
randomSelectionSize = dataSets.size();
}
if (dataSets.size() < randomSelectionSize) {
throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than "
+ "the number of data sets: " + randomSelectionSize + " > " + dataSets.size());
}
RandomUtil.shuffle(dataSets);
List sub = new ArrayList<>();
for (int j = 0; j < randomSelectionSize; j++) {
sub.add(dataSets.get(j));
}
if (algo instanceof TakesGraph) {
((TakesGraph) algo).setGraph(this.sourceGraph);
}
if (this.algorithm instanceof HasKnowledge) {
((HasKnowledge) this.algorithm).setKnowledge(this.knowledge.copy());
}
graphList.add(((MultiDataSetAlgorithm) algo).search(sub, this.parameters));
}
} else if (getAlgorithm() instanceof ClusterAlgorithm) {
for (int k = 0; k < this.parameters.getInt("numRuns"); k++) {
getDataModelList().forEach(dataModel -> {
if (dataModel instanceof ICovarianceMatrix) {
ICovarianceMatrix dataSet = (ICovarianceMatrix) dataModel;
if (algo instanceof TakesGraph) {
((TakesGraph) algo).setGraph(this.sourceGraph);
}
if (this.algorithm instanceof HasKnowledge) {
((HasKnowledge) this.algorithm).setKnowledge(this.knowledge.copy());
}
Graph graph = this.algorithm.search(dataSet, this.parameters);
LayoutUtil.defaultLayout(graph);
graphList.add(graph);
} else if (dataModel instanceof DataSet) {
DataSet dataSet = (DataSet) dataModel;
if (!dataSet.isContinuous()) {
throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm.");
}
if (algo instanceof TakesGraph) {
((TakesGraph) algo).setGraph(this.sourceGraph);
}
if (this.algorithm instanceof HasKnowledge) {
((HasKnowledge) this.algorithm).setKnowledge(this.knowledge.copy());
}
Graph graph = this.algorithm.search(dataSet, this.parameters);
LayoutUtil.defaultLayout(graph);
graphList.add(graph);
}
});
}
} else {
if (getDataModelList().size() != 1) {
throw new IllegalArgumentException("Expecting a single dataset here.");
}
if (algo != null) {
getDataModelList().forEach(data -> {
Knowledge knowledgeFromData = data.getKnowledge();
if (!(knowledgeFromData == null || knowledgeFromData.getVariables().isEmpty())) {
this.knowledge = knowledgeFromData;
}
DataType algDataType = algo.getDataType();
if (algo instanceof TakesGraph) {
((TakesGraph) algo).setGraph(this.sourceGraph);
}
if (this.algorithm instanceof HasKnowledge) {
((HasKnowledge) this.algorithm).setKnowledge(this.knowledge.copy());
}
if (data instanceof ICovarianceMatrix && parameters.getInt(Params.NUMBER_RESAMPLING) > 0) {
throw new IllegalArgumentException("Sorry, you need a tabular dataset in order to do bootstrapping.");
}
if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) {
Graph graph = algo.search(data, this.parameters);
LayoutUtil.defaultLayout(graph);
graphList.add(graph);
} else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) {
Graph graph = algo.search(data, this.parameters);
LayoutUtil.defaultLayout(graph);
graphList.add(graph);
} else if (data.isMixed() && algDataType == DataType.Mixed) {
Graph graph = algo.search(data, this.parameters);
LayoutUtil.defaultLayout(graph);
graphList.add(graph);
} else {
throw new IllegalArgumentException("The algorithm was not expecting that type of data.");
}
});
}
}
}
if (knowledge != null && knowledge.getNumTiers() > 0) {
for (Graph graph : graphList) {
GraphSearchUtils.arrangeByKnowledgeTiers(graph, knowledge);
}
} else {
for (Graph graph : graphList) {
LayoutUtil.defaultLayout(graph);
}
}
this.graphList = graphList;
}
public boolean hasMissingValues() {
DataModelList dataModelList = getDataModelList();
if (dataModelList.containsEmptyData()) {
return false;
} else {
if (dataModelList.get(0) instanceof CovarianceMatrix) {
return false;
}
DataSet dataSet = (DataSet) dataModelList.get(0);
return dataSet.existsMissingValue();
}
}
/**
* By default, algorithm do not support knowledge. Those that do will speak up.
*/
@Override
public boolean supportsKnowledge() {
return false;
}
@Override
public MeekRules getMeekRules() {
return null;
}
@Override
public Graph getExternalGraph() {
return this.externalGraph;
}
@Override
public void setExternalGraph(Graph graph) {
this.externalGraph = graph;
}
@Override
public String getAlgorithmName() {
return null;
}
@Override
public final Graph getSourceGraph() {
return this.sourceGraph;
}
@Override
public Graph getResultGraph() {
return getGraph();
}
@Override
public final DataModel getDataModel() {
if (this.dataWrapper != null) {
DataModelList dataModelList = this.dataWrapper.getDataModelList();
if (dataModelList.size() == 1) {
return dataModelList.get(0);
} else {
return dataModelList;
}
} else {
// Do not throw an exception here!
return new BoxDataSet(new VerticalDoubleDataBox(0, 0), new ArrayList<>());
}
}
@Override
public Parameters getParams() {
return null;
}
public final DataModelList getDataModelList() {
if (this.dataWrapper == null) {
return new DataModelList();
}
return this.dataWrapper.getDataModelList();
}
public final Parameters getParameters() {
return this.parameters;
}
@Override
public Object getResettableParams() {
return this.getParameters();
}
@Override
public void resetParams(Object params) {
this.parameters = (Parameters) params;
}
//===========================PRIVATE METHODS==========================//
private void transferVarNamesToParams(List names) {
getParameters().set("varNames", names);
}
/**
* Adds semantic checks to the default deserialization method. This method must have the standard signature for a
* readObject method, and the body of the method must begin with "s.defaultReadObject();". Other than that, any
* semantic checks can be specified and do not need to stay the same from version to version. A readObject method of
* this form may be added to any class, even if Tetrad sessions were previously saved out using a version of the
* class that didn't include it. (That's what the "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for
* help.
*/
private void readObject(ObjectInputStream s)
throws IOException, ClassNotFoundException {
s.defaultReadObject();
}
@Override
public IndependenceTest getIndependenceTest() {
if (this.independenceTests == null) {
this.independenceTests = new ArrayList<>();
}
if (this.independenceTests.size() == 1) {
return this.independenceTests.get(0);
}
Algorithm algo = getAlgorithm();
if (getDataModelList().size() == 0 && getSourceGraph() != null) {
// We inject the graph to the test to satisfy the tests like MSeparationTest - Zhou
IndependenceWrapper test = new MSeparationTest(getSourceGraph());
if (this.independenceTests == null) {
this.independenceTests = new ArrayList<>();
}
// Grabbing this independence test for the independence tests interface. JR 2020.8.24
// IndependenceTest test = indTestWrapper.getTest(null, parameters);
this.independenceTests.add(test.getTest(null, this.parameters));
} else if (algo instanceof TakesIndependenceWrapper) {
if (getDataModelList().size() == 1) {
IndependenceWrapper indTestWrapper = ((TakesIndependenceWrapper) getAlgorithm()).getIndependenceWrapper();
if (this.independenceTests == null) {
this.independenceTests = new ArrayList<>();
}
// Grabbing this independence test for the independence tests interface. JR 2020.8.24
IndependenceTest test = indTestWrapper.getTest(getDataModelList().get(0), this.parameters);
this.independenceTests.add(test);
}
} else if (algo instanceof UsesScoreWrapper) {
if (getDataModelList().size() == 1) {
ScoreWrapper wrapper = ((UsesScoreWrapper) getAlgorithm()).getScoreWrapper();
if (this.independenceTests == null) {
this.independenceTests = new ArrayList<>();
}
// Grabbing this independence score for the independence tests interface. JR 2020.8.24
Score score = wrapper.getScore(getDataModelList().get(0), this.parameters);
this.independenceTests.add(new ScoreIndTest(score));
}
}
if (this.independenceTests.isEmpty()) {
throw new IllegalArgumentException("One or more of the parents was a search that didn't use "
+ "a test or a score.");
}
return this.independenceTests.get(0);
}
@Override
public String getName() {
return this.name;
}
@Override
public void setName(String name) {
this.name = name;
}
public Algorithm getAlgorithm() {
return this.algorithm;
}
public void setAlgorithm(Algorithm algorithm) {
if (algorithm == null) {
throw new NullPointerException("Algorithm not specified");
}
this.algorithm = algorithm;
}
@Override
public List getTriplesClassificationTypes() {
return Collections.EMPTY_LIST;
}
@Override
public List> getTriplesLists(Node node) {
return Collections.EMPTY_LIST;
}
@Override
public Map getParamSettings() {
return Collections.EMPTY_MAP;
}
@Override
public Map getAllParamSettings() {
return Collections.EMPTY_MAP;
}
@Override
public void setAllParamSettings(Map paramSettings) {
}
@Override
public Graph getGraph() {
if (this.graphList == null || this.graphList.isEmpty()) {
return null;
} else {
return this.graphList.get(0);
}
}
@Override
public List getGraphs() {
return this.graphList;
}
public Knowledge getKnowledge() {
return this.knowledge;
}
public DataWrapper getDataWrapper() {
return this.dataWrapper;
}
@Override
public List getVariables() {
return Collections.EMPTY_LIST;
}
@Override
public List getVariableNames() {
return Collections.EMPTY_LIST;
}
public List getCompareGraphs(List graphs) {
if (graphs == null) {
throw new NullPointerException();
}
List compareGraphs = new ArrayList<>();
for (Graph graph : graphs) {
compareGraphs.add(this.algorithm.getComparisonGraph(graph));
}
return compareGraphs;
}
public Map getUserAlgoSelections() {
return this.userAlgoSelections;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy