edu.cmu.tetrad.search.MarkovCheck Maven / Gradle / Ivy
package edu.cmu.tetrad.search;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.test.MsepTest;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.UniformityTest;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import static org.apache.commons.math3.util.FastMath.min;
/**
* Checks whether a graph is locally Markov or locally Faithful given a data set. First a lists of m-separation
* predictions are made for each pair of variables in the graph given the parents of one of the variables, one list (for
* local Markov) where the m-separation holds and another list (for local Faithfulness) where the m-separation does not
* hold. Then the predictions are tested against the data set using the independence test. For the Markov test, since an
* independence test yielding p-values should be Uniform under the null hypothesis, these p-values are tested for
* Uniformity using the Kolmogorov-Smirnov test. Also, a fraction of dependent judgments is returned, which should equal
* the alpha level of the independence test if the test is Uniform under the null hypothesis. For the Faithfulness test,
* the p-values are tested for Uniformity using the Kolmogorov-Smirnov test; these should be dependent. Also, a fraction
* of dependent judgments is returned, which should be maximal./p>
*
*
A "Markov adequacy score" is also given, which simply returns zero if the Markov p-value Uniformity test
* fails and the fraction of dependent judgments for the local Faithfulness check otherwise. Maximizing this score picks
* out models for which Markov holds and faithfulness holds to the extend possible; these model should generally have
* good accuracy scores.
*
* @author josephramsey
*/
public class MarkovCheck {
private final Graph graph;
private final IndependenceTest independenceTest;
private final MsepTest msep;
private final List resultsIndep = new ArrayList<>();
private final List resultsDep = new ArrayList<>();
private ConditioningSetType setType;
private boolean parallelized = false;
private double fractionDependentIndep = Double.NaN;
private double fractionDependentDep = Double.NaN;
private double ksPValueIndep = Double.NaN;
private double ksPValueDep = Double.NaN;
/**
* Constructor. Takes a graph and an independence test over the variables of the graph.
*
* @param graph The graph.
* @param independenceTest The test over the variables of the graph.
*/
public MarkovCheck(Graph graph, IndependenceTest independenceTest, ConditioningSetType setType) {
this.graph = GraphUtils.replaceNodes(graph, independenceTest.getVariables());
this.independenceTest = independenceTest;
this.msep = new MsepTest(this.graph);
this.setType = setType;
}
/**
* Generates all results, for both the local Markov and local Faithfulness checks, for each node in the graph given
* the parents of that node. These results are stored in the resultsIndep and resultsDep lists.
*
* @see #getResults(boolean)
*/
public void generateResults() {
resultsIndep.clear();
resultsDep.clear();
if (setType == ConditioningSetType.GLOBAL_MARKOV) {
AllSubsetsIndependenceFacts result = getAllSubsetsIndependenceFacts(graph);
generateResultsAllSubsets(true, result.msep, result.mconn);
generateResultsAllSubsets(false, result.msep, result.mconn);
} else {
List variables = independenceTest.getVariables();
List nodes = new ArrayList<>(variables);
Collections.sort(nodes);
List order = graph.paths().getValidOrder(graph.getNodes(), true);
for (Node x : nodes) {
Set z;
switch (setType) {
case LOCAL_MARKOV:
z = new HashSet<>(graph.getParents(x));
break;
case ORDERED_LOCAL_MARKOV:
if (order == null) throw new IllegalArgumentException("No valid order found.");
z = new HashSet<>(graph.getParents(x));
// Keep only the parents in Prefix(x).
for (Node w : new ArrayList<>(z)) {
int i1 = order.indexOf(x);
int i2 = order.indexOf(w);
if (i2 >= i1) {
z.remove(w);
}
}
break;
case MARKOV_BLANKET:
z = GraphUtils.markovBlanket(x, graph);
break;
default:
throw new IllegalArgumentException("Unknown separation set type: " + setType);
}
Set msep = new HashSet<>();
Set mconn = new HashSet<>();
List other = new ArrayList<>(graph.getNodes());
Collections.sort(other);
other.removeAll(z);
for (Node y : other) {
if (y == x) continue;
if (z.contains(x) || z.contains(y)) continue;
if (this.msep.isMSeparated(x, y, z)) {
msep.add(y);
} else {
mconn.add(y);
}
}
generateResults(true, x, z, msep, mconn);
generateResults(false, x, z, msep, mconn);
}
}
calcStats(true);
calcStats(false);
}
@NotNull
public static AllSubsetsIndependenceFacts getAllSubsetsIndependenceFacts(Graph graph) {
List variables = new ArrayList<>(graph.getNodes());
MsepTest msepTest = new MsepTest(graph);
List nodes = new ArrayList<>(variables);
Collections.sort(nodes);
List msep = new ArrayList<>();
List mconn = new ArrayList<>();
for (Node x : nodes) {
List other = new ArrayList<>(variables);
Collections.sort(other);
other.remove(x);
for (Node y : other) {
List _other = new ArrayList<>(other);
_other.remove(y);
SublistGenerator generator = new SublistGenerator(_other.size(), _other.size());
int[] list;
while ((list = generator.next()) != null) {
Set z = GraphUtils.asSet(list, _other);
if (msepTest.isMSeparated(x, y, z)) {
msep.add(new IndependenceFact(x, y, z));
} else {
mconn.add(new IndependenceFact(x, y, z));
}
}
}
}
return new AllSubsetsIndependenceFacts(msep, mconn);
}
public static class AllSubsetsIndependenceFacts {
private final List msep;
private final List mconn;
public AllSubsetsIndependenceFacts(List msep, List mconn) {
this.msep = msep;
this.mconn = mconn;
}
public String toStringIndep() {
StringBuilder builder = new StringBuilder("All subsets independence facts:\n");
for (IndependenceFact fact : msep) {
builder.append(fact).append("\n");
}
return builder.toString();
}
public String toStringDep() {
StringBuilder builder = new StringBuilder("All subsets independence facts:\n");
for (IndependenceFact fact : mconn) {
builder.append(fact).append("\n");
}
return builder.toString();
}
public List getMsep() {
return msep;
}
public List getMconn() {
return mconn;
}
}
/**
* Returns type of conditioning sets to use in the Markov check.
*
* @return The type of conditioning sets to use in the Markov check.
* @see ConditioningSetType
*/
public ConditioningSetType getSetType() {
return this.setType;
}
/**
* Sets the type of conditioning sets to use in the Markov check.
*
* @param setType The type of conditioning sets to use in the Markov check.
* @see ConditioningSetType
*/
public void setSetType(ConditioningSetType setType) {
this.setType = setType;
}
/**
* True if the checks should be parallelized. (Not always a good idea.)
*
* @param parallelized True if the checks should be parallelized.
*/
public void setParallelized(boolean parallelized) {
this.parallelized = parallelized;
}
/**
* After the generateResults method has been called, this method returns the results for the local Markov or local
* Faithfulness check, depending on the value of the indep parameter.
*
* @param indep True for the local Markov results, false for the local Faithfulness results.
* @return The results for the local Markov or local Faithfulness check.
*/
public List getResults(boolean indep) {
if (indep) {
return new ArrayList<>(this.resultsIndep);
} else {
return new ArrayList<>(this.resultsDep);
}
}
/**
* Returns the list of p-values for the given list of results.
*
* @param results The results.
* @return Their p-values.
*/
public List getPValues(List results) {
List pValues = new ArrayList<>();
for (IndependenceResult result : results) {
pValues.add(result.getPValue());
}
return pValues;
}
/**
* Returns the fraction of dependent judgments for the given list of results.
*
* @param indep True for the local Markov results, false for the local Faithfulness results.
* @return The fraction of dependent judgments for this condition.
*/
public double getFractionDependent(boolean indep) {
if (indep) {
return fractionDependentIndep;
} else {
return fractionDependentDep;
}
}
/**
* Returns the Kolmorogov-Smirnov p-value for the given list of results.
*
* @param indep True for the local Markov results, false for the local Faithfulness results.
* @return The Kolmorogov-Smirnov p-value for this condition.
*/
public double getKsPValue(boolean indep) {
if (indep) {
return ksPValueIndep;
} else {
return ksPValueDep;
}
}
/**
* Returns the Markov Adequacy Score for the graph. This is zero if the p-value of the KS test of Uniformity is less
* than alpha, and the fraction of dependent pairs otherwise. This is only for continuous Gaussian data, as it
* hard-codes the Fisher Z test for the local Markov and Faithfulness check.
*
* @param alpha The alpha level for the KS test of Uniformity. An alpha level greater than this will be considered
* uniform.
* @return The Markov Adequacy Score for this graph given the data.
*/
public double getMarkovAdequacyScore(double alpha) {
if (getKsPValue(true) > alpha) {
return getFractionDependent(false);
} else {
return 0.0;
}
}
/**
* Returns the variables of the independence test.
*
* @return The variables of the independence test.
*/
public List getVariables() {
return new ArrayList<>(independenceTest.getVariables());
}
/**
* Returns the variable with the given name.
*
* @param name The name of the variables.
* @return The variable with the given name.
*/
public Node getVariable(String name) {
return independenceTest.getVariable(name);
}
/**
* Returns the independence test being used.
*
* @return This test.
*/
public IndependenceTest getIndependenceTest() {
return this.independenceTest;
}
private void generateResults(boolean indep, Node x, Set z, Set msep, Set mconn) {
List facts = new ArrayList<>();
// Listing all facts before checking any (in preparation for parallelization).
if (indep) {
for (Node y : msep) {
if (z.contains(y)) continue;
facts.add(new IndependenceFact(x, y, z));
}
} else {
for (Node y : mconn) {
if (z.contains(y)) continue;
facts.add(new IndependenceFact(x, y, z));
}
}
class IndCheckTask implements Callable> {
private final int from;
private final int to;
private final List facts;
private final IndependenceTest independenceTest;
IndCheckTask(int from, int to, List facts, IndependenceTest test) {
this.from = from;
this.to = to;
this.facts = facts;
this.independenceTest = test;
}
@Override
public List call() {
List results = new ArrayList<>();
for (int i = from; i < to; i++) {
if (Thread.interrupted()) break;
IndependenceFact fact = facts.get(i);
Node x = fact.getX();
Node y = fact.getY();
Set z = fact.getZ();
boolean verbose = independenceTest.isVerbose();
independenceTest.setVerbose(false);
IndependenceResult result;
try {
result = independenceTest.checkIndependence(x, y, z);
} catch (Exception e) {
throw new RuntimeException(e);
}
boolean indep = result.isIndependent();
double pValue = result.getPValue();
independenceTest.setVerbose(verbose);
if (!Double.isNaN(pValue)) {
results.add(new IndependenceResult(fact, indep, pValue, Double.NaN));
}
}
return results;
}
}
List>> tasks = new ArrayList<>();
int chunkSize = getChunkSize(facts.size());
for (int i = 0; i < facts.size() && !Thread.currentThread().isInterrupted(); i += chunkSize) {
IndCheckTask task = new IndCheckTask(i, min(facts.size(), i + chunkSize), facts, independenceTest);
if (!parallelized) {
List _results = task.call();
getResultsLocal(indep).addAll(_results);
} else {
tasks.add(task);
}
}
if (parallelized) {
List>> theseResults = ForkJoinPool.commonPool().invokeAll(tasks);
for (Future> future : theseResults) {
try {
getResultsLocal(indep).addAll(future.get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}
}
private void generateResultsAllSubsets(boolean indep, List msep, List mconn) {
List facts = indep ? msep : mconn;
class IndCheckTask implements Callable> {
private final int from;
private final int to;
private final List facts;
private final IndependenceTest independenceTest;
IndCheckTask(int from, int to, List facts, IndependenceTest test) {
this.from = from;
this.to = to;
this.facts = facts;
this.independenceTest = test;
}
@Override
public List call() {
List results = new ArrayList<>();
for (int i = from; i < to; i++) {
if (Thread.interrupted()) break;
IndependenceFact fact = facts.get(i);
Node x = fact.getX();
Node y = fact.getY();
Set z = fact.getZ();
boolean verbose = independenceTest.isVerbose();
independenceTest.setVerbose(false);
IndependenceResult result;
try {
result = independenceTest.checkIndependence(x, y, z);
} catch (Exception e) {
throw new RuntimeException(e);
}
boolean indep = result.isIndependent();
double pValue = result.getPValue();
independenceTest.setVerbose(verbose);
if (!Double.isNaN(pValue)) {
results.add(new IndependenceResult(fact, indep, pValue, Double.NaN));
}
}
return results;
}
}
List>> tasks = new ArrayList<>();
int chunkSize = getChunkSize(facts.size());
for (int i = 0; i < facts.size() && !Thread.currentThread().isInterrupted(); i += chunkSize) {
IndCheckTask task = new IndCheckTask(i, min(facts.size(), i + chunkSize), facts, independenceTest);
if (!parallelized) {
List _results = task.call();
getResultsLocal(indep).addAll(_results);
} else {
tasks.add(task);
}
}
if (parallelized) {
List>> theseResults = ForkJoinPool.commonPool().invokeAll(tasks);
for (Future> future : theseResults) {
try {
getResultsLocal(indep).addAll(future.get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}
}
private void calcStats(boolean indep) {
List results = getResultsLocal(indep);
int dependent = 0;
for (IndependenceResult result : results) {
if (result.isDependent() && !Double.isNaN(result.getPValue())) dependent++;
}
if (indep) {
fractionDependentIndep = dependent / (double) results.size();
} else {
fractionDependentDep = dependent / (double) results.size();
}
List pValues = getPValues(results);
if (indep) {
if (pValues.size() < 2) {
ksPValueIndep = Double.NaN;
} else {
ksPValueIndep = UniformityTest.getPValue(pValues);
}
} else {
if (pValues.size() < 2) {
ksPValueDep = Double.NaN;
} else {
ksPValueDep = UniformityTest.getPValue(pValues);
}
}
}
private int getChunkSize(int n) {
int chunk = (int) FastMath.ceil((n / ((double) (5 * Runtime.getRuntime().availableProcessors()))));
if (chunk < 1) chunk = 1;
return chunk;
}
private List getResultsLocal(boolean indep) {
if (indep) {
return this.resultsIndep;
} else {
return this.resultsDep;
}
}
}