edu.cmu.tetrad.search.Fofc Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.utils.*;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import org.apache.commons.math3.util.FastMath;
import java.util.*;
import static org.apache.commons.math3.util.FastMath.abs;
import static org.apache.commons.math3.util.FastMath.sqrt;
/**
* Implements the Find One Factor Clusters (FOFC) algorithm by Erich Kummerfeld, which uses reasoning about vanishing
* tetrads of algorithms to infer clusters of the measured variables in a dataset that each be explained by a single
* latent variable. A reference is the following
*
* Kummerfeld, E., & Ramsey, J. (2016, August). Causal clustering for 1-factor measurement models. In Proceedings of
* the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 1655-1664).
*
* The algorithm employs tests of vanishing tetrads (list of 4 variables that follow a certain pattern in the
* exchangeability of latent paths with respect to the data). The notion of vanishing tetrads is old one but is
* explained in this book:
*
* Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT press.
*
* @author peterspirtes
* @author erichkummerfeld
* @author josephramsey
* @version $Id: $Id
* @see Ftfc
* @see Bpc
*/
public class Fofc {
/**
* The type of test used.
*/
private final CorrelationMatrix corr;
/**
* The list of all variables.
*/
private final List variables;
/**
* The significance level.
*/
private final double alpha;
/**
* The Delta test. Testing two tetrads simultaneously.
*/
private final DeltaTetradTest test;
/**
* The tetrad test--using Ricardo's. Used only for Wishart.
*/
private final TetradTestContinuous test2;
/**
* The data.
*/
private final transient DataModel dataModel;
/**
* The type of test used.
*/
private final BpcTestType testType;
/**
* The type of FOFC algorithm used.
*/
private final Algorithm algorithm;
/**
* The clusters that are output by the algorithm from the last call to search().
*/
private List> clusters;
/**
* Whether verbose output is desired.
*/
private boolean verbose;
/**
* Whether the significance of the cluster should be checked for each cluster.
*/
private boolean significanceChecked;
/**
* The type of cluster check should be performed.
*/
private ClusterSignificance.CheckType checkType = ClusterSignificance.CheckType.Clique;
/**
* Constructor.
*
* @param cov The covariance matrix searched over.
* @param testType The type of test used.
* @param algorithm The type of FOFC algorithm used.
* @param alpha The alpha significance cutoff.
* @see BpcTestType
* @see Algorithm
*/
public Fofc(ICovarianceMatrix cov, BpcTestType testType, Algorithm algorithm, double alpha) {
if (testType == null) throw new NullPointerException("Null indepTest type.");
cov = new CovarianceMatrix(cov);
this.variables = cov.getVariables();
this.alpha = alpha;
this.testType = testType;
this.test = new DeltaTetradTest(cov);
this.test2 = new TetradTestContinuous(cov, testType, alpha);
this.dataModel = cov;
this.algorithm = algorithm;
this.corr = new CorrelationMatrix(cov);
}
/**
* Conctructor.
*
* @param dataSet The continuous dataset searched over.
* @param testType The type of test used.
* @param algorithm The type of FOFC algorithm used.
* @param alpha The alpha significance cutoff.
* @see BpcTestType
* @see Algorithm
*/
public Fofc(DataSet dataSet, BpcTestType testType, Algorithm algorithm, double alpha) {
if (testType == null) throw new NullPointerException("Null test type.");
this.variables = dataSet.getVariables();
this.alpha = alpha;
this.testType = testType;
this.test = new DeltaTetradTest(dataSet);
this.test2 = new TetradTestContinuous(dataSet, testType, alpha);
this.dataModel = dataSet;
this.algorithm = algorithm;
this.corr = new CorrelationMatrix(dataSet);
}
/**
* Runs the search and returns a graph of clusters with the ir respective latent parents.
*
* @return This graph.
*/
public Graph search() {
Set> allClusters;
if (this.algorithm == Algorithm.SAG) {
allClusters = estimateClustersTetradsFirst();
} else if (this.algorithm == Algorithm.GAP) {
allClusters = estimateClustersTriplesFirst();
} else {
throw new IllegalStateException("Expected SAG or GAP: " + this.testType);
}
this.clusters = ClusterSignificance.variablesForIndices2(allClusters, variables);
System.out.println("allClusters = " + allClusters);
System.out.println("this.clusters = " + this.clusters);
ClusterSignificance clusterSignificance = new ClusterSignificance(variables, dataModel);
clusterSignificance.printClusterPValues(allClusters);
return convertToGraph(allClusters);
}
/**
* Sets whether the significance of the cluster should be checked for each cluster.
*
* @param significanceChecked True, if so.
*/
public void setSignificanceChecked(boolean significanceChecked) {
this.significanceChecked = significanceChecked;
}
/**
* Sets which type of cluster check should be performed.
*
* @param checkType The type to be performed.
* @see ClusterSignificance.CheckType
*/
public void setCheckType(ClusterSignificance.CheckType checkType) {
this.checkType = checkType;
}
/**
* The clusters that are output by the algorithm from the last call to search().
*
* @return a {@link java.util.List} object
*/
public List> getClusters() {
return this.clusters;
}
/**
* Setter for the field verbose
.
*
* @param verbose a boolean
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Returns the index of the variable that occurs most frequently in the given array. (renjiey).
*
* @param outliers An array of integers representing variables.
* @return The index of the most frequently occurring variable.
*/
private int findFrequentestIndex(Integer[] outliers) {
Map map = new HashMap<>();
for (Integer outlier : outliers) {
if (map.containsKey(outlier)) {
map.put(outlier, map.get(outlier) + 1);
} else {
map.put(outlier, 1);
}
}
Set> set = map.entrySet();
Iterator> it = set.iterator();
int nums = 0;// how many times variables occur?
int key = 0;// the number occurs the most times
while (it.hasNext()) {
Map.Entry entry = it.next();
if (entry.getValue() > nums) {
nums = entry.getValue();
key = entry.getKey();
}
}
return (key);
}
/**
* This is the main function. It removes variables in the data such that the remaining correlation matrix does not
* contain extreme value Inputs: correlation matrix, upper and lower bound for unacceptable correlations Output: and
* dynamic array of removed variables renjiey
*/
private ArrayList removeVariables(Matrix correlationMatrix, double lowerBound, double upperBound,
double percentBound) {
Integer[] outlier = new Integer[correlationMatrix.getNumRows() * (correlationMatrix.getNumRows() - 1)];
int count = 0;
for (int i = 2; i < (correlationMatrix.getNumRows() + 1); i++) {
for (int j = 1; j < i; j++) {
if ((abs(correlationMatrix.get(i - 1, j - 1)) < lowerBound)
|| (abs(correlationMatrix.get(i - 1, j - 1)) > upperBound)) {
outlier[count * 2] = i;
outlier[count * 2 + 1] = j;
} else {
outlier[count * 2] = 0;
outlier[count * 2 + 1] = 0;
}
count = count + 1;
}
}
//find out the variables that should be deleted
ArrayList removedVariables = new ArrayList<>();
// Added the percent bound jdramsey
while (outlier.length > 1 && removedVariables.size() < percentBound * correlationMatrix.getNumRows()) {
//find out the variable that occurs most frequently in outlier
int worstVariable = findFrequentestIndex(outlier);
if (worstVariable > 0) {
removedVariables.add(worstVariable);
}
//remove the correlations having the bad variable (change the relevant variables to 0)
for (int i = 1; i < outlier.length + 1; i++) {
if (outlier[i - 1] == worstVariable) {
outlier[i - 1] = 0;
if (i % 2 != 0) {
outlier[i] = 0;
} else {
outlier[i - 2] = 0;
}
}
}
//delete zero elements in outlier
outlier = removeZeroIndex(outlier);
}
log(removedVariables.size() + " variables removed: " + ClusterSignificance.variablesForIndices(removedVariables, variables));
return (removedVariables);
}
/**
* Removes the elements with zero index from the given integer array. (renjiey)
*
* @param outlier The array of integers.
* @return The updated array with zero index elements removed.
*/
private Integer[] removeZeroIndex(Integer[] outlier) {
List list = new ArrayList<>();
Collections.addAll(list, outlier);
for (Integer element : outlier) {
if (element < 1) {
list.remove(element);
}
}
return list.toArray(new Integer[1]);
}
/**
* Estimates clusters using the triples-first algorithm.
*
* @return A set of lists of integers representing the clusters.
*/
private Set> estimateClustersTriplesFirst() {
List _variables = allVariables();
Set> triples = findPuretriples(_variables);
Set> combined = combinePuretriples(triples, _variables);
Set> _combined = new HashSet<>();
for (Set c : combined) {
List a = new ArrayList<>(c);
Collections.sort(a);
_combined.add(a);
}
return _combined;
}
/**
* Retrieves a list of all variables.
*
* @return A list of integers representing all variables.
*/
private List allVariables() {
List _variables = new ArrayList<>();
for (int i = 0; i < this.variables.size(); i++) _variables.add(i);
return _variables;
}
/**
* Estimates clusters using the tetrads-first algorithm.
*
* @return A set of lists of integers representing the clusters.
*/
private Set> estimateClustersTetradsFirst() {
List _variables = allVariables();
Set> pureClusters = findPureClusters(_variables);
Set> mixedClusters = findMixedClusters(_variables, unionPure(pureClusters));
Set> allClusters = new HashSet<>(pureClusters);
allClusters.addAll(mixedClusters);
return allClusters;
}
/**
* Finds pure triples from the given list of variables.
*
* @param allVariables The list of integers representing all variables.
* @return A set of sets of integers representing the pure triples.
*/
private Set> findPuretriples(List allVariables) {
if (allVariables.size() < 4) {
return new HashSet<>();
}
log("Finding pure triples.");
ChoiceGenerator gen = new ChoiceGenerator(allVariables.size(), 3);
int[] choice;
Set> puretriples = new HashSet<>();
CHOICE:
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int n1 = allVariables.get(choice[0]);
int n2 = allVariables.get(choice[1]);
int n3 = allVariables.get(choice[2]);
List triple = triple(n1, n2, n3);
if (zeroCorr(triple)) continue;
for (int o : allVariables) {
if (Thread.currentThread().isInterrupted()) {
break;
}
if (triple.contains(o)) {
continue;
}
List quartet = quartet(n1, n2, n3, o);
boolean vanishes = vanishes(quartet);
if (!vanishes) {
continue CHOICE;
}
}
HashSet _cluster = new HashSet<>(triple);
if (this.verbose) {
log("++" + ClusterSignificance.variablesForIndices(triple, variables));
}
puretriples.add(_cluster);
}
return puretriples;
}
/**
* Combines pure triples with given variables.
*
* @param puretriples The set of pure triples.
* @param _variables The list of variables.
* @return A set of combined clusters.
*/
private Set> combinePuretriples(Set> puretriples, List _variables) {
log("Growing pure triples.");
Set> grown = new HashSet<>();
// Lax grow phase with speedup.
if (true) {
Set t = new HashSet<>();
int count = 0;
int total = puretriples.size();
do {
if (Thread.currentThread().isInterrupted()) {
break;
}
if (!puretriples.iterator().hasNext()) {
break;
}
Set cluster = puretriples.iterator().next();
Set _cluster = new HashSet<>(cluster);
for (int o : _variables) {
if (Thread.currentThread().isInterrupted()) {
break;
}
if (_cluster.contains(o)) continue;
List _cluster2 = new ArrayList<>(_cluster);
int rejected = 0;
int accepted = 0;
ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
t.clear();
t.add(_cluster2.get(choice[0]));
t.add(_cluster2.get(choice[1]));
t.add(o);
if (!puretriples.contains(t)) {
rejected++;
} else {
accepted++;
}
}
if (rejected > accepted) {
continue;
}
_cluster.add(o);
ClusterSignificance clusterSignificance = new ClusterSignificance(variables, dataModel);
clusterSignificance.setCheckType(checkType);
if (significanceChecked && clusterSignificance.significant(_cluster2, alpha)) {
_cluster2.remove(o);
}
}
// This takes out all pure clusters that are subsets of _cluster.
ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
int[] choice2;
List _cluster3 = new ArrayList<>(_cluster);
while ((choice2 = gen2.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int n1 = _cluster3.get(choice2[0]);
int n2 = _cluster3.get(choice2[1]);
int n3 = _cluster3.get(choice2[2]);
t.clear();
t.add(n1);
t.add(n2);
t.add(n3);
puretriples.remove(t);
}
if (this.verbose) {
log("Grown " + (++count) + " of " + total + ": "
+ ClusterSignificance.variablesForIndices(new ArrayList<>(_cluster), variables));
}
grown.add(_cluster);
} while (!puretriples.isEmpty());
}
// Lax grow phase without speedup.
if (false) {
int count = 0;
int total = puretriples.size();
// Optimized lax version of grow phase.
for (Set cluster : new HashSet<>(puretriples)) {
Set _cluster = new HashSet<>(cluster);
for (int o : _variables) {
if (_cluster.contains(o)) continue;
List _cluster2 = new ArrayList<>(_cluster);
int rejected = 0;
int accepted = 0;
ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 4);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int n1 = _cluster2.get(choice[0]);
int n2 = _cluster2.get(choice[1]);
List triple = triple(n1, n2, o);
Set t = new HashSet<>(triple);
if (!puretriples.contains(t)) {
rejected++;
} else {
accepted++;
}
// if (avgSumLnP(triple) < -10) continue CLUSTER;
}
if (rejected > accepted) {
continue;
}
_cluster.add(o);
}
for (Set c : new HashSet<>(puretriples)) {
if (_cluster.containsAll(c)) {
puretriples.remove(c);
}
}
if (this.verbose) {
System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
}
grown.add(_cluster);
}
}
// Strict grow phase.
if (false) {
Set t = new HashSet<>();
int count = 0;
int total = puretriples.size();
do {
if (!puretriples.iterator().hasNext()) {
break;
}
Set cluster = puretriples.iterator().next();
Set _cluster = new HashSet<>(cluster);
VARIABLES:
for (int o : _variables) {
if (_cluster.contains(o)) continue;
List _cluster2 = new ArrayList<>(_cluster);
ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 4);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int n1 = _cluster2.get(choice[0]);
int n2 = _cluster2.get(choice[1]);
int n3 = _cluster2.get(choice[2]);
int n4 = _cluster2.get(choice[3]);
t.clear();
t.add(n1);
t.add(n2);
t.add(n3);
t.add(n4);
t.add(o);
if (!puretriples.contains(t)) {
continue VARIABLES;
}
// if (avgSumLnP(new ArrayList(t)) < -10) continue CLUSTER;
}
_cluster.add(o);
}
// This takes out all pure clusters that are subsets of _cluster.
ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
int[] choice2;
List _cluster3 = new ArrayList<>(_cluster);
while ((choice2 = gen2.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int n1 = _cluster3.get(choice2[0]);
int n2 = _cluster3.get(choice2[1]);
int n3 = _cluster3.get(choice2[2]);
t.clear();
t.add(n1);
t.add(n2);
t.add(n3);
puretriples.remove(t);
}
if (this.verbose) {
System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
}
grown.add(_cluster);
} while (!puretriples.isEmpty());
}
if (false) {
System.out.println("# pure triples = " + puretriples.size());
List> clusters = new LinkedList<>(puretriples);
Set t = new HashSet<>();
for (int i = 0; i < clusters.size(); i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
System.out.println("I = " + i);
J:
for (int j = i + 1; j < clusters.size(); j++) {
Set ci = clusters.get(i);
Set cj = clusters.get(j);
if (ci == null) continue;
if (cj == null) continue;
Set ck = new HashSet<>(ci);
ck.addAll(cj);
List cm = new ArrayList<>(ck);
ChoiceGenerator gen = new ChoiceGenerator(cm.size(), 3);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
t.clear();
t.add(cm.get(choice[0]));
t.add(cm.get(choice[1]));
t.add(cm.get(choice[2]));
if (!puretriples.contains(t)) {
continue J;
}
}
clusters.set(i, ck);
clusters.remove(j);
j--;
System.out.println("Removing " + ci + ", " + cj + ", adding " + ck);
}
}
grown = new HashSet<>(clusters);
}
// Optimized pick phase.
log("Choosing among grown clusters.");
for (Set l : grown) {
ArrayList _l = new ArrayList<>(l);
Collections.sort(_l);
if (this.verbose) {
log("Grown: " + ClusterSignificance.variablesForIndices(_l, variables));
}
}
Set> out = new HashSet<>();
List> list = new ArrayList<>(grown);
list.sort((o1, o2) -> o2.size() - o1.size());
Set all = new HashSet<>();
CLUSTER:
for (Set cluster : list) {
for (Integer i : cluster) {
if (all.contains(i)) continue CLUSTER;
}
out.add(cluster);
all.addAll(cluster);
}
return out;
}
// Finds clusters of size 4 or higher for the tetrad-first algorithm.
private Set> findPureClusters(List _variables) {
Set> clusters = new HashSet<>();
VARIABLES:
while (!_variables.isEmpty()) {
if (this.verbose) {
System.out.println(_variables);
}
if (_variables.size() < 4) break;
ChoiceGenerator gen = new ChoiceGenerator(_variables.size(), 4);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int n1 = _variables.get(choice[0]);
int n2 = _variables.get(choice[1]);
int n3 = _variables.get(choice[2]);
int n4 = _variables.get(choice[3]);
List cluster = quartet(n1, n2, n3, n4);
// Note that purity needs to be assessed with respect to all the variables in order to
// remove all latent-measure impurities between pairs of latents.
if (pure(cluster)) {
addOtherVariables(_variables, cluster);
if (this.verbose) {
log("Cluster found: " + ClusterSignificance.variablesForIndices(cluster, variables));
}
clusters.add(cluster);
_variables.removeAll(cluster);
continue VARIABLES;
}
}
break;
}
return clusters;
}
/**
* Adds other variables to the given cluster if they satisfy certain conditions.
*
* @param _variables The list of available variables.
* @param cluster The current cluster.
*/
private void addOtherVariables(List _variables, List cluster) {
O:
for (int o : _variables) {
if (cluster.contains(o)) continue;
List _cluster = new ArrayList<>(cluster);
ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
int[] choice2;
// boolean found = false;
while ((choice2 = gen2.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int t1 = _cluster.get(choice2[0]);
int t2 = _cluster.get(choice2[1]);
int t3 = _cluster.get(choice2[2]);
List quartet = triple(t1, t2, t3);
quartet.add(o);
if (!pure(quartet)) {
continue O;
}
}
log("Extending by " + this.variables.get(o));
cluster.add(o);
}
}
/**
* Determines if adding a new cluster to the existing clusters would result in an insignificant model.
*
* @param clusters The set of existing clusters.
* @param cluster The new cluster to be added.
* @param variable The list of variables.
* @param dataModel The data model to be used in significance calculations.
* @return True if adding the new cluster would result in an insignificant model, false otherwise.
*/
private boolean modelInsignificantWithNewCluster(Set> clusters, List cluster,
List variable, DataModel dataModel) {
List> __clusters = new ArrayList<>(clusters);
__clusters.add(cluster);
ClusterSignificance clusterSignificance = new ClusterSignificance(variables, dataModel);
clusterSignificance.setCheckType(checkType);
double significance3 = clusterSignificance.getModelPValue(__clusters);
if (this.verbose) {
log("Significance * " + __clusters + " = " + significance3);
}
return significance3 < this.alpha;
}
/**
* Finds clusters of size 3 3or the quartet-first algorithm.
*
* @param remaining The list of remaining variables.
* @param unionPure The set of union pure variables.
* @return A set of lists of integers representing the mixed clusters.
*/
private Set> findMixedClusters(List remaining, Set unionPure) {
Set> triples = new HashSet<>();
if (unionPure.isEmpty()) {
return new HashSet<>();
}
REMAINING:
while (true) {
if (remaining.size() < 3) break;
ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 3);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
int t2 = remaining.get(choice[0]);
int t3 = remaining.get(choice[1]);
int t4 = remaining.get(choice[2]);
List cluster = new ArrayList<>();
cluster.add(t2);
cluster.add(t3);
cluster.add(t4);
if (zeroCorr(cluster)) {
continue;
}
// Check all x as a cross-check; really only one should be necessary.
boolean allVanish = true;
boolean someVanish = false;
for (int t1 : allVariables()) {
if (Thread.currentThread().isInterrupted()) {
break;
}
if (cluster.contains(t1)) continue;
List _cluster = new ArrayList<>(cluster);
_cluster.add(t1);
if (vanishes(_cluster)) {
someVanish = true;
} else {
allVanish = false;
break;
}
}
if (someVanish && allVanish) {
triples.add(cluster);
unionPure.addAll(cluster);
remaining.removeAll(cluster);
if (this.verbose) {
log("3-cluster found: " + ClusterSignificance.variablesForIndices(cluster, variables));
}
continue REMAINING;
}
}
break;
}
return triples;
}
/**
* Calculate the degrees of freedom for Drton's method.
*
* @param n The number of variables.
* @return The number of degrees of freedom.
*/
private int dofDrton(int n) {
int dof = ((n - 2) * (n - 3)) / 2 - 2;
if (dof < 0) dof = 0;
return dof;
}
/**
* Determines if a given quartet of variables satisfies the conditions for being considered pure.
*
* @param quartet The list of integers representing a quartet of variables.
* @return True if the quartet is pure, false otherwise.
*/
private boolean pure(List quartet) {
if (zeroCorr(quartet)) {
return false;
}
if (vanishes(quartet)) {
for (int o : allVariables()) {
if (quartet.contains(o)) continue;
for (int i = 0; i < quartet.size(); i++) {
List _quartet = new ArrayList<>(quartet);
_quartet.remove(quartet.get(i));
_quartet.add(o);
if (!(vanishes(_quartet))) {
return false;
}
}
}
return true;
}
return false;
}
/**
* Constructs a quartet from four given integers.
*
* @param n1 The first integer.
* @param n2 The second integer.
* @param n3 The third integer.
* @param n4 The fourth integer.
* @return A list containing the four integers in the order they were passed in.
* @throws IllegalArgumentException If any of the integers are duplicated.
*/
private List quartet(int n1, int n2, int n3, int n4) {
List quartet = new ArrayList<>();
quartet.add(n1);
quartet.add(n2);
quartet.add(n3);
quartet.add(n4);
if (new HashSet<>(quartet).size() < 4)
throw new IllegalArgumentException("quartet elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ", " + n4 + ">");
return quartet;
}
/**
* Constructs a {@link List} of integers representing a triple.
*
* @param n1 The first integer.
* @param n2 The second integer.
* @param n3 The third integer.
* @return A {@link List} containing the three integers in the order they were passed in.
* @throws IllegalArgumentException If any of the integers are duplicated.
*/
private List triple(int n1, int n2, int n3) {
List triple = new ArrayList<>();
triple.add(n1);
triple.add(n2);
triple.add(n3);
if (new HashSet<>(triple).size() < 3)
throw new IllegalArgumentException("triple elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ">");
return triple;
}
/**
* Determines if the quartet of variables vanishes based on the test type.
*
* @param quartet The list of integers representing the quartet of variables.
* @return True if the quartet vanishes, false otherwise.
*/
private boolean vanishes(List quartet) {
int n1 = quartet.get(0);
int n2 = quartet.get(1);
int n3 = quartet.get(2);
int n4 = quartet.get(3);
return vanishes(n1, n2, n3, n4);
}
/**
* Checks if a given cluster has zero correlation among its variables.
*
* @param cluster The list of integers representing the cluster.
* @return True if the cluster has zero correlation, false otherwise.
*/
private boolean zeroCorr(List cluster) {
int count = 0;
for (int i = 0; i < cluster.size(); i++) {
for (int j = i + 1; j < cluster.size(); j++) {
double r = this.corr.getValue(cluster.get(i), cluster.get(j));
int N = this.corr.getSampleSize();
double f = sqrt(N) * FastMath.log((1. + r) / (1. - r));
double p = 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0, 1, abs(f)));
if (p > this.alpha) count++;
}
}
return count >= 1;
}
/**
* Determines if the quartet of variables vanishes based on the test type.
*
* @param x The first variable index.
* @param y The second variable index.
* @param z The third variable index.
* @param w The fourth variable index.
* @return True if the quartet vanishes, false otherwise.
*/
private boolean vanishes(int x, int y, int z, int w) {
if (this.testType == BpcTestType.TETRAD_DELTA) {
Tetrad t1 = new Tetrad(this.variables.get(x), this.variables.get(y), this.variables.get(z), this.variables.get(w));
Tetrad t2 = new Tetrad(this.variables.get(x), this.variables.get(y), this.variables.get(w), this.variables.get(z));
return this.test.getPValue(t1, t2) > this.alpha;
} else if (this.testType == BpcTestType.TETRAD_WISHART) {
return this.test2.tetradPValue(x, y, z, w) > this.alpha && this.test2.tetradPValue(x, y, w, z) > this.alpha;
}
throw new IllegalArgumentException("Only the delta and wishart tests are being used: " + this.testType);
}
/**
* Converts search graph nodes to a Graph object.
*
* @param clusters The set of sets of Node objects representing the clusters.
* @return A Graph object representing the search graph nodes.
*/
private Graph convertSearchGraphNodes(Set> clusters) {
Graph graph = new EdgeListGraph();
List latents = new ArrayList<>();
List> _clusters = new ArrayList<>(clusters);
for (int i = 0; i < _clusters.size(); i++) {
Node latent = new GraphNode(ClusterUtils.LATENT_PREFIX + (i + 1));
latent.setNodeType(NodeType.LATENT);
latents.add(latent);
graph.addNode(latent);
for (Node node : _clusters.get(i)) {
if (!graph.containsNode(node)) graph.addNode(node);
graph.addDirectedEdge(latents.get(i), node);
}
}
return graph;
}
/**
* Converts search graph nodes to a Graph object.
*
* @param allClusters The set of sets of Node objects representing the clusters.
* @return A Graph object representing the search graph nodes.
*/
private Graph convertToGraph(Set> allClusters) {
Set> _clustering = new HashSet<>();
for (List cluster : allClusters) {
Set nodes = new HashSet<>();
for (int i : cluster) {
nodes.add(this.variables.get(i));
}
_clustering.add(nodes);
}
return convertSearchGraphNodes(_clustering);
}
/**
* Returns the union of all integers in the given list of clusters.
*
* @param pureClusters The set of clusters, where each cluster is represented as a list of integers.
* @return A set containing the union of all integers in the clusters.
*/
private Set unionPure(Set> pureClusters) {
Set unionPure = new HashSet<>();
for (List cluster : pureClusters) {
unionPure.addAll(cluster);
}
return unionPure;
}
/**
* Logs a message if the verbose flag is set to true.
*
* @param s The message to log.
*/
private void log(String s) {
if (this.verbose) {
TetradLogger.getInstance().log(s);
}
}
/**
* Gives the options to be used in FOFC to sort through the various possibilities for forming clusters to find the
* best options. SAG (Seed and Grow) looks for good seed clusters and then grows them by adding one variable at a
* time. GAP (Grow and Pick) grows out all the cluster initially and then just picks from among these. SAG is
* generally faster; GAP is generally slower but more accurate.
*/
public enum Algorithm {
/**
* The SAG algorithm.
*/
SAG,
/**
* The GAP algorithm.
*/
GAP
}
}