edu.cmu.tetradapp.editor.PlotMatrix Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.editor;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Histogram;
import edu.cmu.tetrad.graph.Node;
import javax.swing.*;
import java.awt.*;
import java.awt.event.InputEvent;
import java.awt.event.KeyEvent;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Implements a matrix of scatterplots and histograms for variables that users can select from a list.
*
* @author Adrian Tang
* @author josephramsey
*/
public class PlotMatrix extends JPanel {
private JPanel charts;
private JList rowSelector;
private JList colSelector;
private int numBins = 9;
private boolean addRegressionLines = false;
private boolean removeZeroPointsPerPlot = false;
private int[] lastRows = new int[]{0};
private int[] lastCols = new int[]{0};
private Map conditioningPanelMap = new HashMap<>();
private ScatterPlot.JitterStyle jitterStyle = ScatterPlot.JitterStyle.None;
public PlotMatrix(DataSet dataSet) {
setLayout(new BorderLayout());
List nodes = dataSet.getVariables();
Collections.sort(nodes);
Collections.sort(nodes);
Node[] _vars = new Node[nodes.size()];
for (int i = 0; i < nodes.size(); i++) _vars[i] = nodes.get(i);
this.rowSelector = new JList<>(_vars);
this.colSelector = new JList<>(_vars);
this.rowSelector.setSelectedIndex(0);
this.colSelector.setSelectedIndex(0);
charts = new JPanel();
this.rowSelector.addListSelectionListener(e ->
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot()));
this.colSelector.addListSelectionListener(e ->
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot()));
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
JMenuBar menuBar = new JMenuBar();
JMenu settings = new JMenu("Settings");
menuBar.add(settings);
JMenuItem addTrendLines = new JCheckBoxMenuItem("Add Trend Lines");
addTrendLines.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_Y, InputEvent.CTRL_DOWN_MASK));
addTrendLines.setSelected(false);
settings.add(addTrendLines);
JMenuItem removeZeroPointsPerPlot = new JCheckBoxMenuItem("Remove Zero Points Per Plot");
removeZeroPointsPerPlot.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_Z, InputEvent.CTRL_DOWN_MASK));
removeZeroPointsPerPlot.setSelected(false);
settings.add(removeZeroPointsPerPlot);
removeZeroPointsPerPlot.addActionListener(e -> {
setRemoveMinPointsPerPlot(!isRemoveTrendLinesPerPlot());
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
addTrendLines.addActionListener(e -> {
setAddRegressionLines(!isAddRegressionLines());
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
JMenuItem numBins = new JMenu("Set number of Bins for Histograms");
ButtonGroup group = new ButtonGroup();
for (int i = 2; i <= 30; i++) {
int _i = i;
JMenuItem comp = new JCheckBoxMenuItem(String.valueOf(i));
numBins.add(comp);
group.add(comp);
if (i == getNumBins()) comp.setSelected(true);
comp.addActionListener(e -> {
setNumBins(_i);
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
}
settings.add(numBins);
JMenu jitterDiscrete = new JMenu("Jitter Style (Display Only)");
final JMenuItem menuItem1 = new JCheckBoxMenuItem(ScatterPlot.JitterStyle.Gaussian.toString());
final JMenuItem menuItem2 = new JCheckBoxMenuItem(ScatterPlot.JitterStyle.Uniform.toString());
final JMenuItem menuItem3 = new JCheckBoxMenuItem(ScatterPlot.JitterStyle.None.toString());
menuItem1.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_U, InputEvent.CTRL_DOWN_MASK));
menuItem2.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_I, InputEvent.CTRL_DOWN_MASK));
menuItem3.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_O, InputEvent.CTRL_DOWN_MASK));
ButtonGroup group1 = new ButtonGroup();
group1.add(menuItem1);
group1.add(menuItem2);
group1.add(menuItem3);
menuItem3.setSelected(true);
jitterDiscrete.add(menuItem1);
jitterDiscrete.add(menuItem2);
jitterDiscrete.add(menuItem3);
menuItem1.addActionListener(e -> {
this.jitterStyle = ScatterPlot.JitterStyle.Gaussian;
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
menuItem2.addActionListener(e -> {
this.jitterStyle = ScatterPlot.JitterStyle.Uniform;
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
menuItem3.addActionListener(e -> {
this.jitterStyle = ScatterPlot.JitterStyle.None;
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
settings.add(jitterDiscrete);
JMenuItem editConditioning = new JMenuItem("Edit Conditioning Variables...");
editConditioning.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_P, InputEvent.CTRL_DOWN_MASK));
editConditioning.addActionListener(e -> {
VariableConditioningEditor conditioningEditor
= new VariableConditioningEditor(dataSet, conditioningPanelMap);
conditioningEditor.setPreferredSize(new Dimension(300, 300));
JOptionPane.showMessageDialog(PlotMatrix.this, new JScrollPane(conditioningEditor));
conditioningPanelMap = conditioningEditor.getConditioningPanelMap();
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
});
settings.add(editConditioning);
add(menuBar, BorderLayout.NORTH);
Box b1 = Box.createHorizontalBox();
JScrollPane comp2 = new JScrollPane(charts);
comp2.setPreferredSize(new Dimension(750, 750));
b1.add(comp2);
Box b3 = Box.createVerticalBox();
b3.add(new JLabel("Rows"));
b3.add(new JScrollPane(this.rowSelector));
Box b4 = Box.createVerticalBox();
b4.add(new JLabel("Cols"));
b4.add(new JScrollPane(this.colSelector));
b1.add(b3);
b1.add(b4);
add(b1, BorderLayout.CENTER);
setPreferredSize(new Dimension(750, 450));
}
private void setRemoveMinPointsPerPlot(boolean removeZeroPointsPerPlot) {
this.removeZeroPointsPerPlot = removeZeroPointsPerPlot;
}
private void constructPlotMatrix(JPanel charts, DataSet dataSet, List nodes, JList rowSelector,
JList colSelector, boolean removeZeroPointsPerPlot) {
int[] rowIndices = rowSelector.getSelectedIndices();
int[] colIndices = colSelector.getSelectedIndices();
charts.removeAll();
charts.setLayout(new GridLayout(rowIndices.length, colIndices.length));
for (int rowIndex : rowIndices) {
for (int colIndex : colIndices) {
if (rowIndex == colIndex) {
Histogram histogram = new Histogram(dataSet, nodes.get(rowIndex).getName(), removeZeroPointsPerPlot);
// histogram.setTarget(nodes.get(rowIndex).getName());
for (Node node : conditioningPanelMap.keySet()) {
if (node instanceof ContinuousVariable) {
ContinuousVariable var = (ContinuousVariable) node;
VariableConditioningEditor.ContinuousConditioningPanel panel
= (VariableConditioningEditor.ContinuousConditioningPanel)
conditioningPanelMap.get(var);
histogram.addConditioningVariable(var.getName(), panel.getLow(), panel.getHigh());
} else if (node instanceof DiscreteVariable) {
DiscreteVariable var = (DiscreteVariable) node;
VariableConditioningEditor.DiscreteConditioningPanel panel
= (VariableConditioningEditor.DiscreteConditioningPanel)
conditioningPanelMap.get(var);
histogram.addConditioningVariable(var.getName(), panel.getIndex());
}
}
if (!(nodes.get(rowIndex) instanceof DiscreteVariable)) {
histogram.setNumBins(numBins);
}
HistogramPanel panel = new HistogramPanel(histogram,
rowIndices.length == 1 && colIndices.length == 1);
panel.setMinimumSize(new Dimension(10, 10));
addPanelListener(charts, dataSet, nodes, rowIndex, colIndex, panel);
charts.add(panel);
} else {
ScatterPlot scatterPlot = new ScatterPlot(dataSet, addRegressionLines, nodes.get(colIndex).getName(),
nodes.get(rowIndex).getName(), removeZeroPointsPerPlot);
for (Node node : conditioningPanelMap.keySet()) {
if (node instanceof ContinuousVariable) {
ContinuousVariable var = (ContinuousVariable) node;
VariableConditioningEditor.ContinuousConditioningPanel panel
= (VariableConditioningEditor.ContinuousConditioningPanel)
conditioningPanelMap.get(var);
scatterPlot.addConditioningVariable(var.getName(), panel.getLow(), panel.getHigh());
} else if (node instanceof DiscreteVariable) {
DiscreteVariable var = (DiscreteVariable) node;
VariableConditioningEditor.DiscreteConditioningPanel panel
= (VariableConditioningEditor.DiscreteConditioningPanel)
conditioningPanelMap.get(var);
scatterPlot.addConditioningVariable(var.getName(), panel.getIndex());
}
}
scatterPlot.setJitterStyle(jitterStyle);
ScatterplotPanel panel = new ScatterplotPanel(scatterPlot, removeZeroPointsPerPlot);
panel.setDrawAxes(rowIndices.length == 1 && colIndices.length == 1);
panel.setMinimumSize(new Dimension(10, 10));
int pointSize = 5;
if (rowIndices.length > 2 || colIndices.length > 2) pointSize = 4;
if (rowIndices.length > 3 || colIndices.length > 3) pointSize = 3;
if (rowIndices.length > 5 || colIndices.length > 5) pointSize = 2;
panel.setPointSize(pointSize);
addPanelListener(charts, dataSet, nodes, rowIndex, colIndex, panel);
charts.add(panel);
}
}
}
revalidate();
repaint();
}
private void addPanelListener(JPanel charts, DataSet dataSet, List nodes, int rowIndex, int colIndex, JPanel panel) {
panel.addMouseListener(new MouseAdapter() {
@Override
public void mouseClicked(MouseEvent e) {
// if (e.getClickCount() == 1) {
if (rowSelector.getSelectedIndices().length == 1
&& colSelector.getSelectedIndices().length == 1) {
rowSelector.setSelectedIndices(lastRows);
colSelector.setSelectedIndices(lastCols);
lastRows = new int[]{rowIndex};
lastCols = new int[]{colIndex};
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
} else {
lastRows = rowSelector.getSelectedIndices();
lastCols = colSelector.getSelectedIndices();
rowSelector.setSelectedIndex(rowIndex);
colSelector.setSelectedIndex(colIndex);
constructPlotMatrix(charts, dataSet, nodes, rowSelector, colSelector, isRemoveTrendLinesPerPlot());
}
// }
}
});
}
public int getNumBins() {
return numBins;
}
public void setNumBins(int numBins) {
this.numBins = numBins;
}
public boolean isAddRegressionLines() {
return addRegressionLines;
}
public void setAddRegressionLines(boolean addRegressionLines) {
this.addRegressionLines = addRegressionLines;
}
public boolean isRemoveTrendLinesPerPlot() {
return removeZeroPointsPerPlot;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy