edu.cmu.tetradapp.editor.ScatterPlot Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.editor;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import org.apache.commons.math3.util.FastMath;
import java.awt.geom.Point2D;
import java.util.*;
import static org.apache.commons.math3.util.FastMath.abs;
/**
* This is the scatterplot model class holding the necessary information to create a scatterplot. It uses Point2D to
* hold the pair of values need to create the scatterplot.
*
* @author Adrian Tang
* @author josephramsey
*/
public class ScatterPlot {
private final String x;
private final String y;
private final boolean includeLine;
private final DataSet dataSet;
private final Map continuousIntervals;
private final Map discreteValues;
private final Node _x;
private final Node _y;
private boolean removeZeroPointsPerPlot;
private JitterStyle jitterStyle = JitterStyle.None;
/**
* Constructor.
*
* @param includeLine whether to include the regression line in the plot.
* @param x y-axis variable name.
* @param y x-axis variable name.
* @param removeZeroPointsPerPlot whether to remove zero points per plot.
*/
public ScatterPlot(DataSet dataSet, boolean includeLine, String x, String y, boolean removeZeroPointsPerPlot) {
this.dataSet = dataSet;
this.x = x;
this.y = y;
_x = this.dataSet.getVariable(this.x);
_y = this.dataSet.getVariable(this.y);
this.includeLine = includeLine;
this.continuousIntervals = new HashMap<>();
this.discreteValues = new HashMap<>();
this.removeZeroPointsPerPlot = removeZeroPointsPerPlot;
}
public void setJitterStyle(JitterStyle jitterStyle) {
this.jitterStyle = jitterStyle;
}
private RegressionResult getRegressionResult() {
List regressors = new ArrayList<>();
regressors.add(_x);
RegressionDataset regression = new RegressionDataset(this.dataSet);
List conditionedRows = getConditionedRows();
int[] _conditionedRows = new int[conditionedRows.size()];
for (int i = 0; i < conditionedRows.size(); i++) _conditionedRows[i] = conditionedRows.get(i);
regression.setRows(_conditionedRows);
return regression.regress(_y, regressors);
}
public double getCorrelationCoeff() {
DataSet dataSet = getDataSet();
Matrix data = dataSet.getDoubleData();
int _x = dataSet.getColumn(dataSet.getVariable(this.x));
int _y = dataSet.getColumn(dataSet.getVariable(this.y));
double[] xdata = data.getColumn(_x).toArray();
double[] ydata = data.getColumn(_y).toArray();
Result result = new Result(xdata, ydata, removeZeroPointsPerPlot);
xdata = result.xdata;
ydata = result.ydata;
double correlation = StatUtils.correlation(xdata, ydata);
if (correlation > 1) correlation = 1;
else if (correlation < -1) correlation = -1;
return correlation;
}
private static class Result {
public double[] xdata;
public double[] ydata;
public Result(double[] xdata, double[] ydata, boolean removeZeroPointsPerPlot) {
this.xdata = xdata;
this.ydata = ydata;
if (removeZeroPointsPerPlot) {
List x = new ArrayList<>();
List y = new ArrayList<>();
for (int i = 0; i < xdata.length; i++) {
if (xdata[i] != 0 && ydata[i] != 0) {
x.add(xdata[i]);
y.add(ydata[i]);
}
}
this.xdata = new double[x.size()];
this.ydata = new double[y.size()];
for (int i = 0; i < x.size(); i++) {
this.xdata[i] = x.get(i);
this.ydata[i] = y.get(i);
}
}
}
}
/**
* @return the p-value of the correlation coefficient statistics.
*/
public double getCorrelationPValue() {
double r = getCorrelationCoeff();
double fisherZ = fisherz(r);
double pValue;
if (Double.isInfinite(fisherZ)) {
pValue = 0;
} else {
pValue = 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0, 1, abs(fisherZ)));
}
return pValue;
}
private double fisherz(double r) {
return 0.5 * FastMath.sqrt(getSampleSize() - 3.0) * (FastMath.log(1.0 + r) - FastMath.log(1.0 - r));
}
/**
* @return the minimum x-axis value from the set of sample values.
*/
public double getXmin() {
double min = Double.POSITIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
min = FastMath.min(min, cleanedSampleValue.getX());
}
return min;
}
/**
* @return the minimum y-axis value from the set of sample values.
*/
public double getYmin() {
double min = Double.POSITIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
min = FastMath.min(min, cleanedSampleValue.getY());
}
return min;
}
/**
* @return the maximum x-axis value from the set of sample values.
*/
public double getXmax() {
double max = Double.NEGATIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
max = FastMath.max(max, cleanedSampleValue.getX());
}
return max;
}
/**
* @return the maximum y-axis value from the set of sample values.
*/
public double getYmax() {
double max = Double.NEGATIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
max = FastMath.max(max, cleanedSampleValue.getY());
}
return max;
}
/**
* Sieves through the sample values and grabs only the values for the response and predictor variables.
*
* @return a vector containing the filtered values.
*/
public Vector getSievedValues() {
return pairs(this.x, this.y);
}
/**
* @return size of the sample.
*/
private int getSampleSize() {
return getSievedValues().size();
}
/**
* @return the name of the predictor variable.
*/
public String getXvar() {
return this.x;
}
/**
* @return the name of the response variable.
*/
public String getYvar() {
return this.y;
}
/**
* @return whether to include the regression line.
*/
public boolean isIncludeLine() {
return this.includeLine;
}
/**
* Calculates the regression coefficient for the variables return a regression coefficient.
*/
public double getRegressionCoeff() {
return getRegressionResult().getCoef()[1];
}
/**
* @return the zero intercept of the regression equation.
*/
public double getRegressionIntercept() {
return getRegressionResult().getCoef()[0];
}
public DataSet getDataSet() {
return this.dataSet;
}
/**
* Adds a continuous conditioning variables, conditioning on a range of values.
*
* @param variable The name of the variable in the data set.
* @param low The low end of the conditioning range.
* @param high The high end of the conditioning range.
*/
public void addConditioningVariable(String variable, double low, double high) {
if (!(low < high)) throw new IllegalArgumentException("Low must be less than high: " + low + " >= " + high);
Node node = this.dataSet.getVariable(variable);
if (!(node instanceof ContinuousVariable)) throw new IllegalArgumentException("Variable must be continuous.");
if (this.continuousIntervals.containsKey(node))
throw new IllegalArgumentException("Please remove conditioning variable first.");
this.continuousIntervals.put(node, new double[]{low, high});
}
/**
* Adds a discrete conditioning variable, conditioning on a particular value.
*
* @param variable The name of the variable in the data set.
* @param value The value to condition on.
*/
public void addConditioningVariable(String variable, int value) {
Node node = this.dataSet.getVariable(variable);
// if (node == this.target) throw new IllegalArgumentException("Conditioning node may not be the target.");
if (!(node instanceof DiscreteVariable)) throw new IllegalArgumentException("Variable must be discrete.");
this.discreteValues.put(node, value);
}
private List getUnconditionedDataContinuous(String target) {
int index = this.dataSet.getColumn(this.dataSet.getVariable(target));
List _data = new ArrayList<>();
for (int i = 0; i < this.dataSet.getNumRows(); i++) {
_data.add(this.dataSet.getDouble(i, index));
}
return _data;
}
//======================================PRIVATE METHODS=======================================//
private List getConditionedDataContinuous(String target) {
if (this.continuousIntervals == null) return getUnconditionedDataContinuous(target);
List rows = getConditionedRows();
int index = this.dataSet.getColumn(this.dataSet.getVariable(target));
List _data = new ArrayList<>();
for (Integer row : rows) {
_data.add(this.dataSet.getDouble(row, index));
}
return _data;
}
// Returns the rows in the data that satisfy the conditioning constraints.
private List getConditionedRows() {
List rows = new ArrayList<>();
I:
for (int i = 0; i < this.dataSet.getNumRows(); i++) {
for (Node node : this.continuousIntervals.keySet()) {
double[] range = this.continuousIntervals.get(node);
int index = this.dataSet.getColumn(node);
double value = this.dataSet.getDouble(i, index);
if (!(value >= range[0] && value <= range[1])) {
continue I;
}
}
for (Node node : this.discreteValues.keySet()) {
int value = this.discreteValues.get(node);
int index = this.dataSet.getColumn(node);
int _value = this.dataSet.getInt(i, index);
if (!(value == _value)) {
continue I;
}
}
rows.add(i);
}
return rows;
}
private Vector pairs(String x, String y) {
Point2D.Double pt;
Vector cleanedVals = new Vector<>();
List _x = getConditionedDataContinuous(x);
List _y = getConditionedDataContinuous(y);
double spreadx = getRange(_x);
double spready = getRange(_y);
for (int row = 0; row < _x.size(); row++) {
pt = new Point2D.Double();
double x1 = _x.get(row);
double y1 = _y.get(row);
double v = 0.03;
if (jitterStyle == JitterStyle.Gaussian) {
x1 += RandomUtil.getInstance().nextNormal(0, spreadx * v);
} else if (jitterStyle == JitterStyle.Uniform) {
x1 += RandomUtil.getInstance().nextUniform(-2 * spreadx * v, 2 * spreadx * v);
}
if (jitterStyle == JitterStyle.Gaussian) {
y1 += RandomUtil.getInstance().nextNormal(0, spready * v);
} else if (jitterStyle == JitterStyle.Uniform) {
y1 += RandomUtil.getInstance().nextUniform(-2 * spready * v, 2 * spready * v);
}
pt.setLocation(x1, y1);
cleanedVals.add(pt);
}
return cleanedVals;
}
private double getRange(List x) {
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
for (Double d : x) {
if (d < min) min = d;
if (d > max) max = d;
}
return max - min;
}
public enum JitterStyle {None, Gaussian, Uniform}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy