edu.cmu.tetradapp.model.FgesRunner Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.model;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Triple;
import edu.cmu.tetrad.search.*;
import edu.cmu.tetrad.session.DoNotAddOldModel;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.TetradSerializableUtils;
import java.beans.PropertyChangeEvent;
import java.beans.PropertyChangeListener;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Extends AbstractAlgorithmRunner to produce a wrapper for the GES algorithm.
*
* @author Ricardo Silva
*/
public class FgesRunner extends AbstractAlgorithmRunner implements IFgesRunner,
PropertyChangeListener, IGesRunner, Indexable, DoNotAddOldModel {
static final long serialVersionUID = 23L;
public enum Type {CONTINUOUS, DISCRETE, MIXED, GRAPH}
private transient List listeners;
private List topGraphs;
private int index;
private transient Fges fges;
private transient Graph externalGraph;
//============================CONSTRUCTORS============================//
public FgesRunner(DataWrapper[] dataWrappers, Parameters params, KnowledgeBoxModel knowledgeBoxModel) {
super(new MergeDatasetsWrapper(dataWrappers, params), params, knowledgeBoxModel);
}
public FgesRunner(DataWrapper[] dataWrappers, Parameters params) {
super(new MergeDatasetsWrapper(dataWrappers, params), params, null);
}
public FgesRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params) {
super(new MergeDatasetsWrapper(dataWrappers, params), params, null);
if (graph == this) throw new IllegalArgumentException();
this.externalGraph = graph.getGraph();
}
public FgesRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params, KnowledgeBoxModel knowledgeBoxModel) {
super(new MergeDatasetsWrapper(dataWrappers, params), params, knowledgeBoxModel);
if (graph == this) throw new IllegalArgumentException();
this.externalGraph = graph.getGraph();
}
public FgesRunner(GraphWrapper graphWrapper, Parameters params, KnowledgeBoxModel knowledgeBoxModel) {
super(graphWrapper.getGraph(), params, knowledgeBoxModel);
}
public FgesRunner(GraphWrapper graphWrapper, Parameters params) {
super(graphWrapper.getGraph(), params, null);
}
/**
* Generates a simple exemplar of this class to test serialization.
*
* @see TetradSerializableUtils
*/
public static DataWrapper serializableInstance() {
return new DataWrapper(new Parameters());
}
//============================PUBLIC METHODS==========================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
System.out.println("A");
Object model = getDataModel();
if (model == null && getSourceGraph() != null) {
model = getSourceGraph();
}
if (model == null) {
throw new RuntimeException("Data source is unspecified. You may need to double click all your data boxes, \n" +
"then click Save, and then right click on them and select Propagate Downstream. \n" +
"The issue is that we use a seed to simulate from IM's, so your data is not saved to \n" +
"file when you save the session. It can, however, be recreated from the saved seed.");
}
Parameters params = getParams();
if (model instanceof Graph) {
GraphScore gesScore = new GraphScore((Graph) model);
this.fges = new Fges(gesScore);
this.fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
this.fges.setVerbose(true);
} else {
double penaltyDiscount = params.getDouble("penaltyDiscount", 4);
if (model instanceof DataSet) {
DataSet dataSet = (DataSet) model;
if (dataSet.isContinuous()) {
SemBicScore gesScore = new SemBicScore(new CovarianceMatrix((DataSet) model));
gesScore.setPenaltyDiscount(penaltyDiscount);
System.out.println("Score done");
this.fges = new Fges(gesScore);
} else if (dataSet.isDiscrete()) {
double samplePrior = getParams().getDouble("samplePrior", 1);
double structurePrior = getParams().getDouble("structurePrior", 1);
BDeuScore score = new BDeuScore(dataSet);
score.setSamplePrior(samplePrior);
score.setStructurePrior(structurePrior);
this.fges = new Fges(score);
} else {
ConditionalGaussianScore gesScore = new ConditionalGaussianScore(dataSet, 1, 1, false);
gesScore.setPenaltyDiscount(penaltyDiscount);
this.fges = new Fges(gesScore);
}
} else if (model instanceof ICovarianceMatrix) {
SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model);
gesScore.setPenaltyDiscount(penaltyDiscount);
gesScore.setPenaltyDiscount(penaltyDiscount);
this.fges = new Fges(gesScore);
} else if (model instanceof DataModelList) {
DataModelList list = (DataModelList) model;
for (DataModel dataModel : list) {
if (!(dataModel instanceof DataSet || dataModel instanceof ICovarianceMatrix)) {
throw new IllegalArgumentException("Need a combination of all continuous data sets or " +
"covariance matrices, or else all discrete data sets, or else a single externalGraph.");
}
}
if (list.size() != 1) {
throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or externalGraph " +
"as input. For multiple data sets as input, use IMaGES.");
}
if (allContinuous(list)) {
double penalty = getParams().getDouble("penaltyDiscount", 4);
if (params.getBoolean("firstNontriangular", false)) {
SemBicScoreImages fgesScore = new SemBicScoreImages(list);
fgesScore.setPenaltyDiscount(penalty);
this.fges = new Fges(fgesScore);
} else {
SemBicScoreImages fgesScore = new SemBicScoreImages(list);
fgesScore.setPenaltyDiscount(penalty);
this.fges = new Fges(fgesScore);
}
} else if (allDiscrete(list)) {
double structurePrior = getParams().getDouble("structurePrior", 1);
double samplePrior = getParams().getDouble("samplePrior", 1);
BdeuScoreImages fgesScore = new BdeuScoreImages(list);
fgesScore.setSamplePrior(samplePrior);
fgesScore.setStructurePrior(structurePrior);
if (params.getBoolean("firstNontriangular", false)) {
this.fges = new Fges(fgesScore);
} else {
this.fges = new Fges(fgesScore);
}
} else {
throw new IllegalArgumentException("Data must be either all discrete or all continuous.");
}
} else {
System.out.println("No viable input.");
}
}
this.fges.setExternalGraph(this.externalGraph);
this.fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
this.fges.setVerbose(true);
this.fges.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true));
Graph graph = this.fges.search();
if (getSourceGraph() != null) {
GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
} else if (((IKnowledge) getParams().get("knowledge", new Knowledge2())).isDefaultToKnowledgeLayout()) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, (IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else {
GraphUtils.circleLayout(graph, 200, 200, 150);
}
setResultGraph(graph);
this.topGraphs = new ArrayList<>(this.fges.getTopGraphs());
if (this.topGraphs.isEmpty()) {
this.topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN));
}
setIndex(this.topGraphs.size() - 1);
}
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public Type getType() {
Object model = getDataModel();
if (model == null && getSourceGraph() != null) {
model = getSourceGraph();
}
if (model == null) {
throw new RuntimeException("Data source is unspecified. You may need to double click all your data boxes, \n" +
"then click Save, and then right click on them and select Propagate Downstream. \n" +
"The issue is that we use a seed to simulate from IM's, so your data is not saved to \n" +
"file when you save the session. It can, however, be recreated from the saved seed.");
}
Type type;
if (model instanceof Graph) {
type = Type.GRAPH;
} else if (model instanceof DataSet) {
DataSet dataSet = (DataSet) model;
if (dataSet.isContinuous()) {
type = Type.CONTINUOUS;
} else if (dataSet.isDiscrete()) {
type = Type.DISCRETE;
} else {
type = Type.MIXED;
// throw new IllegalStateException("Data set must either be continuous or discrete.");
}
} else if (model instanceof ICovarianceMatrix) {
type = Type.CONTINUOUS;
} else if (model instanceof DataModelList) {
DataModelList list = (DataModelList) model;
if (allContinuous(list)) {
type = Type.CONTINUOUS;
} else if (allDiscrete(list)) {
type = Type.DISCRETE;
} else {
type = Type.MIXED;
// throw new IllegalArgumentException("Data must be either all discrete or all continuous.");
}
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
return type;
}
private boolean allContinuous(List dataModels) {
for (DataModel dataModel : dataModels) {
if (dataModel instanceof DataSet) {
if (!dataModel.isContinuous() || dataModel instanceof ICovarianceMatrix) {
return false;
}
}
}
return true;
}
private boolean allDiscrete(List dataModels) {
for (DataModel dataModel : dataModels) {
if (dataModel instanceof DataSet) {
if (!dataModel.isDiscrete()) {
return false;
}
}
}
return true;
}
public void setIndex(int index) {
if (index < -1) {
throw new IllegalArgumentException("Must be in >= -1: " + index);
}
this.index = index;
}
public int getIndex() {
return this.index;
}
public Graph getGraph() {
if (getIndex() >= 0) {
return getTopGraphs().get(getIndex()).getGraph();
} else {
return getResultGraph();
}
}
/**
* @return the names of the triple classifications. Coordinates with
*/
public List getTriplesClassificationTypes() {
return new ArrayList<>();
}
/**
* @return the list of triples corresponding to getTripleClassificationNames
.
*/
public List> getTriplesLists(Node node) {
return new ArrayList<>();
}
public boolean supportsKnowledge() {
return true;
}
public ImpliedOrientation getMeekRules() {
MeekRules rules = new MeekRules();
rules.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
return rules;
}
@Override
public Map getParamSettings() {
super.getParamSettings();
Parameters params = getParams();
this.paramSettings.put("Penalty Discount", new DecimalFormat("0.0").format(params.getDouble("penaltyDiscount", 4)));
return this.paramSettings;
}
@Override
public String getAlgorithmName() {
return "FGES";
}
public void propertyChange(PropertyChangeEvent evt) {
firePropertyChange(evt);
}
private void firePropertyChange(PropertyChangeEvent evt) {
for (PropertyChangeListener l : getListeners()) {
l.propertyChange(evt);
}
}
private List getListeners() {
if (this.listeners == null) {
this.listeners = new ArrayList<>();
}
return this.listeners;
}
public void addPropertyChangeListener(PropertyChangeListener l) {
if (!getListeners().contains(l)) getListeners().add(l);
}
public List getTopGraphs() {
return this.topGraphs;
}
public String getBayesFactorsReport(Graph dag) {
if (this.fges == null) {
return "Please re-run IMaGES.";
} else {
return this.fges.logEdgeBayesFactorsString(dag);
}
}
public GraphScorer getGraphScorer() {
return this.fges;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy