edu.cmu.tetradapp.model.GridSearchModel Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.model;
import edu.cmu.tetrad.algcomparison.Comparison;
import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.algorithm.Algorithms;
import edu.cmu.tetrad.algcomparison.graph.RandomForward;
import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.algcomparison.simulation.Simulation;
import edu.cmu.tetrad.algcomparison.simulation.Simulations;
import edu.cmu.tetrad.algcomparison.simulation.SingleDatasetSimulation;
import edu.cmu.tetrad.algcomparison.statistic.ParameterColumn;
import edu.cmu.tetrad.algcomparison.statistic.Statistic;
import edu.cmu.tetrad.algcomparison.statistic.Statistics;
import edu.cmu.tetrad.algcomparison.utils.TakesIndependenceWrapper;
import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
import edu.cmu.tetrad.annotation.AnnotatedClass;
import edu.cmu.tetrad.annotation.Score;
import edu.cmu.tetrad.annotation.TestOfIndependence;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.util.*;
import edu.cmu.tetradapp.session.SessionModel;
import edu.cmu.tetradapp.ui.model.*;
import org.jetbrains.annotations.NotNull;
import org.reflections.Reflections;
import org.reflections.scanners.Scanners;
import java.io.*;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.*;
import java.util.prefs.Preferences;
/**
* The GridSearchModel class is a session model that allows for running comparisons of algorithms. It provides methods
* for selecting algorithms, simulations, statistics, and parameters, and then running the comparison.
*
* The reference is here:
*
* Ramsey, J. D., Malinsky, D., & Bui, K. V. (2020). Algcomparison: Comparing the performance of graphical structure
* learning algorithms with tetrad. Journal of Machine Learning Research, 21(238), 1-6.
*
* @author josephramsey
*/
public class GridSearchModel implements SessionModel, GraphSource {
@Serial
private static final long serialVersionUID = 23L;
/**
* A private final variable that holds a Parameters object.
*/
private final Parameters parameters;
/**
* The knowledge to be used for the GridSearchModel.
*/
private final Knowledge knowledge;
/**
* The data to be used for the GridSearchModel.
*/
private DataSet suppliedData = null;
/**
* The graph to be used for the GridSearchModel.
*/
private Graph suppliedGraph = null;
/**
* The list of statistic names.
*/
private List statNames;
/**
* The list of simulation names.
*/
private List simNames;
/**
* The list of simulation classes.
*/
private List> simulationClasses;
/**
* The list of statistic classes.
*/
private List> statisticsClasses;
/**
* The list of algorithm classes.
*/
private List> algorithmClasses;
/**
* The list of algorithm names.
*/
private List algNames;
/**
* The last comparison text displayed.
*/
private String lastComparisonText = "";
/**
* The last verbose output displayed.
*/
private String lastVerboseOutputText = "";
/**
* The name of the GridSearchModel.
*/
private String name = "Grid Search";
/**
* The variable resultsPath represents the path to the result folder. This is set after a comparison has been run
* and can be used to add additional files to the comparison results.
*/
private String resultsPath = null;
/**
* This variable represents the currently selected graph.
*/
private Graph selectedGraph = null;
/**
* The selectedSimulation variable represents the index of the currently selected simulation.
* This variable is used to keep track of the selected simulation in a collection of simulations.
* The index is zero-based, where 0 represents the first simulation.
*
* By default, the value of selectedSimulation is 0, indicating that the first simulation is selected.
*
* The value of selectedSimulation can be modified externally to change the selected simulation.
*
* @see Simulation
*/
private int selectedSimulation = 0;
/**
* The selectedAlgorithm variable holds the index of the currently selected algorithm.
*
* The value of selectedAlgorithm represents the index of the algorithm in a collection
* or an array of algorithms.
*
* The default value of selectedAlgorithm is 0, indicating that the first algorithm in
* the collection or array is selected by default.
*
* The value of selectedAlgorithm can be changed to select a different algorithm by
* assigning a different index to it.
*
* @since 1.0
*/
private int selectedAlgorithm = 0;
/**
* The index of the selected graph.
*/
private int selectedGraphIndex = 0;
/**
* Verbose output is sent here.
*/
private transient PrintStream verboseOut;
/**
* Constructs a new GridSearchModel with the specified parameters.
*
* @param parameters The parameters to be set.
*/
public GridSearchModel(Parameters parameters) {
if (parameters == null) {
throw new IllegalArgumentException("Parameters must not be null.");
}
this.parameters = parameters;
this.knowledge = null;
this.suppliedData = null;
this.suppliedGraph = null;
initializeIfNull();
}
/**
* Initializes a new GridSearchModel with the given KnowledgeBoxModel and Parameters.
*
* @param knowledge The KnowledgeBoxModel containing the knowledge to be used for grid search. Must not be null.
* @param parameters The Parameters specifying the grid search parameters. Must not be null.
* @throws IllegalArgumentException If either knowledge or parameters are null.
*/
public GridSearchModel(KnowledgeBoxModel knowledge, Parameters parameters) {
if (knowledge == null) {
throw new IllegalArgumentException("Knowledge must not be null.");
}
if (parameters == null) {
throw new IllegalArgumentException("Parameters must not be null.");
}
this.parameters = parameters;
this.knowledge = knowledge.getKnowledge();
this.suppliedData = null;
this.suppliedGraph = null;
initializeIfNull();
}
/**
* Initializes a new instance of the GridSearchModel class.
*
* @param graphSource The graph source to be used for the model.
* @param parameters The parameters to be used for the model.
* @throws IllegalArgumentException if graphSource or parameters is null.
*/
public GridSearchModel(GraphSource graphSource, Parameters parameters) {
if (graphSource == null) {
throw new IllegalArgumentException("Graph source must not be null.");
}
if (parameters == null) {
throw new IllegalArgumentException("Parameters must not be null.");
}
this.parameters = parameters;
this.knowledge = null;
this.suppliedGraph = graphSource.getGraph();
this.suppliedData = null;
initializeIfNull();
}
/**
* Constructs a grid search model with the given graph source, knowledge box model, and parameters.
*
* @param graphSource The source of the graph.
* @param knowledge The knowledge box model.
* @param parameters The parameters for the grid search model.
* @throws IllegalArgumentException if graphSource, knowledge, or parameters is null.
*/
public GridSearchModel(GraphSource graphSource, KnowledgeBoxModel knowledge, Parameters parameters) {
if (graphSource == null) {
throw new IllegalArgumentException("Graph source must not be null.");
}
if (knowledge == null) {
throw new IllegalArgumentException("Knowledge must not be null.");
}
if (parameters == null) {
throw new IllegalArgumentException("Parameters must not be null.");
}
this.parameters = parameters;
this.knowledge = knowledge.getKnowledge();
this.suppliedGraph = graphSource.getGraph();
initializeIfNull();
}
/**
* Constructs a new GridSearchModel instance.
*
* @param dataWrapper the data wrapper containing the selected data model
* @param parameters the parameters to use for grid search
* @throws IllegalArgumentException if either dataWrapper or parameters is null
*/
public GridSearchModel(DataWrapper dataWrapper, Parameters parameters) {
if (dataWrapper == null) {
throw new IllegalArgumentException("Data wrapper must not be null.");
}
if (parameters == null) {
throw new IllegalArgumentException("Parameters must not be null.");
}
this.parameters = parameters;
this.knowledge = null;
this.suppliedData = (DataSet) dataWrapper.getSelectedDataModel();
this.suppliedGraph = null;
initializeIfNull();
}
/**
* Constructs a new instance of the GridSearchModel.
*
* @param dataWrapper the data wrapper used for selecting the data model (must not be null)
* @param knowledge the knowledge box model (must not be null)
* @param parameters the parameters for the model (must not be null)
* @throws IllegalArgumentException if any of the parameters is null
*/
public GridSearchModel(DataWrapper dataWrapper, KnowledgeBoxModel knowledge, Parameters parameters) {
if (dataWrapper == null) {
throw new IllegalArgumentException("Data wrapper must not be null.");
}
if (knowledge == null) {
throw new IllegalArgumentException("Knowledge must not be null.");
}
if (parameters == null) {
throw new IllegalArgumentException("Parameters must not be null.");
}
this.parameters = parameters;
this.knowledge = knowledge.getKnowledge();
this.suppliedData = (DataSet) dataWrapper.getSelectedDataModel();
System.out.println("Variables names = " + this.suppliedData.getVariableNames());
initializeIfNull();
}
/**
* Finds and returns a list of algorithm classes that implement the Algorithm interface.
*
* @return A list of algorithm classes.
*/
@NotNull
private static List> findAlgorithmClasses() {
Set> _algorithms = findImplementations("edu.cmu.tetrad.algcomparison.algorithm",
Algorithm.class);
final List> algorithmClasses = new ArrayList<>(_algorithms);
algorithmClasses.sort(Comparator.comparing(Class::getName));
return algorithmClasses;
}
/**
* Finds and returns a set of classes that implement a given interface within a specified package.
*
* @param packageName The name of the package to search in.
* @param interfaceClazz The interface class to find implementations of.
* @return A set of classes that implement the specified interface.
*/
private static Set> findImplementations(String packageName, Class interfaceClazz) {
Reflections reflections = new Reflections(packageName, Scanners.SubTypes);
return reflections.getSubTypesOf(interfaceClazz);
}
public static void sortTableColumns(List selectedTableColumns) {
selectedTableColumns.sort((o1, o2) -> {
if (o1.equals(o2)) {
return 0;
} else if (o1.getType() == MyTableColumn.ColumnType.PARAMETER
&& o2.getType() == MyTableColumn.ColumnType.STATISTIC) {
return -1;
} else if (o1.getType() == MyTableColumn.ColumnType.STATISTIC
&& o2.getType() == MyTableColumn.ColumnType.PARAMETER) {
return 1;
} else {
return String.CASE_INSENSITIVE_ORDER.compare(o1.getColumnName(), o2.getColumnName());
}
});
}
/**
* Retrieves all simulation parameters from a list of simulation objects.
*
* @param simulations the list of simulation objects
* @return a set of all simulation parameters
*/
@NotNull
public static Set getAllSimulationParameters(List simulations) {
Set paramNamesSet = new HashSet<>();
for (Simulation simulation : simulations) {
paramNamesSet.addAll(simulation.getParameters());
}
return paramNamesSet;
}
/**
* Retrieves all algorithms parameters from a list of Algorithm objects.
*
* @param algorithms the list of Algorithm objects
* @return a set of all algorithms parameters
*/
@NotNull
public static Set getAllAlgorithmParameters(List algorithms) {
Set paramNamesSet = new HashSet<>();
for (AlgorithmSpec algorithm : algorithms) {
paramNamesSet.addAll(algorithm.getAlgorithmImpl().getParameters());
}
return paramNamesSet;
}
@NotNull
public static Set getAllTestParameters(List algorithms) {
Set paramNamesSet = new HashSet<>();
for (AlgorithmSpec algorithm : algorithms) {
Algorithm algorithmImpl = algorithm.getAlgorithmImpl();
if (algorithmImpl instanceof TakesIndependenceWrapper) {
paramNamesSet.addAll(((TakesIndependenceWrapper) algorithmImpl).getIndependenceWrapper().getParameters());
}
}
return paramNamesSet;
}
public static Set getAllScoreParameters(List algorithms) {
Set paramNamesSet = new HashSet<>();
for (AlgorithmSpec algorithm : algorithms) {
Algorithm algorithmImpl = algorithm.getAlgorithmImpl();
if (algorithmImpl instanceof UsesScoreWrapper) {
paramNamesSet.addAll(((UsesScoreWrapper) algorithmImpl).getScoreWrapper().getParameters());
}
}
return paramNamesSet;
}
@NotNull
public static Set getAllBootstrapParameters(List algorithms) {
Set paramNamesSet = new HashSet<>();
for (AlgorithmSpec algorithm : algorithms) {
paramNamesSet.addAll(Params.getBootstrappingParameters(algorithm.getAlgorithmImpl()));
}
return paramNamesSet;
}
/**
* Runs the comparison of simulations, algorithms, and statistics.
*
* @param ps1 A print stream to write the verbose output.
* @param ps2 A print stream to write the verbose output.
*/
public void runComparison(PrintStream ps1, PrintStream ps2) {
initializeIfNull();
Simulations simulations = new Simulations();
if (suppliedData != null) {
simulations.add(new SingleDatasetSimulation(suppliedData));
} else {
for (SimulationSpec simulation : getSelectedSimulationsSpecs())
simulations.add(simulation.getSimulationImpl());
}
Algorithms algorithms = new Algorithms();
for (AlgorithmSpec algorithm : getSelectedAlgorithmSpecs()) algorithms.add(algorithm.getAlgorithmImpl());
Comparison comparison = new Comparison();
comparison.setSaveData(parameters.getBoolean("algcomparisonSaveData"));
comparison.setSaveGraphs(parameters.getBoolean("algcomparisonSaveGraphs"));
comparison.setSaveCPDAGs(parameters.getBoolean("algcomparisonSaveCPDAGs"));
comparison.setSavePags(parameters.getBoolean("algcomparisonSavePAGs"));
comparison.setSortByUtility(parameters.getBoolean("algcomparisonSortByUtility"));
comparison.setShowUtilities(parameters.getBoolean("algcomparisonShowUtilities"));
comparison.setSetAlgorithmKnowledge(parameters.getBoolean("algcomparisonSetAlgorithmKnowledge"));
comparison.setParallelism(parameters.getInt("algcomparisonParallelism"));
comparison.setKnowledge(knowledge);
String string = parameters.getString("algcomparisonGraphType", "DAG");
ComparisonGraphType type = ComparisonGraphType.valueOf(string);
switch (type) {
case DAG -> comparison.setComparisonGraph(Comparison.ComparisonGraph.true_DAG);
case CPDAG -> comparison.setComparisonGraph(Comparison.ComparisonGraph.CPDAG_of_the_true_DAG);
case PAG -> comparison.setComparisonGraph(Comparison.ComparisonGraph.PAG_of_the_true_DAG);
default -> throw new IllegalArgumentException("Invalid value for comparison graph: " + type);
}
String resultsPath;
for (int i = 1; ; i++) {
String pathname = System.getProperty("user.home") + "/comparison-results/comparison-" + i;
File resultsDir = new File(pathname);
if (!resultsDir.exists()) {
if (!resultsDir.mkdirs()) {
throw new IllegalStateException("Could not create directory: " + resultsDir);
}
resultsPath = pathname;
break;
}
}
// Making a copy of the parameters to send to Comparison since Comparison iterates
// over the parameters and modifies them.
String outputFileName = "Comparison.txt";
comparison.compareFromSimulations(resultsPath, simulations, outputFileName, ps1, ps2,
algorithms, getSelectedStatistics(), new Parameters(parameters));
this.resultsPath = resultsPath;
}
private LinkedList getSelectedAlgorithmSpecs() {
if (!(parameters.get("algcomparison.selectedAlgorithms") instanceof LinkedList>)) {
parameters.set("algcomparison.selectedAlgorithms", new LinkedList());
}
return (LinkedList) parameters.get("algcomparison.selectedAlgorithms");
}
private LinkedList getSelectedSimulationsSpecs() {
if (!(parameters.get("algcomparison.selectedSimulations") instanceof LinkedList>)) {
parameters.set("algcomparison.selectedSimulations", new LinkedList());
}
return (LinkedList) parameters.get("algcomparison.selectedSimulations");
}
/**
* A list of possible simulations.
*/
public List getSimulationName() {
return simNames;
}
/**
* A private instance variable that holds a list of possible Algorithm objects.
*/
public List getAlgorithmsName() {
return algNames;
}
/**
* A private instance variable that holds a list of possible Statistic objects.
*/
public List getStatisticsNames() {
return statNames;
}
/**
* A private instance variable that holds a list of possible parameters.
*
* @return A list of possible parameters.
*/
public Parameters getParameters() {
return parameters;
}
/**
* Add a simulation to the list of selected simulations.
*
* @param simulation The simulation to add.
*/
public void addSimulationSpec(SimulationSpec simulation) {
initializeIfNull();
getSelectedSimulationsSpecs().add(simulation);
}
/**
* Remove the last simulation from the list of selected simulations.
*/
public void removeLastSimulation() {
initializeIfNull();
if (!getSelectedSimulationsSpecs().isEmpty()) {
getSelectedSimulationsSpecs().removeLast();
}
}
/**
* Add an algorithm to the list of selected algorithms.
*
* @param algorithm The algorithm to add.
*/
public void addAlgorithm(AlgorithmSpec algorithm) {
initializeIfNull();
getSelectedAlgorithmSpecs().add(algorithm);
}
/**
* Remove the last algorithm from the list of selected algorithms.
*/
public void removeLastAlgorithm() {
initializeIfNull();
LinkedList selectedSimulationsSpecs = getSelectedAlgorithmSpecs();
if (!selectedSimulationsSpecs.isEmpty()) {
getSelectedAlgorithmSpecs().removeLast();
}
}
/**
* Add a table column to the list of selected table columns.
*
* @param tableColumn The table column to add.
*/
public void addTableColumn(MyTableColumn tableColumn) {
if (getSelectedTableColumnsPrivate().contains(tableColumn)) return;
initializeIfNull();
getSelectedTableColumnsPrivate().add(tableColumn);
GridSearchModel.sortTableColumns(getSelectedTableColumnsPrivate());
}
/**
* Remove the last statistic from the list of selected statistics.
*/
public void removeLastTableColumn() {
initializeIfNull();
if (!getSelectedTableColumnsPrivate().isEmpty()) {
getSelectedTableColumnsPrivate().removeLast();
}
}
/**
* Returns the list of statistics classes.
*/
public List> getStatisticsClasses() {
initializeIfNull();
return new ArrayList<>(statisticsClasses);
}
/**
* Returns the selected graph. This is set by the editor when the user selects a graph.
*
* @return The selected graph.
*/
@Override
public Graph getGraph() {
return selectedGraph == null ? new EdgeListGraph() : selectedGraph;
}
/**
* Returns the name of the session model.
*
* @return The name of the session model.
*/
@Override
public String getName() {
return name;
}
/**
* Sets the name of the session model.
*
* @param name the name of the session model.
* @throws IllegalArgumentException if the name is null or blank.
*/
@Override
public void setName(String name) {
if (name == null) {
throw new IllegalArgumentException("Name must not be null.");
}
if (name.isBlank()) {
throw new IllegalArgumentException("Name must not be blank.");
}
this.name = name;
}
/**
* The currently selected simulation in the GridSearchModel. A list of size one (enforced) that contains the
* selected simulation.
*/
public Simulations getSelectedSimulations() {
initializeIfNull();
Simulations simulations = new Simulations();
if (suppliedData != null) {
simulations.add(new SingleDatasetSimulation(suppliedData));
} else {
for (SimulationSpec simulation : getSelectedSimulationsSpecs())
simulations.add(simulation.getSimulationImpl());
}
return simulations;
}
/**
* A private instance variable that holds a list of selected Algorithm objects.
*/
public List getSelectedAlgorithms() {
if (!(parameters.get("algcomparison.selectedAlgorithms") instanceof LinkedList>)) {
parameters.set("algcomparison.selectedAlgorithms", new LinkedList());
}
return (LinkedList) parameters.get("algcomparison.selectedAlgorithms");
}
public List getSelectedTableColumns() {
GridSearchModel.sortTableColumns(getSelectedTableColumnsPrivate());
return new ArrayList<>(getSelectedTableColumnsPrivate());
}
private LinkedList getSelectedTableColumnsPrivate() {
if (!(parameters.get("algcomparison.selectedTableColumns") instanceof LinkedList>)) {
parameters.set("algcomparison.selectedTableColumns", new LinkedList());
}
return (LinkedList) parameters.get("algcomparison.selectedTableColumns");
}
/**
* Initializes the necessary variables if they are null.
*
* This method checks if the selectedSimulations, selectedAlgorithms, selectedStatistics, and selectedParameters
* variables are null. If any of them is null, it calls the initializeSimulationsEtc() method to initialize them.
*
* It also checks if the selectedParameters variable is null. If it is null, it initializes it as an empty
* LinkedList.
*
* It then checks if the simulationClasses, algorithmClasses, and statisticsClasses variables are null. If any of
* them is null, it calls the initializeClasses() method to initialize them.
*
* Finally, it checks if the algNames, statNames, and simNames variables are null. If any of them is null, it calls
* the initializeNames() method to initialize them.
*/
private void initializeIfNull() {
initializeClasses();
initializeNames();
}
private List getSelectedParameters() {
if (!(parameters.get("algcomparison.selectedParameters") instanceof LinkedList>)) {
parameters.set("algcomparison.selectedParameters", new LinkedList());
}
return (LinkedList) parameters.get("algcomparison.selectedParameters");
}
/**
* Initializes the necessary simulation, algorithm, and statistics classes.
*/
private void initializeClasses() {
simulationClasses = findSimulationClasses();
algorithmClasses = findAlgorithmClasses();
statisticsClasses = findStatisticsClasses();
}
/**
* Initializes the names of algorithms, statistics, and simulations.
*/
private void initializeNames() {
algNames = getAlgorithmNamesFromAnnotations(algorithmClasses);
statNames = getStatisticsNamesFromImplementations(statisticsClasses);
simNames = getSimulationNamesFromImplementations(simulationClasses);
this.algNames.sort(String.CASE_INSENSITIVE_ORDER);
this.statNames.sort(String.CASE_INSENSITIVE_ORDER);
this.simNames.sort(String.CASE_INSENSITIVE_ORDER);
}
/**
* Finds and returns a list of simulation classes that implement the Simulation interface.
*
* @return A list of simulation classes.
*/
@NotNull
private List> findSimulationClasses() {
Set> _simulations = findImplementations("edu.cmu.tetrad.algcomparison.simulation",
Simulation.class);
final List> simulationClasses = new ArrayList<>(_simulations);
simulationClasses.sort(Comparator.comparing(Class::getName));
return simulationClasses;
}
/**
* Finds and returns a list of classes that implement the Statistic interface within the specified package.
*
* @return A list of Statistic classes.
*/
private List> findStatisticsClasses() {
Set> _statistics = findImplementations("edu.cmu.tetrad.algcomparison.statistic",
Statistic.class);
final List> statisticsClasses = new ArrayList<>(_statistics);
statisticsClasses.sort(Comparator.comparing(Class::getName));
return statisticsClasses;
}
/**
* For each algorithm class, use reflection to get the annotation for that class, and add the name of the algorithm
* to a list of algorithm names.
*/
private List getAlgorithmNamesFromAnnotations(List> algorithmClasses) {
List algorithmNames = new ArrayList<>();
for (Class extends Algorithm> algorithm : algorithmClasses) {
edu.cmu.tetrad.annotation.Algorithm algAnnotation = algorithm.getAnnotation(edu.cmu.tetrad.annotation.Algorithm.class);
if (algAnnotation != null) {
String _name = algAnnotation.name();
algorithmNames.add(_name);
}
}
return algorithmNames;
}
/**
* Retrieves the abbreviations of statistics from a list of implementation classes.
*
* @param algorithmClasses The list of implementation classes for statistics.
* @return The abbreviations of the statistics.
*/
private List getStatisticsNamesFromImplementations(List> algorithmClasses) {
List statisticsNames = new ArrayList<>();
for (Class extends Statistic> statistic : algorithmClasses) {
try {
Constructor>[] constructors = statistic.getDeclaredConstructors();
boolean hasNoArgConstructor = false;
for (Constructor> constructor : constructors) {
if (constructor.getParameterCount() == 0) {
hasNoArgConstructor = true;
break;
}
}
if (hasNoArgConstructor) {
Statistic _statistic = statistic.getConstructor().newInstance();
String abbreviation = _statistic.getAbbreviation();
statisticsNames.add(abbreviation);
}
} catch (NoSuchMethodException | InvocationTargetException | InstantiationException |
IllegalAccessException e) {
TetradLogger.getInstance().log("Error creating statistic: " + e.getMessage());
e.printStackTrace();
}
}
return statisticsNames;
}
/**
* Retrieves the names of simulations from a list of implementation classes.
*
* @param algorithmClasses The list of implementation classes for simulations.
* @return The names of the simulations.
*/
private List getSimulationNamesFromImplementations(List> algorithmClasses) {
List simulationNames = new ArrayList<>();
RandomGraph graph = new RandomForward();
for (Class extends Simulation> statistic : algorithmClasses) {
try {
Simulation _statistic = statistic.getConstructor(RandomGraph.class).newInstance(graph);
String shortName = _statistic.getShortName();
simulationNames.add(shortName);
} catch (NoSuchMethodException | InvocationTargetException | InstantiationException |
IllegalAccessException e) {
// Skip.
}
}
return simulationNames;
}
public Statistics getSelectedStatistics() {
LinkedList selectedTableColumns = getSelectedTableColumnsPrivate();
Statistics selectedStatistics = new Statistics();
List lastStatisticsUsed = new ArrayList<>();
for (MyTableColumn column : selectedTableColumns) {
if (column.getType() == MyTableColumn.ColumnType.STATISTIC) {
try {
Constructor>[] constructors = column.getStatistic().getDeclaredConstructors();
boolean hasNoArgConstructor = false;
for (Constructor> constructor : constructors) {
if (constructor.getParameterCount() == 0) {
hasNoArgConstructor = true;
break;
}
}
if (hasNoArgConstructor) {
Statistic statistic = column.getStatistic().getConstructor().newInstance();
selectedStatistics.add(statistic);
lastStatisticsUsed.add(statistic);
}
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
NoSuchMethodException ex) {
System.out.println("Error creating statistic: " + ex.getMessage());
}
} else if (column.getType() == MyTableColumn.ColumnType.PARAMETER) {
String parameter = column.getParameter();
selectedStatistics.add(new ParameterColumn(parameter));
}
}
for (Statistic statistic : selectedStatistics.getStatistics()) {
double weight = 0;
try {
weight = parameters.getDouble("algcomparison." + statistic.getAbbreviation());
} catch (Exception e) {
// Skip.
}
selectedStatistics.setWeight(statistic.getAbbreviation(), weight);
}
setLastStatisticsUsed(lastStatisticsUsed);
return selectedStatistics;
}
@NotNull
public List getAllTableColumns() {
List allTableColumns = new ArrayList<>();
List simulations = getSelectedSimulations().getSimulations();
List algorithms = getSelectedAlgorithms();
for (String name : getAllSimulationParameters(simulations)) {
ParamDescription paramDescription = ParamDescriptions.getInstance().get(name);
String shortDescriptiom = paramDescription.getShortDescription();
String description = paramDescription.getLongDescription();
MyTableColumn column = new MyTableColumn(shortDescriptiom, description, name);
column.setSetByUser(paramSetByUser(name));
allTableColumns.add(column);
}
for (String name : getAllAlgorithmParameters(algorithms)) {
ParamDescription paramDescription = ParamDescriptions.getInstance().get(name);
String shortDescriptiom = paramDescription.getShortDescription();
String description = paramDescription.getLongDescription();
MyTableColumn column = new MyTableColumn(shortDescriptiom, description, name);
column.setSetByUser(paramSetByUser(name));
allTableColumns.add(column);
}
for (String name : getAllTestParameters(algorithms)) {
ParamDescription paramDescription = ParamDescriptions.getInstance().get(name);
String shortDescriptiom = paramDescription.getShortDescription();
String description = paramDescription.getLongDescription();
MyTableColumn column = new MyTableColumn(shortDescriptiom, description, name);
column.setSetByUser(paramSetByUser(name));
allTableColumns.add(column);
}
for (String name : getAllScoreParameters(algorithms)) {
ParamDescription paramDescription = ParamDescriptions.getInstance().get(name);
String shortDescriptiom = paramDescription.getShortDescription();
String description = paramDescription.getLongDescription();
MyTableColumn column = new MyTableColumn(shortDescriptiom, description, name);
column.setSetByUser(paramSetByUser(name));
allTableColumns.add(column);
}
for (String name : getAllBootstrapParameters(algorithms)) {
ParamDescription paramDescription = ParamDescriptions.getInstance().get(name);
String shortDescriptiom = paramDescription.getShortDescription();
String description = paramDescription.getLongDescription();
MyTableColumn column = new MyTableColumn(shortDescriptiom, description, name);
column.setSetByUser(paramSetByUser(name));
allTableColumns.add(column);
}
List> statisticClasses = getStatisticsClasses();
for (Class extends Statistic> statisticClass : statisticClasses) {
try {
Constructor>[] constructors = statisticClass.getDeclaredConstructors();
boolean hasNoArgConstructor = false;
for (Constructor> constructor : constructors) {
if (constructor.getParameterCount() == 0) {
hasNoArgConstructor = true;
break;
}
}
if (hasNoArgConstructor) {
Statistic statistic = statisticClass.getConstructor().newInstance();
MyTableColumn column = new MyTableColumn(statistic.getAbbreviation(), statistic.getDescription(), statisticClass);
allTableColumns.add(column);
}
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
NoSuchMethodException ex) {
System.out.println("Error creating statistic: " + ex.getMessage());
}
}
return allTableColumns;
}
private boolean paramSetByUser(String columnName) {
ParamDescription paramDescription = ParamDescriptions.getInstance().get(columnName);
Object defaultValue = paramDescription.getDefaultValue();
Object[] values = parameters.getValues(columnName);
boolean userDefault = values != null && values.length == 1 && values[0].equals(defaultValue);
return !userDefault;
}
public List getLastStatisticsUsed() {
String[] lastStatisticsUsed = Preferences.userRoot().get("lastAlgcomparisonStatisticsUsed", "").split(";");
return Arrays.asList(lastStatisticsUsed);
}
public void setLastStatisticsUsed(List lastStatisticsUsed) {
StringBuilder sb = new StringBuilder();
for (Statistic statistic : lastStatisticsUsed) {
sb.append(statistic.getAbbreviation()).append(";");
}
// System.out.println("Setting last statistics used: " + sb);
Preferences.userRoot().put("lastAlgcomparisonStatisticsUsed", sb.toString());
}
public String getLastIndependenceTest() {
return Preferences.userRoot().get("lastAlgcomparisonIndependenceTestUsed", "");
}
public void setLastIndependenceTest(String name) {
IndependenceTestModels independenceTestModels = IndependenceTestModels.getInstance();
List models = independenceTestModels.getModels();
for (IndependenceTestModel model : models) {
if (model.getName().equals(name)) {
Preferences.userRoot().put("lastAlgcomparisonIndependenceTestUsed", name);
return;
}
}
throw new IllegalArgumentException("Independence test by that name not found: " + name);
}
public String getLastScore() {
return Preferences.userRoot().get("lastAlgcomparisonScoreUsed", "");
}
public void setLastScore(String name) {
ScoreModels scoreModels = ScoreModels.getInstance();
List models = scoreModels.getModels();
for (ScoreModel model : models) {
if (model.getName().equals(name)) {
Preferences.userRoot().put("lastAlgcomparisonScoreUsed", name);
return;
}
}
throw new IllegalArgumentException("Score by that name not found: " + name);
}
public String getLastGraphChoice() {
return Preferences.userRoot().get("lastAlgcomparisonGraphChoice", "");
}
public void setLastGraphChoice(String name) {
Preferences.userRoot().put("lastAlgcomparisonGraphChoice", name);
}
public String getLastAlgorithmChoice() {
return Preferences.userRoot().get("lastAlgcomparisonAlgorithmChoice", "");
}
public void setLastAlgorithmChoice(String name) {
Preferences.userRoot().put("lastAlgcomparisonAlgorithmChoice", name);
}
public Object getLastSimulationChoice() {
return Preferences.userRoot().get("lastAlgcomparisonSimulationChoice", "");
}
public void setLastSimulationChoice(String selectedItem) {
Preferences.userRoot().put("lastAlgcomparisonSimulationChoice", selectedItem);
}
/**
* The suppliedGraph variable represents a graph that can be supplied by the user. This graph will be given as an
* option in the user interface.
*/
public Graph getSuppliedGraph() {
return suppliedGraph;
}
/**
* The last comparison text displayed.
*/
public String getLastComparisonText() {
return lastComparisonText == null ? "" : lastComparisonText;
}
public void setLastComparisonText(String lastComparisonText) {
this.lastComparisonText = lastComparisonText;
}
/**
* The last verbose output displayed.
*/
public String getLastVerboseOutputText() {
return lastVerboseOutputText == null ? "" : lastVerboseOutputText;
}
public void setLastVerboseOutputText(String lastVerboseOutputText) {
this.lastVerboseOutputText = lastVerboseOutputText;
}
/**
* Writes the object to the specified ObjectOutputStream.
*
* @param out The ObjectOutputStream to write the object to.
* @throws IOException If an I/O error occurs.
*/
@Serial
private void writeObject(ObjectOutputStream out) throws IOException {
try {
out.defaultWriteObject();
} catch (IOException e) {
TetradLogger.getInstance().log("Failed to serialize object: " + getClass().getCanonicalName()
+ ", " + e.getMessage());
throw e;
}
}
/**
* Reads the object from the specified ObjectInputStream. This method is used during deserialization to restore the
* state of the object.
*
* @param in The ObjectInputStream to read the object from.
* @throws IOException If an I/O error occurs.
* @throws ClassNotFoundException If the class of the serialized object cannot be found.
*/
@Serial
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
try {
in.defaultReadObject();
} catch (IOException e) {
TetradLogger.getInstance().log("Failed to deserialize object: " + getClass().getCanonicalName()
+ ", " + e.getMessage());
throw e;
}
}
/**
* Represents the variable "knowledge" in the GridSearchModel class. This variable is of type Knowledge and is
* private and final.
*/
public Knowledge getKnowledge() {
return knowledge;
}
/**
* The suppliedData variable represents a dataset that can be used in place of a simulated dataset for analysis. It
* can be set to null if no dataset is supplied.
*
* Using a supplied dataset restricts the analysis to only those statistics that do not require a true graph.
*
* Example usage:
*
* DataSet dataset = new DataSet();
* suppliedData = dataset;
*
*/
public DataSet getSuppliedData() {
return suppliedData;
}
/**
* The variable resultsPath represents the path to the result folder. This is set after a comparison has been run
* and can be used to add additional files to the comparison results.
*/
public String getResultsPath() {
return resultsPath;
}
public void setResultsPath(String resultsPath) {
this.resultsPath = resultsPath;
}
public void setSelectedGraph(Graph graph) {
this.selectedGraph = graph;
}
public int getSelectedSimulation() {
return selectedSimulation;
}
public void setSelectedSimulation(int selectedSimulation) {
this.selectedSimulation = selectedSimulation;
}
public int getSelectedAlgorithm() {
return selectedAlgorithm;
}
public void setSelectedAlgorithm(int selectedAlgorithm) {
this.selectedAlgorithm = selectedAlgorithm;
}
public int getSelectedGraphIndex() {
return selectedGraphIndex;
}
public void setSelectedGraphIndex(int selectedGraphIndex) {
this.selectedGraphIndex = selectedGraphIndex;
}
public void getVerboseOut(PrintStream printStream) {
this.verboseOut = printStream;
}
/**
* This class represents the comparison graph type for graph-based comparison algorithms. ComparisonGraphType is an
* enumeration type that represents different types of comparison graphs. The available types are DAG (Directed
* Acyclic Graph), CPDAG (Completed Partially Directed Acyclic Graph), and PAG (Partially Directed Acyclic Graph).
*/
public enum ComparisonGraphType {
/**
* Directed Acyclic Graph (DAG).
*/
DAG,
/**
* Completed Partially Directed Acyclic Graph (CPDAG).
*/
CPDAG,
/**
* Partially Directed Acyclic Graph (PAG).
*/
PAG
}
public static class MyTableColumn implements TetradSerializable {
@Serial
private static final long serialVersionUID = 23L;
/**
* The name of the column.
*/
private final String columnName;
/**
* The description of the column.
*/
private final String description;
/**
* The statistic class.
*/
private final Class extends Statistic> statistic;
/**
* The parameter name.
*/
private final String parameter;
/**
* The type of the column.
*/
private final ColumnType type;
/**
* A boolean that indicates whether the column was set by the user.
*/
private boolean setByUser = false;
public MyTableColumn(String name, String description, Class extends Statistic> statistic) {
this.columnName = name;
this.description = description;
this.statistic = statistic;
this.parameter = null;
this.type = ColumnType.STATISTIC;
}
public MyTableColumn(String name, String description, String parameter) {
this.columnName = name;
this.description = description;
this.statistic = null;
this.parameter = parameter;
this.type = ColumnType.PARAMETER;
}
public String getDescription() {
return description;
}
public String getColumnName() {
return columnName;
}
public String getdescription() {
return description;
}
public ColumnType getType() {
return type;
}
public Class extends Statistic> getStatistic() {
if (type != ColumnType.STATISTIC) throw new IllegalStateException("Not a statistic column");
return statistic;
}
public String getParameter() {
if (type != ColumnType.PARAMETER) throw new IllegalStateException("Not a parameter column");
return parameter;
}
public boolean isSetByUser() {
return setByUser;
}
public void setSetByUser(boolean setByUser) {
this.setByUser = setByUser;
}
public int hashCode() {
return columnName.hashCode();
}
public boolean equals(Object obj) {
if (obj instanceof MyTableColumn other) {
return columnName.equals(other.columnName);
}
return false;
}
public enum ColumnType {
STATISTIC,
PARAMETER
}
}
public static class AlgorithmSpec implements TetradSerializable {
@Serial
private static final long serialVersionUID = 23L;
/**
* The name of the algorithm.
*/
private final String name;
/**
* The algorithm model.
*/
private final AlgorithmModel algorithm;
/**
* The test of independence.
*/
private final AnnotatedClass test;
/**
* The score.
*/
private final AnnotatedClass score;
/**
* Constructs a new AlgorithmSpec object with the specified name, algorithm model, test of independence, and
*
* @param name The name of the algorithm.
* @param algorithm The algorithm model.
* @param test The test of independence.
* @param score The score.
*/
public AlgorithmSpec(String name, AlgorithmModel algorithm,
AnnotatedClass test, AnnotatedClass score) {
this.name = name;
this.algorithm = algorithm;
this.test = test;
this.score = score;
}
public String getName() {
return name;
}
public AlgorithmModel getAlgorithm() {
return algorithm;
}
public AnnotatedClass getTest() {
return test;
}
public AnnotatedClass getScore() {
return score;
}
public Algorithm getAlgorithmImpl() {
try {
IndependenceWrapper independenceWrapper = null;
ScoreWrapper scoreWrapper = null;
if (test != null) {
independenceWrapper = (IndependenceWrapper) test.clazz().getConstructor().newInstance();
}
if (score != null) {
scoreWrapper = (ScoreWrapper) score.clazz().getConstructor().newInstance();
}
Class> _algorithm = algorithm.getAlgorithm().clazz();
Algorithm algorithmImpl = (Algorithm) _algorithm.getConstructor().newInstance();
if (algorithmImpl instanceof TakesIndependenceWrapper && independenceWrapper != null) {
((TakesIndependenceWrapper) algorithmImpl).setIndependenceWrapper(independenceWrapper);
}
if (algorithmImpl instanceof UsesScoreWrapper && scoreWrapper != null) {
((UsesScoreWrapper) algorithmImpl).setScoreWrapper(scoreWrapper);
}
if (algorithmImpl instanceof TakesIndependenceWrapper && independenceWrapper != null) {
((TakesIndependenceWrapper) algorithmImpl).setIndependenceWrapper(independenceWrapper);
}
if (algorithmImpl instanceof UsesScoreWrapper && scoreWrapper != null) {
((UsesScoreWrapper) algorithmImpl).setScoreWrapper(scoreWrapper);
}
return algorithmImpl;
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
NoSuchMethodException ex) {
throw new RuntimeException(ex);
}
}
public String toString() {
return name;
}
}
public static class SimulationSpec implements TetradSerializable {
@Serial
private static final long serialVersionUID = 23L;
/**
* The name of the simulation.
*/
private final String name;
/**
* The class of the graph.
*/
private final Class extends RandomGraph> graphClass;
/**
* The class of the simulation.
*/
private final Class extends Simulation> simulationClass;
public SimulationSpec(String name, Class extends RandomGraph> graph,
Class extends Simulation> simulation) {
this.name = name;
this.graphClass = graph;
this.simulationClass = simulation;
}
public String getName() {
return name;
}
public Simulation getSimulationImpl() {
try {
RandomGraph randomGraph = graphClass.getConstructor().newInstance();
return simulationClass.getConstructor(RandomGraph.class).newInstance(randomGraph);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException |
NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
public String toString() {
return name;
}
}
}