edu.cmu.tetradapp.editor.AlgorithmParameterPanel Maven / Gradle / Ivy
/*
* Copyright (C) 2017 University of Pittsburgh.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
*/
package edu.cmu.tetradapp.editor;
import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.PagSampleRfci;
import edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.RfciBsc;
import edu.cmu.tetrad.algcomparison.utils.TakesIndependenceWrapper;
import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
import edu.cmu.tetrad.annotation.Score;
import edu.cmu.tetrad.annotation.TestOfIndependence;
import edu.cmu.tetrad.util.ParamDescription;
import edu.cmu.tetrad.util.ParamDescriptions;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.cmu.tetradapp.model.GeneralAlgorithmRunner;
import edu.cmu.tetradapp.ui.PaddingPanel;
import edu.cmu.tetradapp.util.DoubleTextField;
import edu.cmu.tetradapp.util.IntTextField;
import edu.cmu.tetradapp.util.LongTextField;
import edu.cmu.tetradapp.util.StringTextField;
import javax.swing.*;
import java.awt.*;
import java.text.DecimalFormat;
import java.util.List;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Dec 4, 2017 5:05:42 PM
*
* @author Kevin V. Bui ([email protected])
*/
public class AlgorithmParameterPanel extends JPanel {
private static final long serialVersionUID = 274638263704283474L;
protected final JPanel mainPanel = new JPanel();
public AlgorithmParameterPanel() {
initComponents();
}
private void initComponents() {
this.mainPanel.setLayout(new BoxLayout(this.mainPanel, BoxLayout.Y_AXIS));
setLayout(new BorderLayout());
add(this.mainPanel, BorderLayout.NORTH);
}
public void addToPanel(GeneralAlgorithmRunner algorithmRunner) {
this.mainPanel.removeAll();
Algorithm algorithm = algorithmRunner.getAlgorithm();
Parameters parameters = algorithmRunner.getParameters();
// Hard-coded parameter groups for Rfci-Bsc
if (algorithm instanceof RfciBsc) {
// Phase one: PAG and constraints candidates Searching
String title = algorithm
.getClass().getAnnotation(edu.cmu.tetrad.annotation.Algorithm.class).name();
Set params = new LinkedHashSet<>();
// RFCI
params.add(Params.DEPTH);
params.add(Params.MAX_PATH_LENGTH);
params.add(Params.COMPLETE_RULE_SET_USED);
params.add(Params.VERBOSE);
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
// Stage one: PAG and constraints candidates Searching
title = "Stage One: PAG and constraints candidates Searching";
params = new LinkedHashSet<>();
// Thresholds
params.add(Params.NUM_RANDOMIZED_SEARCH_MODELS);
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
// Stage two: Bayesian Scoring of Constraints
title = "Stage Two: Bayesian Scoring of Constraints";
params = new LinkedHashSet<>();
params.add(Params.NUM_BSC_BOOTSTRAP_SAMPLES);
params.add(Params.THRESHOLD_NO_RANDOM_CONSTRAIN_SEARCH);
//params.add(Params.CUTOFF_CONSTRAIN_SEARCH);
params.add(Params.LOWER_BOUND);
params.add(Params.UPPER_BOUND);
params.add(Params.OUTPUT_RBD);
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
} else if (algorithm instanceof PagSampleRfci) {
String title = algorithm.getClass().getAnnotation(edu.cmu.tetrad.annotation.Algorithm.class).name();
Set params = new LinkedHashSet<>();
params.add(Params.NUM_RANDOMIZED_SEARCH_MODELS);
params.add(Params.VERBOSE);
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
title = "RFCI Parameters";
params.clear();
params.addAll(PagSampleRfci.RFCI_PARAMETERS);
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
title = "Probabilistic Test Parameters";
params.clear();
params.addAll(PagSampleRfci.PROBABILISTIC_TEST_PARAMETERS);
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
} else {
// add algorithm parameters
Set params = Params.getAlgorithmParameters(algorithm);
if (!params.isEmpty()) {
String title = algorithm
.getClass().getAnnotation(edu.cmu.tetrad.annotation.Algorithm.class).name();
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
}
params = Params.getScoreParameters(algorithm);
if (!params.isEmpty()) {
String title = ((UsesScoreWrapper) algorithm).getScoreWrapper()
.getClass().getAnnotation(Score.class).name();
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
}
params = Params.getTestParameters(algorithm);
if (!params.isEmpty()) {
String title = ((TakesIndependenceWrapper) algorithm).getIndependenceWrapper()
.getClass().getAnnotation(TestOfIndependence.class).name();
this.mainPanel.add(createSubPanel(title, params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
}
if (algorithmRunner.getSourceGraph() == null) {
params = Params.getBootstrappingParameters(algorithm);
if (!params.isEmpty()) {
this.mainPanel.add(createSubPanel("Bootstrapping", params, parameters));
this.mainPanel.add(Box.createVerticalStrut(10));
}
}
}
}
protected Box[] toArray(Map parameterComponents) {
ParamDescriptions paramDescs = ParamDescriptions.getInstance();
List boolComps = new LinkedList<>();
List otherComps = new LinkedList<>();
parameterComponents.forEach((k, v) -> {
if (paramDescs.get(k).getDefaultValue() instanceof Boolean) {
boolComps.add(v);
} else {
otherComps.add(v);
}
});
return Stream.concat(otherComps.stream(), boolComps.stream())
.toArray(Box[]::new);
}
protected Map createParameterComponents(Set params, Parameters parameters) {
ParamDescriptions paramDescs = ParamDescriptions.getInstance();
return params.stream()
.collect(Collectors.toMap(
Function.identity(),
e -> createParameterComponent(e, parameters, paramDescs.get(e)),
(u, v) -> {
throw new IllegalStateException(String.format("Duplicate key %s.", u));
},
TreeMap::new));
}
protected Box createParameterComponent(String parameter, Parameters parameters, ParamDescription paramDesc) {
JComponent component;
Object defaultValue = paramDesc.getDefaultValue();
if (defaultValue instanceof Double) {
double lowerBoundDouble = paramDesc.getLowerBoundDouble();
double upperBoundDouble = paramDesc.getUpperBoundDouble();
component = getDoubleField(parameter, parameters, (Double) defaultValue, lowerBoundDouble, upperBoundDouble);
} else if (defaultValue instanceof Integer) {
int lowerBoundInt = paramDesc.getLowerBoundInt();
int upperBoundInt = paramDesc.getUpperBoundInt();
component = getIntTextField(parameter, parameters, (Integer) defaultValue, lowerBoundInt, upperBoundInt);
} else if (defaultValue instanceof Long) {
long lowerBoundLong = paramDesc.getLowerBoundLong();
long upperBoundLong = paramDesc.getUpperBoundLong();
component = getLongTextField(parameter, parameters, (Long) defaultValue, lowerBoundLong, upperBoundLong);
} else if (defaultValue instanceof Boolean) {
component = getBooleanSelectionBox(parameter, parameters, (Boolean) defaultValue);
} else if (defaultValue instanceof String) {
component = getStringField(parameter, parameters, (String) defaultValue);
} else {
throw new IllegalArgumentException("Unexpected type: " + defaultValue.getClass());
}
Box paramRow = Box.createHorizontalBox();
JLabel paramLabel = new JLabel(paramDesc.getShortDescription());
String longDescription = paramDesc.getLongDescription();
if (longDescription != null) {
paramLabel.setToolTipText(longDescription);
}
paramRow.add(paramLabel);
paramRow.add(Box.createHorizontalGlue());
paramRow.add(component);
return paramRow;
}
protected JPanel createSubPanel(String title, Set params, Parameters parameters) {
JPanel panel = new JPanel(new BorderLayout());
panel.setBorder(BorderFactory.createTitledBorder(title));
Box paramsBox = Box.createVerticalBox();
Box[] boxes = toArray(createParameterComponents(params, parameters));
int lastIndex = boxes.length - 1;
for (int i = 0; i < lastIndex; i++) {
paramsBox.add(boxes[i]);
paramsBox.add(Box.createVerticalStrut(10));
}
paramsBox.add(boxes[lastIndex]);
panel.add(new PaddingPanel(paramsBox), BorderLayout.CENTER);
return panel;
}
protected DoubleTextField getDoubleField(String parameter, Parameters parameters,
double defaultValue, double lowerBound, double upperBound) {
DoubleTextField field = new DoubleTextField(parameters.getDouble(parameter, defaultValue),
8, new DecimalFormat("0.####"), new DecimalFormat("0.0#E0"), 0.001);
field.setFilter((value, oldValue) -> {
if (value == field.getValue()) {
return oldValue;
}
if (value < lowerBound) {
return oldValue;
}
if (value > upperBound) {
return oldValue;
}
try {
parameters.set(parameter, value);
} catch (Exception e) {
// Ignore.
}
return value;
});
return field;
}
protected IntTextField getIntTextField(String parameter, Parameters parameters,
int defaultValue, double lowerBound, double upperBound) {
IntTextField field = new IntTextField(parameters.getInt(parameter, defaultValue), 8);
field.setFilter((value, oldValue) -> {
if (value == field.getValue()) {
return oldValue;
}
if (value < lowerBound) {
return oldValue;
}
if (value > upperBound) {
return oldValue;
}
try {
parameters.set(parameter, value);
} catch (Exception e) {
// Ignore.
}
return value;
});
return field;
}
protected LongTextField getLongTextField(String parameter, Parameters parameters,
long defaultValue, long lowerBound, long upperBound) {
LongTextField field = new LongTextField(parameters.getLong(parameter, defaultValue), 8);
field.setFilter((value, oldValue) -> {
if (value == field.getValue()) {
return oldValue;
}
if (value < lowerBound) {
return oldValue;
}
if (value > upperBound) {
return oldValue;
}
try {
parameters.set(parameter, value);
} catch (Exception e) {
// Ignore.
}
return value;
});
return field;
}
// Zhou's new implementation with yes/no radio buttons
protected Box getBooleanSelectionBox(String parameter, Parameters parameters, boolean defaultValue) {
Box selectionBox = Box.createHorizontalBox();
JRadioButton yesButton = new JRadioButton("Yes");
JRadioButton noButton = new JRadioButton("No");
// Button group to ensure only only one option can be selected
ButtonGroup selectionBtnGrp = new ButtonGroup();
selectionBtnGrp.add(yesButton);
selectionBtnGrp.add(noButton);
boolean aBoolean = parameters.getBoolean(parameter, defaultValue);
// Set default selection
if (aBoolean) {
yesButton.setSelected(true);
} else {
noButton.setSelected(true);
}
// Add to containing box
selectionBox.add(yesButton);
selectionBox.add(noButton);
// Event listener
yesButton.addActionListener((e) -> {
JRadioButton button = (JRadioButton) e.getSource();
if (button.isSelected()) {
parameters.set(parameter, true);
}
});
// Event listener
noButton.addActionListener((e) -> {
JRadioButton button = (JRadioButton) e.getSource();
if (button.isSelected()) {
parameters.set(parameter, false);
}
});
return selectionBox;
}
protected StringTextField getStringField(String parameter, Parameters parameters, String defaultValue) {
StringTextField field = new StringTextField(parameters.getString(parameter, defaultValue), 20);
field.setFilter((value, oldValue) -> {
if (value.equals(field.getValue().trim())) {
return oldValue;
}
try {
parameters.set(parameter, value);
} catch (Exception e) {
// Ignore.
}
return value;
});
return field;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy