edu.cmu.tetrad.search.Ccd Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.SepsetProducer;
import edu.cmu.tetrad.search.utils.SepsetsSet;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.*;
/**
* Implemented the Cyclic Causal Discovery (CCD) algorithm by Thomas Richardson.
* A reference for this is here:
*
* Richardson, T. S. (2013). A discovery algorithm for directed cyclic graphs. arXiv
* preprint arXiv:1302.3599.
*
* See also Chapter 7 of:
*
* Glymour, C. N., & Cooper, G. F. (Eds.). (1999). Computation, causation, and
* discovery. Aaai Press.
*
* The graph takes continuous data from a cyclic model as input and returns a cyclic
* PAG graphs, with various types of underlining, that represents a Markov equivalence of the true DAG.
*
* This class is not configured to respect knowledge of forbidden and required
* edges.
*
* @author Frank C. Wimberly
* @author josephramsey
*/
public final class Ccd implements IGraphSearch {
private final IndependenceTest independenceTest;
private final List nodes;
private boolean applyR1;
/**
* Construct a CCD algorithm with the given independence test.
*
* @param test The test to be used.
* @see IndependenceTest
*/
public Ccd(IndependenceTest test) {
if (test == null) throw new NullPointerException("Test is not provided");
this.independenceTest = test;
this.nodes = test.getVariables();
}
/**
* The search method assumes that the IndependenceTest provided to the constructor is a conditional independence
* oracle for the SEM (or Bayes network) which describes the causal structure of the population. The method returns
* a PAG instantiated as a Tetrad GaSearchGraph which represents the equivalence class of digraphs which are
* m-separation equivalent to the digraph of the underlying model (SEM or BN). Although they are not returned by the
* search method, it also computes two lists of triples which, respectively, store the underlines and dotted
* underlines of the PAG.
*
* @return The CCD cyclic PAG for the data.
*/
public Graph search() {
Map> supSepsets = new HashMap<>();
// Step A.
Fas fas = new Fas(this.independenceTest);
Graph psi = fas.search();
psi.reorientAllWith(Endpoint.CIRCLE);
SepsetProducer sepsets = new SepsetsSet(fas.getSepsets(), this.independenceTest);
stepB(psi);
stepC(psi, sepsets);
stepD(psi, sepsets, supSepsets);
stepE(supSepsets, psi);
stepF(psi, sepsets, supSepsets);
orientAwayFromArrow(psi);
return psi;
}
/**
* Returns true iff the R1 rule should be applied.
*
* @return True if the case.
*/
public boolean isApplyR1() {
return this.applyR1;
}
/**
* Sets whether the R1 rule should be applied.
*
* @param applyR1 True if the case.
*/
public void setApplyR1(boolean applyR1) {
this.applyR1 = applyR1;
}
private void orientAwayFromArrow(Graph graph) {
for (Edge edge : graph.getEdges()) {
Node n1 = edge.getNode1();
Node n2 = edge.getNode2();
edge = graph.getEdge(n1, n2);
if (edge.pointsTowards(n1)) {
orientAwayFromArrow(n2, n1, graph);
} else if (edge.pointsTowards(n2)) {
orientAwayFromArrow(n1, n2, graph);
}
}
}
private void stepB(Graph graph) {
Map colliders = new HashMap<>();
Map noncolliders = new HashMap<>();
for (Node node : this.nodes) {
doNodeCollider(graph, colliders, noncolliders, node);
}
List collidersList = new ArrayList<>(colliders.keySet());
List noncollidersList = new ArrayList<>(noncolliders.keySet());
for (Triple triple : collidersList) {
Node a = triple.getX();
Node b = triple.getY();
Node c = triple.getZ();
graph.removeEdge(a, b);
graph.removeEdge(c, b);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(c, b);
}
for (Triple triple : noncollidersList) {
Node a = triple.getX();
Node b = triple.getY();
Node c = triple.getZ();
graph.addUnderlineTriple(a, b, c);
}
}
private void doNodeCollider(Graph graph, Map colliders, Map noncolliders, Node b) {
List adjacentNodes = new ArrayList<>(graph.getAdjacentNodes(b));
if (adjacentNodes.size() < 2) {
return;
}
ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
int[] combination;
while ((combination = cg.next()) != null) {
Node a = adjacentNodes.get(combination[0]);
Node c = adjacentNodes.get(combination[1]);
// Skip triples that are shielded.
if (graph.isAdjacentTo(a, c)) {
continue;
}
List adja = new ArrayList<>(graph.getAdjacentNodes(a));
double score = Double.POSITIVE_INFINITY;
Set S = null;
SublistGenerator cg2 = new SublistGenerator(adja.size(), -1);
int[] comb2;
while ((comb2 = cg2.next()) != null) {
Set s = GraphUtils.asSet(comb2, adja);
IndependenceResult result = this.independenceTest.checkIndependence(a, c, s);
double _score = result.getScore();
if (_score < score) {
score = _score;
S = s;
}
}
List adjc = new ArrayList<>(graph.getAdjacentNodes(c));
SublistGenerator cg3 = new SublistGenerator(adjc.size(), -1);
int[] comb3;
while ((comb3 = cg3.next()) != null) {
Set s = GraphUtils.asSet(comb3, adjc);
IndependenceResult result = this.independenceTest.checkIndependence(c, a, s);
double _score = result.getScore();
if (_score < score) {
score = _score;
S = s;
}
}
// This could happen if there are undefined values and such.
if (S == null) {
continue;
}
if (S.contains(b)) {
noncolliders.put(new Triple(a, b, c), score);
} else {
colliders.put(new Triple(a, b, c), score);
}
}
}
private void stepC(Graph psi, SepsetProducer sepsets) {
TetradLogger.getInstance().log("info", "\nStep C");
EDGE:
for (Edge edge : psi.getEdges()) {
Node x = edge.getNode1();
Node y = edge.getNode2();
// x and y are adjacent.
List adjx = psi.getAdjacentNodes(x);
List adjy = psi.getAdjacentNodes(y);
for (Node node : adjx) {
if (psi.getEdge(node, x).getProximalEndpoint(x) == Endpoint.ARROW
&& psi.isUnderlineTriple(y, x, node)) {
continue EDGE;
}
}
// Check each A
for (Node a : this.nodes) {
if (a == x) continue;
if (a == y) continue;
//...A is not adjacent to X and A is not adjacent to Y...
if (adjx.contains(a)) continue;
if (adjy.contains(a)) continue;
// Orientable...
if (!(psi.getEndpoint(y, x) == Endpoint.CIRCLE &&
(psi.getEndpoint(x, y) == Endpoint.CIRCLE || psi.getEndpoint(x, y) == Endpoint.TAIL))) {
continue;
}
//...X is not in sepset...
Set sepset = sepsets.getSepset(a, y);
if (sepset == null) {
continue;
}
if (sepset.contains(x)) continue;
if (!sepsets.isIndependent(a, x, sepset)) {
psi.removeEdge(x, y);
psi.addDirectedEdge(y, x);
orientAwayFromArrow(y, x, psi);
break;
}
}
}
}
private void stepD(Graph psi, SepsetProducer sepsets, Map> supSepsets) {
Map> local = new HashMap<>();
for (Node node : psi.getNodes()) {
local.put(node, local(psi, node));
}
for (Node node : this.nodes) {
doNodeStepD(psi, sepsets, supSepsets, local, node);
}
}
private void doNodeStepD(Graph psi, SepsetProducer sepsets, Map> supSepsets,
Map> local, Node b) {
List adj = new ArrayList<>(psi.getAdjacentNodes(b));
if (adj.size() < 2) {
return;
}
ChoiceGenerator gen = new ChoiceGenerator(adj.size(), 2);
int[] choice;
while ((choice = gen.next()) != null) {
List _adj = GraphUtils.asList(choice, adj);
Node a = _adj.get(0);
Node c = _adj.get(1);
if (!psi.isDefCollider(a, b, c)) continue;
Set S = sepsets.getSepset(a, c);
if (S == null) continue;
ArrayList TT = new ArrayList<>(local.get(a));
TT.removeAll(S);
TT.remove(b);
TT.remove(c);
SublistGenerator gen2 = new SublistGenerator(TT.size(), -1);
int[] choice2;
while ((choice2 = gen2.next()) != null) {
Set T = GraphUtils.asSet(choice2, TT);
Set B = new HashSet<>(T);
B.addAll(S);
B.add(b);
if (sepsets.isIndependent(a, c, new HashSet<>(B))) {
psi.addDottedUnderlineTriple(a, b, c);
supSepsets.put(new Triple(a, b, c), B);
break;
}
}
}
}
private void stepE(Map> supSepset, Graph psi) {
TetradLogger.getInstance().log("info", "\nStep E");
for (Triple triple : psi.getDottedUnderlines()) {
Node a = triple.getX();
Node b = triple.getY();
Node c = triple.getZ();
List aAdj = psi.getAdjacentNodes(a);
for (Node d : aAdj) {
if (d == b) continue;
if (psi.getEndpoint(b, d) != Endpoint.CIRCLE) {
continue;
}
if (supSepset.get(triple).contains(d)) {
// Orient B*-oD as B*-D
psi.setEndpoint(b, d, Endpoint.TAIL);
} else {
if (psi.getEndpoint(d, b) == Endpoint.ARROW) {
continue;
}
// Or orient Bo-oD or B-oD as B->D...
psi.removeEdge(b, d);
psi.addDirectedEdge(b, d);
orientAwayFromArrow(b, d, psi);
}
}
List cAdj = psi.getAdjacentNodes(c);
for (Node d : cAdj) {
if (d == b) continue;
if (psi.getEndpoint(b, d) != Endpoint.CIRCLE) {
continue;
}
if (supSepset.get(triple).contains(d)) {
// Orient B*-oD as B*-D
psi.setEndpoint(b, d, Endpoint.TAIL);
} else {
if (psi.getEndpoint(d, b) == Endpoint.ARROW) {
continue;
}
// Or orient Bo-oD or B-oD as B->D...
psi.removeEdge(b, d);
psi.addDirectedEdge(b, d);
orientAwayFromArrow(b, d, psi);
}
}
}
}
private void stepF(Graph psi, SepsetProducer sepsets, Map> supSepsets) {
for (Triple triple : psi.getDottedUnderlines()) {
Node a = triple.getX();
Node b = triple.getY();
Node c = triple.getZ();
Set adj = new HashSet<>(psi.getAdjacentNodes(a));
adj.addAll(psi.getAdjacentNodes(c));
for (Node d : adj) {
if (psi.getEndpoint(b, d) != Endpoint.CIRCLE) {
continue;
}
if (psi.getEndpoint(d, b) == Endpoint.ARROW) {
continue;
}
//...and D is not adjacent to both A and C in psi...
if (psi.isAdjacentTo(a, d) && psi.isAdjacentTo(c, d)) {
continue;
}
//...and B and D are adjacent...
if (!psi.isAdjacentTo(b, d)) {
continue;
}
Set supSepUnionD = new HashSet<>();
supSepUnionD.add(d);
supSepUnionD.addAll(supSepsets.get(triple));
Set listSupSepUnionD = new HashSet<>(supSepUnionD);
//If A and C are a pair of vertices d-connected given
//SupSepset union {D} then orient Bo-oD or B-oD
//as B->D in psi.
if (!sepsets.isIndependent(a, c, listSupSepUnionD)) {
psi.removeEdge(b, d);
psi.addDirectedEdge(b, d);
orientAwayFromArrow(b, d, psi);
}
}
}
}
private List local(Graph psi, Node x) {
Set nodes = new HashSet<>(psi.getAdjacentNodes(x));
for (Node y : new HashSet<>(nodes)) {
for (Node z : psi.getAdjacentNodes(y)) {
if (psi.isDefCollider(x, y, z)) {
if (z != x) {
nodes.add(z);
}
}
}
}
return new ArrayList<>(nodes);
}
private void orientAwayFromArrow(Node a, Node b, Graph graph) {
if (!isApplyR1()) return;
for (Node c : graph.getAdjacentNodes(b)) {
if (c == a) continue;
orientAwayFromArrowVisit(a, b, c, graph);
}
}
private boolean orientAwayFromArrowVisit(Node a, Node b, Node c, Graph graph) {
if (!Edges.isNondirectedEdge(graph.getEdge(b, c))) {
return false;
}
if (!(graph.isUnderlineTriple(a, b, c))) {
return false;
}
if (graph.getEdge(b, c).pointsTowards(b)) {
return false;
}
graph.removeEdge(b, c);
graph.addDirectedEdge(b, c);
for (Node d : graph.getAdjacentNodes(c)) {
if (d == b) return true;
Edge bc = graph.getEdge(b, c);
if (!orientAwayFromArrowVisit(b, c, d, graph)) {
graph.removeEdge(b, c);
graph.addEdge(bc);
}
}
return true;
}
}