edu.cmu.tetrad.search.work_in_progress.IndTestFisherZRecursive Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search.work_in_progress;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndTestFisherZ;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.utils.PartialCorrelation;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;
import java.text.DecimalFormat;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import static java.lang.StrictMath.log;
import static org.apache.commons.math3.util.FastMath.abs;
import static org.apache.commons.math3.util.FastMath.sqrt;
/**
* Checks conditional independence of variable in a continuous data set using Fisher's Z test. See Spirtes, Glymour, and
* Scheines, "Causation, Prediction and Search," 2nd edition, page 94.
*
* @author josephramsey
* @author Frank Wimberly adapted IndTestCramerT for Fisher's Z
* @version $Id: $Id
*/
public final class IndTestFisherZRecursive implements IndependenceTest {
/**
* The covariance matrix.
*/
private final ICovarianceMatrix covMatrix;
private final Map indexMap;
private final Map nameMap;
private final NormalDistribution normal = new NormalDistribution(0, 1);
private final PartialCorrelation recursivePartialCorrelation;
/**
* The variables of the covariance matrix, in order. (Unmodifiable list.)
*/
private List variables;
/**
* The significance level of the independence tests.
*/
private double alpha;
/**
* Stores a reference to the dataset being analyzed.
*/
private DataSet dataSet;
private boolean verbose = true;
private double fisherZ = Double.NaN;
private double cutoff = Double.NaN;
//==========================CONSTRUCTORS=============================//
/**
* Constructs a new Independence test which checks independence facts based on the correlation matrix implied by the
* given data set (must be continuous). The given significance level is used.
*
* @param dataSet A data set containing only continuous columns.
* @param alpha The alpha level of the test.
*/
public IndTestFisherZRecursive(DataSet dataSet, double alpha) {
if (!(dataSet.isContinuous())) {
throw new IllegalArgumentException("Data set must be continuous.");
}
if (!(alpha >= 0 && alpha <= 1)) {
throw new IllegalArgumentException("Alpha mut be in [0, 1]");
}
this.covMatrix = new CovarianceMatrix(dataSet);
List nodes = this.covMatrix.getVariables();
this.variables = Collections.unmodifiableList(nodes);
this.indexMap = indexMap(this.variables);
this.nameMap = nameMap(this.variables);
setAlpha(alpha);
this.dataSet = dataSet;
this.recursivePartialCorrelation = new PartialCorrelation(this.covMatrix);
}
/**
* Constructs a new Fisher Z independence test with the listed arguments.
*
* @param data A 2D continuous data set with no missing values.
* @param variables A list of variables, a subset of the variables of data
.
* @param alpha The significance cutoff level. p values less than alpha will be reported as dependent.
*/
public IndTestFisherZRecursive(Matrix data, List variables, double alpha) {
this.dataSet = new BoxDataSet(new VerticalDoubleDataBox(data.transpose().toArray()), variables);
this.covMatrix = new CovarianceMatrix(this.dataSet);
this.variables = Collections.unmodifiableList(variables);
this.indexMap = indexMap(variables);
this.nameMap = nameMap(variables);
setAlpha(alpha);
this.recursivePartialCorrelation = new PartialCorrelation(this.covMatrix);
}
/**
* Constructs a new independence test that will determine conditional independence facts using the given correlation
* matrix and the given significance level.
*
* @param covMatrix a {@link edu.cmu.tetrad.data.ICovarianceMatrix} object
* @param alpha a double
*/
public IndTestFisherZRecursive(ICovarianceMatrix covMatrix, double alpha) {
this.covMatrix = covMatrix;
this.variables = covMatrix.getVariables();
this.indexMap = indexMap(this.variables);
this.nameMap = nameMap(this.variables);
setAlpha(alpha);
this.recursivePartialCorrelation = new PartialCorrelation(this.covMatrix);
}
//==========================PUBLIC METHODS=============================//
/**
* {@inheritDoc}
*
* Creates a new independence test instance for a subset of the variables.
*/
public IndependenceTest indTestSubset(List vars) {
if (vars.isEmpty()) {
throw new IllegalArgumentException("Subset may not be empty.");
}
for (Node var : vars) {
if (!this.variables.contains(var)) {
throw new IllegalArgumentException(
"All vars must be original vars");
}
}
int[] indices = new int[vars.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = this.indexMap.get(vars.get(i));
}
ICovarianceMatrix newCovMatrix = this.covMatrix.getSubmatrix(indices);
double alphaNew = getAlpha();
return new IndTestFisherZ(newCovMatrix, alphaNew);
}
/**
* {@inheritDoc}
*
* Determines whether variable x is independent of variable y given a list of conditioning variables z.
*
* @param x a {@link edu.cmu.tetrad.graph.Node} object
* @param y a {@link edu.cmu.tetrad.graph.Node} object
* @param z a {@link java.util.Set} object
* @return a {@link edu.cmu.tetrad.search.test.IndependenceResult} object
*/
public IndependenceResult checkIndependence(Node x, Node y, Set z) {
int n = sampleSize();
double r;
try {
r = partialCorrelation(x, y, z);
} catch (SingularMatrixException e) {
throw new RuntimeException("Singularity encountered when testing " +
LogUtilsSearch.independenceFact(x, y, z));
// this.fisherZ = Double.POSITIVE_INFINITY;
// return new IndependenceResult(new IndependenceFact(x, y, z), false, Double.NaN, Double.NaN);
}
double q = 0.5 * (log(1.0 + r) - FastMath.log(1.0 - r));
double fisherZ = sqrt(n - 3 - z.size()) * abs(q);
this.fisherZ = fisherZ;
if (Double.isNaN(fisherZ)) {
throw new RuntimeException("NaN Fisher's Z encountered when testing " +
LogUtilsSearch.independenceFact(x, y, z));
}
boolean independent = fisherZ < this.cutoff;
if (this.verbose) {
if (independent) {
TetradLogger.getInstance().log(
LogUtilsSearch.independenceFactMsg(x, y, z, getPValue()));
}
}
return new IndependenceResult(new IndependenceFact(x, y, z), independent, Double.NaN, abs(this.fisherZ) - this.cutoff);
}
private double partialCorrelation(Node x, Node y, Set _z) throws SingularMatrixException {
List z = new ArrayList<>(_z);
Collections.sort(z);
return this.recursivePartialCorrelation.corr(x, y, z);
}
/**
* getPValue.
*
* @return the probability associated with the most recently computed independence test.
*/
public double getPValue() {
return 2.0 * (1.0 - this.normal.cumulativeProbability(abs(this.fisherZ)));
}
/**
* Gets the getModel significance level.
*
* @return a double
*/
public double getAlpha() {
return this.alpha;
}
/**
* {@inheritDoc}
*
* Sets the significance level at which independence judgments should be made. Affects the cutoff for partial
* correlations to be considered statistically equal to zero.
*/
public void setAlpha(double alpha) {
if (alpha < 0.0 || alpha > 1.0) {
throw new IllegalArgumentException("Significance out of range: " + alpha);
}
this.alpha = alpha;
this.cutoff = StatUtils.getZForAlpha(alpha);
}
/**
*
Getter for the field variables
.
*
* @return the list of variables over which this independence checker is capable of determinine independence
* relations-- that is, all the variables in the given graph or the given data set.
*/
public List getVariables() {
return this.variables;
}
/**
* Setter for the field variables
.
*
* @param variables a {@link java.util.List} object
*/
public void setVariables(List variables) {
if (variables.size() != this.variables.size()) throw new IllegalArgumentException("Wrong # of variables.");
this.variables = new ArrayList<>(variables);
this.covMatrix.setVariables(variables);
}
/**
* {@inheritDoc}
*/
public Node getVariable(String name) {
return this.nameMap.get(name);
}
/**
* {@inheritDoc}
*
* If isDeterminismAllowed()
, deters to IndTestFisherZD; otherwise throws
* UnsupportedOperationException.
*/
public boolean determines(Set _z, Node x) throws UnsupportedOperationException {
int[] parents = new int[_z.size()];
List z = new ArrayList<>(_z);
Collections.sort(z);
for (int j = 0; j < parents.length; j++) {
parents[j] = this.covMatrix.getVariables().indexOf(z.get(j));
}
if (parents.length > 0) {
// Regress z onto i, yielding regression coefficients b.
Matrix Czz = this.covMatrix.getSelection(parents, parents);
try {
Czz.inverse();
} catch (SingularMatrixException e) {
System.out.println(LogUtilsSearch.determinismDetected(_z, x));
return true;
}
}
return false;
}
/**
* getData.
*
* @return the data set being analyzed.
*/
public DataSet getData() {
return this.dataSet;
}
//==========================PRIVATE METHODS============================//
/**
* toString.
*
* @return a string representation of this test.
*/
public String toString() {
return "Fisher Z, alpha = " + new DecimalFormat("0.0E0").format(getAlpha());
}
private int sampleSize() {
return covMatrix().getSampleSize();
}
private ICovarianceMatrix covMatrix() {
return this.covMatrix;
}
private Map nameMap(List variables) {
Map nameMap = new ConcurrentHashMap<>();
for (Node node : variables) {
nameMap.put(node.getName(), node);
}
return nameMap;
}
private Map indexMap(List variables) {
Map indexMap = new ConcurrentHashMap<>();
for (int i = 0; i < variables.size(); i++) {
indexMap.put(variables.get(i), i);
}
return indexMap;
}
/**
* getCov.
*
* @return a {@link edu.cmu.tetrad.data.ICovarianceMatrix} object
*/
public ICovarianceMatrix getCov() {
return this.covMatrix;
}
/**
* {@inheritDoc}
*/
@Override
public List getDataSets() {
List dataSets = new ArrayList<>();
dataSets.add(this.dataSet);
return dataSets;
}
/**
* {@inheritDoc}
*/
@Override
public int getSampleSize() {
return this.covMatrix.getSampleSize();
}
/**
* isVerbose.
*
* @return a boolean
*/
public boolean isVerbose() {
return this.verbose;
}
/**
* {@inheritDoc}
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
}