edu.cmu.tetradapp.editor.ScatterPlot Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.editor;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import java.awt.geom.Point2D;
import java.util.*;
import static java.lang.Math.abs;
import static java.lang.Math.log;
/**
* This is the scatterplot model class holding the necessary information to
* create a scatterplot. It uses Point2D to hold the pair of values need to
* create the scatterplot.
*
* @author Adrian Tang
* @author Joseph Ramsey
*/
public class ScatterPlot {
private final String x;
private final String y;
private final boolean includeLine;
private final DataSet dataSet;
private Map continuousIntervals;
/**
* Constructor.
*
* @param includeLine whether or not to include the regression line in the
* plot.
* @param x y-axis variable name.
* @param y x-axis variable name.
*/
public ScatterPlot(
DataSet dataSet,
boolean includeLine,
String x,
String y) {
this.dataSet = dataSet;
this.x = x;
this.y = y;
this.includeLine = includeLine;
this.continuousIntervals = new HashMap<>();
}
private RegressionResult getRegressionResult() {
List regressors = new ArrayList<>();
regressors.add(this.dataSet.getVariable(this.x));
Node target = this.dataSet.getVariable(this.y);
Regression regression = new RegressionDataset(this.dataSet);
RegressionResult result = regression.regress(target, regressors);
System.out.println(result);
return result;
}
public double getCorrelationCoeff() {
DataSet dataSet = getDataSet();
Matrix data = dataSet.getDoubleData();
int _x = dataSet.getColumn(dataSet.getVariable(this.x));
int _y = dataSet.getColumn(dataSet.getVariable(this.y));
double[] xdata = data.getColumn(_x).toArray();
double[] ydata = data.getColumn(_y).toArray();
double correlation = StatUtils.correlation(xdata, ydata);
if (correlation > 1) correlation = 1;
else if (correlation < -1) correlation = -1;
return correlation;
}
/**
* @return the p-value of the correlation coefficient statistics.
*/
public double getCorrelationPValue() {
double r = getCorrelationCoeff();
double fisherZ = fisherz(r);
double pValue;
if (Double.isInfinite(fisherZ)) {
pValue = 0;
} else {
pValue = 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0, 1, abs(fisherZ)));
}
return pValue;
}
private double fisherz(double r) {
return 0.5 * Math.sqrt(getSampleSize() - 3.0) * (log(1.0 + r) - log(1.0 - r));
}
/**
* @return the minimum x-axis value from the set of sample values.
*/
public double getXmin() {
double min = Double.POSITIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
min = Math.min(min, cleanedSampleValue.getX());
}
return min;
}
/**
* @return the minimum y-axis value from the set of sample values.
*/
public double getYmin() {
double min = Double.POSITIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
min = Math.min(min, cleanedSampleValue.getY());
}
return min;
}
/**
* @return the maximum x-axis value from the set of sample values.
*/
public double getXmax() {
double max = Double.NEGATIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
max = Math.max(max, cleanedSampleValue.getX());
}
return max;
}
/**
* @return the maximum y-axis value from the set of sample values.
*/
public double getYmax() {
double max = Double.NEGATIVE_INFINITY;
Vector cleanedSampleValues = getSievedValues();
for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
max = Math.max(max, cleanedSampleValue.getY());
}
return max;
}
/**
* Seives through the sample values and grabs only the values for the
* response and predictor variables.
*
* @return a vector containing the filtered values.
*/
public Vector getSievedValues() {
return pairs(this.x, this.y);
}
/**
* @return size of the sample.
*/
private int getSampleSize() {
return getSievedValues().size();
}
/**
* @return the name of the predictor variable.
*/
public String getXvar() {
return this.x;
}
/**
* @return the name of the response variable.
*/
public String getYvar() {
return this.y;
}
/**
* @return whether or not to include the regression line.
*/
public boolean isIncludeLine() {
return this.includeLine;
}
/**
* Calculates the regression coefficient for the variables
* return a regression coeff
*/
public double getRegressionCoeff() {
return getRegressionResult().getCoef()[1];
}
/**
* @return the zero intercept of the regression equation.
*/
public double getRegressionIntercept() {
return getRegressionResult().getCoef()[0];
}
public DataSet getDataSet() {
return this.dataSet;
}
//========================================PUBLIC METHODS=================================//
/**
* Adds a continuous conditioning variables, conditioning on a range of values.
*
* @param variable The name of the variable in the data set.
* @param low The low end of the conditioning range.
* @param high The high end of the conditioning range.
*/
public void addConditioningVariable(String variable, double low, double high) {
if (!(low < high)) throw new IllegalArgumentException("Low must be less than high: " + low + " >= " + high);
Node node = this.dataSet.getVariable(variable);
if (!(node instanceof ContinuousVariable)) throw new IllegalArgumentException("Variable must be continuous.");
if (this.continuousIntervals.containsKey(node))
throw new IllegalArgumentException("Please remove conditioning variable first.");
this.continuousIntervals.put(node, new double[]{low, high});
}
/**
* Removes a conditioning variable.
*
* @param variable The name of the conditioning variable to remove.
*/
public void removeConditioningVariable(String variable) {
Node node = this.dataSet.getVariable(variable);
if (!(this.continuousIntervals.containsKey(node))) {
throw new IllegalArgumentException("Not a conditioning node: " + variable);
}
this.continuousIntervals.remove(node);
}
public void removeConditioningVariables() {
this.continuousIntervals = new HashMap<>();
}
/**
* For a continuous target, returns the number of values histogrammed. This may be
* less than the sample size of the data set because of conditioning.
*/
public int getN(String target) {
List conditionedDataContinuous = getConditionedDataContinuous(target);
return conditionedDataContinuous.size();
}
/**
* A convenience method to return the data for a particular named continuous
* variable.
*
* @param variable The name of the variable.
*/
public double[] getContinuousData(String variable) {
int index = this.dataSet.getColumn(this.dataSet.getVariable(variable));
List _data = new ArrayList<>();
for (int i = 0; i < this.dataSet.getNumRows(); i++) {
_data.add(this.dataSet.getDouble(i, index));
}
return asDoubleArray(_data);
}
//======================================PRIVATE METHODS=======================================//
private double[] asDoubleArray(List data) {
double[] _data = new double[data.size()];
for (int i = 0; i < data.size(); i++) _data[i] = data.get(i);
return _data;
}
private List getUnconditionedDataContinuous(String target) {
int index = this.dataSet.getColumn(this.dataSet.getVariable(target));
List _data = new ArrayList<>();
for (int i = 0; i < this.dataSet.getNumRows(); i++) {
_data.add(this.dataSet.getDouble(i, index));
}
return _data;
}
private List getConditionedDataContinuous(String target) {
if (this.continuousIntervals == null) return getUnconditionedDataContinuous(target);
List rows = getConditionedRows();
int index = this.dataSet.getColumn(this.dataSet.getVariable(target));
List _data = new ArrayList<>();
for (Integer row : rows) {
_data.add(this.dataSet.getDouble(row, index));
}
return _data;
}
// Returns the rows in the data that satisfy the conditioning constraints.
private List getConditionedRows() {
List rows = new ArrayList<>();
I:
for (int i = 0; i < this.dataSet.getNumRows(); i++) {
for (Node node : this.continuousIntervals.keySet()) {
double[] range = this.continuousIntervals.get(node);
int index = this.dataSet.getColumn(node);
double value = this.dataSet.getDouble(i, index);
if (!(value > range[0] && value < range[1])) {
continue I;
}
}
rows.add(i);
}
return rows;
}
private Vector pairs(String x, String y) {
Point2D.Double pt;
Vector cleanedVals = new Vector<>();
List _x = getConditionedDataContinuous(x);
List _y = getConditionedDataContinuous(y);
for (int row = 0; row < _x.size(); row++) {
pt = new Point2D.Double();
pt.setLocation(_x.get(row), _y.get(row));
cleanedVals.add(pt);
}
return cleanedVals;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy