edu.cmu.tetrad.search.SvarFas Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014 by Peter Spirtes, Richard Scheines, Joseph //
// Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.utils.SepsetMap;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import org.apache.commons.math3.util.FastMath;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
/**
* Adapts FAS for the time series setting, assuming the data is generated by a
* SVAR (structural vector autoregression). The main difference is that time order is imposed, and if an edge is
* removed, it will also remove all homologous edges to preserve the time-repeating structure assumed by SvarFCI. Based
* on (but not identical to) code by Entner and Hoyer for their 2010 paper. Modified by dmalinsky 4/21/2016.
*
* The references are as follows:
*
* Malinsky, D., & Spirtes, P. (2018, August). Causal structure learning from multivariate
* time series in settings with unmeasured confounding. In Proceedings of 2018 ACM SIGKDD workshop on causal discovery
* (pp. 23-47). PMLR.
*
* Entner, D., & Hoyer, P. O. (2010). On causal discovery from time series data using FCI. Probabilistic
* graphical models, 121-128.
*
* This class is configured to respect knowledge of forbidden and required
* edges, including knowledge of temporal tiers.
*
* @author dmalinsky
* @see Fas
* @see Knowledge
* @see SvarFci
*/
public class SvarFas implements IFas {
/**
* The search graph. It is assumed going in that all the true adjacencies of x are in this graph for every node x.
* It is hoped (i.e., true in the large sample limit) that true adjacencies are never removed.
*/
private final Graph graph;
/**
* The independence test. This should be appropriate to the types
*/
private final IndependenceTest test;
/**
* The logger, by default the empty logger.
*/
private final TetradLogger logger = TetradLogger.getInstance();
private final NumberFormat nf = new DecimalFormat("0.00E0");
/**
* Specification of which edges are forbidden or required.
*/
private Knowledge knowledge = new Knowledge();
/**
* The maximum number of variables conditioned on in any conditional independence test. If the depth is -1, it will
* be taken to be the maximum value, which is 1000. Otherwise, it should be set to a non-negative integer.
*/
private int depth = 1000;
/**
* The number of independence tests.
*/
private int numIndependenceTests;
/**
* The sepsets found during the search.
*/
private SepsetMap sepset = new SepsetMap();
/**
* The depth 0 graph, specified initially.
*/
private Graph externalGraph;
/**
* True iff verbose output should be printed.
*/
private boolean verbose;
private PrintStream out = System.out;
/**
* Constructs a new FastAdjacencySearch.
*/
public SvarFas(IndependenceTest test) {
this.graph = new EdgeListGraph(test.getVariables());
this.test = test;
}
/**
* Discovers all adjacencies in data. The procedure is to remove edges in the graph which connect pairs of
* variables which are independent, conditional on some other set of variables in the graph (the "sepset"). These are
* removed in tiers. First, edges which are independent conditional on zero other variables are removed, then edges
* which are independent conditional on one other variable are removed, then two, then three, and so on, until no
* more edges can be removed from the graph. The edges which remain in the graph after this procedure are the
* adjacencies in the data.
*
* @return a SepSet, which indicates which variables are independent conditional on which other variables
*/
public Graph search() {
this.logger.log("info", "Starting Fast Adjacency Search.");
this.graph.removeEdges(this.graph.getEdges());
this.sepset = new SepsetMap();
int _depth = this.depth;
if (_depth == -1) {
_depth = 1000;
}
Map> adjacencies = new HashMap<>();
List nodes = this.graph.getNodes();
for (Node node : nodes) {
adjacencies.put(node, new TreeSet<>());
}
for (int d = 0; d <= _depth; d++) {
boolean more;
if (d == 0) {
more = searchAtDepth0(nodes, this.test, adjacencies);
} else {
more = searchAtDepth(nodes, this.test, adjacencies, d);
}
if (!more) {
break;
}
}
for (int i = 0; i < nodes.size(); i++) {
for (int j = i + 1; j < nodes.size(); j++) {
Node x = nodes.get(i);
Node y = nodes.get(j);
if (adjacencies.get(x).contains(y)) {
this.graph.addUndirectedEdge(x, y);
}
}
}
this.logger.log("info", "Finishing Fast Adjacency Search.");
return this.graph;
}
/**
* Sets the depth--i.e., the maximum number of variables conditioned on in any test, -1 for unlimited.
*
* @param depth This depth.
*/
public void setDepth(int depth) {
if (depth < -1) {
throw new IllegalArgumentException(
"Depth must be -1 (unlimited) or >= 0.");
}
this.depth = depth;
}
/**
* Sets the knowledge used in the search.
*
* @param knowledge This knowledge.
*/
public void setKnowledge(Knowledge knowledge) {
this.knowledge = new Knowledge(knowledge);
}
/**
* Returns the number of independence tests.
*
* @return This number.
*/
public int getNumIndependenceTests() {
return this.numIndependenceTests;
}
/**
* Returns a map for x _||_ y | Z from {x, y} to Z.
*
* @return This map.
*/
public SepsetMap getSepsets() {
return this.sepset;
}
/**
* Sets an external graph.
*
* @param externalGraph This graph.
*/
public void setExternalGraph(Graph externalGraph) {
this.externalGraph = externalGraph;
}
/**
* Sets whether verbose output should be printed.
*
* @param verbose True, if so.
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* @throws UnsupportedOperationException This method is not used.
*/
@Override
public long getElapsedTime() {
throw new UnsupportedOperationException("This method is not used.");
}
/**
* Returns the nodes of the test.
*
* @return This list.
*/
@Override
public List getNodes() {
return this.test.getVariables();
}
/**
* @throws UnsupportedOperationException This method is not used.
*/
@Override
public List getAmbiguousTriples(Node node) {
throw new UnsupportedOperationException("This method is not used.");
}
/**
* Sets the output stream for printing, default is System.out.
*
* @param out The print stream.
* @see PrintStream
*/
@Override
public void setOut(PrintStream out) {
this.out = out;
}
private boolean searchAtDepth0(List nodes, IndependenceTest test, Map> adjacencies) {
Set empty = Collections.emptySet();
List simListX = new ArrayList<>();
List simListY = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
if (this.verbose) {
if ((i + 1) % 100 == 0) this.out.println("Node # " + (i + 1));
}
Node x = nodes.get(i);
for (int j = i + 1; j < nodes.size(); j++) {
Node y = nodes.get(j);
//if the current nodes under consideration were already handled by similarNodes, skip this pair
String xName = x.getName();
String yName = y.getName();
boolean skippair = false;
Iterator itx1 = simListX.iterator();
Iterator ity1 = simListY.iterator();
while (itx1.hasNext() && ity1.hasNext()) {
Node x1 = itx1.next();
Node y1 = ity1.next();
String simX = x1.getName();
String simY = y1.getName();
if ((Objects.equals(xName, simX) && Objects.equals(yName, simY)) ||
(Objects.equals(xName, simY) && Objects.equals(yName, simX))) {
skippair = true;
System.out.println("Skipping pair x,y = " + xName + ", " + yName);
break;
}
}
if (skippair) continue;
if (this.externalGraph != null) {
Node x2 = this.externalGraph.getNode(x.getName());
Node y2 = this.externalGraph.getNode(y.getName());
if (!this.externalGraph.isAdjacentTo(x2, y2)) {
continue;
}
}
IndependenceResult result;
try {
this.numIndependenceTests++;
result = test.checkIndependence(x, y, empty);
System.out.println("############# independence given empty set: x,y " + x + ", " +
y + " independence = " + result.isIndependent());
} catch (Exception e) {
e.printStackTrace();
result = new IndependenceResult(new IndependenceFact(x, y, empty), false, Double.NaN, Double.NaN);
}
boolean noEdgeRequired =
this.knowledge.noEdgeRequired(x.getName(), y.getName());
// getSepsets().setReturnEmptyIfNotSet(false); // added 05.30.2016
if (result.isIndependent() && noEdgeRequired) {
// if (!getSepsets().isReturnEmptyIfNotSet()) {
getSepsets().set(x, y, empty);
System.out.println("$$$$$$$$$$$ look for similar pairs x,y = " + x + ", " + y);
List> simList = returnSimilarPairs(test, x, y);
if (simList.isEmpty()) continue;
List x1List = simList.get(0);
List y1List = simList.get(1);
simListX.addAll(x1List);
simListY.addAll(y1List);
Iterator itx = x1List.iterator();
Iterator ity = y1List.iterator();
while (itx.hasNext() && ity.hasNext()) {
Node x1 = itx.next();
Node y1 = ity.next();
System.out.println("$$$$$$$$$$$ found similar pair x,y = " + x1 + ", " + y1);
getSepsets().set(x1, y1, empty);
}
TetradLogger.getInstance().log("independencies", LogUtilsSearch.independenceFact(x, y, empty) + " score = " +
this.nf.format(result.getScore()));
if (this.verbose) {
this.out.println(LogUtilsSearch.independenceFact(x, y, empty) + " score = " +
this.nf.format(result.getScore()));
}
} else if (!forbiddenEdge(x, y)) {
System.out.println("adding edge between x = " + x + " and y = " + y);
adjacencies.get(x).add(y);
adjacencies.get(y).add(x);
// This would add edges to all similar pairs that are found to be dependent...
List> simList = returnSimilarPairs(test, x, y);
if (simList.isEmpty()) continue;
List x1List = simList.get(0);
List y1List = simList.get(1);
simListX.addAll(x1List);
simListY.addAll(y1List);
Iterator itx = x1List.iterator();
Iterator ity = y1List.iterator();
while (itx.hasNext() && ity.hasNext()) {
Node x1 = itx.next();
Node y1 = ity.next();
System.out.println("$$$$$$$$$$$ similar pair x,y = " + x1 + ", " + y1);
System.out.println("adding edge between x = " + x1 + " and y = " + y1);
adjacencies.get(x1).add(y1);
adjacencies.get(y1).add(x1);
}
if (this.verbose) {
TetradLogger.getInstance().log("dependencies", LogUtilsSearch.independenceFact(x, y, empty) + " score = " +
this.nf.format(result.getScore()));
}
}
}
}
return freeDegree(nodes, adjacencies) > 0;
}
private int freeDegree(List nodes, Map> adjacencies) {
int max = 0;
for (Node x : nodes) {
Set opposites = adjacencies.get(x);
for (Node y : opposites) {
Set adjx = new HashSet<>(opposites);
adjx.remove(y);
if (adjx.size() > max) {
max = adjx.size();
}
}
}
return max;
}
private boolean forbiddenEdge(Node x, Node y) {
String name1 = x.getName();
String name2 = y.getName();
if (this.knowledge.isForbidden(name1, name2) &&
this.knowledge.isForbidden(name2, name1)) {
this.logger.log("edgeRemoved", "Removed " + Edges.undirectedEdge(x, y) + " because it was " +
"forbidden by background knowledge.");
return true;
}
return false;
}
private boolean searchAtDepth(List nodes, IndependenceTest test, Map> adjacencies, int depth) {
int count = 0;
for (Node x : nodes) {
if (this.verbose) {
if (++count % 100 == 0) this.out.println("count " + count + " of " + nodes.size());
}
List adjx = new ArrayList<>(adjacencies.get(x));
EDGE:
for (Node y : adjx) {
List _adjx = new ArrayList<>(adjacencies.get(x));
_adjx.remove(y);
List ppx = possibleParents(x, _adjx, this.knowledge);
if (ppx.size() >= depth) {
ChoiceGenerator cg = new ChoiceGenerator(ppx.size(), depth);
int[] choice;
while ((choice = cg.next()) != null) {
Set condSet = GraphUtils.asSet(choice, ppx);
boolean independent;
try {
this.numIndependenceTests++;
independent = test.checkIndependence(x, y, condSet).isIndependent();
} catch (Exception e) {
independent = false;
}
boolean noEdgeRequired =
this.knowledge.noEdgeRequired(x.getName(), y.getName());
if (independent && noEdgeRequired) {
adjacencies.get(x).remove(y);
adjacencies.get(y).remove(x);
getSepsets().set(x, y, condSet);
// This is the added component to enforce repeating structure
removeSimilarPairs(adjacencies, test, x, y, condSet);
continue EDGE;
}
}
}
}
}
return freeDegree(nodes, adjacencies) > depth;
}
private List possibleParents(Node x, List adjx,
Knowledge knowledge) {
List possibleParents = new LinkedList<>();
String _x = x.getName();
for (Node z : adjx) {
String _z = z.getName();
if (possibleParentOf(_z, _x, knowledge)) {
possibleParents.add(z);
}
}
return possibleParents;
}
private boolean possibleParentOf(String z, String x, Knowledge knowledge) {
return !knowledge.isForbidden(z, x) && !knowledge.isRequired(x, z);
}
// removeSimilarPairs based on orientSimilarPairs in SvarFciOrient.java by Entner and Hoyer
private void removeSimilarPairs(Map> adjacencies, IndependenceTest test, Node x, Node y, Set condSet) {
System.out.println("Entering removeSimilarPairs method...");
System.out.println("original independence: " + x + " and " + y + " conditional on " + condSet);
if (x.getName().equals("time") || y.getName().equals("time")) {
System.out.println("Not removing similar pairs b/c variable pair includes time.");
return;
}
for (Node tempNode : condSet) {
if (tempNode.getName().equals("time")) {
System.out.println("Not removing similar pairs b/c conditioning set includes time.");
return;
}
}
int ntiers = this.knowledge.getNumTiers();
int indx_tier = this.knowledge.isInWhichTier(x);
int indy_tier = this.knowledge.isInWhichTier(y);
int tier_diff = FastMath.max(indx_tier, indy_tier) - FastMath.min(indx_tier, indy_tier);
int indx_comp = -1;
int indy_comp = -1;
List tier_x = this.knowledge.getTier(indx_tier);
List tier_y = this.knowledge.getTier(indy_tier);
int i;
for (i = 0; i < tier_x.size(); ++i) {
if (getNameNoLag(x.getName()).equals(getNameNoLag(tier_x.get(i)))) {
indx_comp = i;
break;
}
}
for (i = 0; i < tier_y.size(); ++i) {
if (getNameNoLag(y.getName()).equals(getNameNoLag(tier_y.get(i)))) {
indy_comp = i;
break;
}
}
if (indx_comp == -1) System.out.println("WARNING: indx_comp = -1!!!! ");
if (indy_comp == -1) System.out.println("WARNING: indy_comp = -1!!!! ");
for (i = 0; i < ntiers - tier_diff; ++i) {
if (this.knowledge.getTier(i).size() == 1) continue;
String A;
Node x1;
String B;
Node y1;
if (indx_tier >= indy_tier) {
List tmp_tier1 = this.knowledge.getTier(i + tier_diff);
List tmp_tier2 = this.knowledge.getTier(i);
A = tmp_tier1.get(indx_comp);
B = tmp_tier2.get(indy_comp);
if (A.equals(B)) continue;
if (A.equals(tier_x.get(indx_comp)) && B.equals(tier_y.get(indy_comp))) continue;
if (B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue;
x1 = test.getVariable(A);
y1 = test.getVariable(B);
adjacencies.get(x1).remove(y1);
adjacencies.get(y1).remove(x1);
System.out.println("removed edge between " + x1 + " and " + y1 + " because of structure knowledge");
Set condSetAB = new HashSet<>();
for (Node tempNode : condSet) {
int ind_temptier = this.knowledge.isInWhichTier(tempNode);
List temptier = this.knowledge.getTier(ind_temptier);
int ind_temp = -1;
for (int j = 0; j < temptier.size(); ++j) {
if (getNameNoLag(tempNode.getName()).equals(getNameNoLag(temptier.get(j)))) {
ind_temp = j;
break;
}
}
int cond_diff = indx_tier - ind_temptier;
int condAB_tier = this.knowledge.isInWhichTier(x1) - cond_diff;
if (condAB_tier < 0 || condAB_tier > (ntiers - 1)
|| this.knowledge.getTier(condAB_tier).size() == 1) { // added condition for time tier 05.29.2016
System.out.println("Warning: For nodes " + x1 + "," + y1 + " the conditioning variable is outside "
+ "of window, so not added to SepSet");
continue;
}
List new_tier = this.knowledge.getTier(condAB_tier);
String tempNode1 = new_tier.get(ind_temp);
System.out.println("adding variable " + tempNode1 + " to SepSet");
condSetAB.add(test.getVariable(tempNode1));
}
System.out.println("done");
getSepsets().set(x1, y1, condSetAB);
} else {
List tmp_tier1 = this.knowledge.getTier(i);
List tmp_tier2 = this.knowledge.getTier(i + tier_diff);
A = tmp_tier1.get(indx_comp);
B = tmp_tier2.get(indy_comp);
if (A.equals(B)) continue;
if (A.equals(tier_x.get(indx_comp)) && B.equals(tier_y.get(indy_comp))) continue;
if (B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue;
x1 = test.getVariable(A);
y1 = test.getVariable(B);
adjacencies.get(x1).remove(y1);
adjacencies.get(y1).remove(x1);
System.out.println("removed edge between " + x1 + " and " + y1 + " because of structure knowledge");
Set condSetAB = new HashSet<>();
for (Node tempNode : condSet) {
int ind_temptier = this.knowledge.isInWhichTier(tempNode);
List temptier = this.knowledge.getTier(ind_temptier);
int ind_temp = -1;
for (int j = 0; j < temptier.size(); ++j) {
if (getNameNoLag(tempNode.getName()).equals(getNameNoLag(temptier.get(j)))) {
ind_temp = j;
break;
}
}
int cond_diff = indx_tier - ind_temptier;
int condAB_tier = this.knowledge.isInWhichTier(x1) - cond_diff;
if (condAB_tier < 0 || condAB_tier > (ntiers - 1)
|| this.knowledge.getTier(condAB_tier).size() == 1) { // added condition for time tier 05.29.2016
System.out.println("Warning: For nodes " + x1 + "," + y1 + " the conditioning variable is outside "
+ "of window, so not added to SepSet");
continue;
}
List new_tier = this.knowledge.getTier(condAB_tier);
String tempNode1 = new_tier.get(ind_temp);
System.out.println("adding variable " + tempNode1 + " to SepSet");
condSetAB.add(test.getVariable(tempNode1));
}
System.out.println("done");
getSepsets().set(x1, y1, condSetAB);
}
}
}
// returnSimilarPairs based on orientSimilarPairs in SvarFciOrient.java by Entner and Hoyer
private List> returnSimilarPairs(IndependenceTest test, Node x, Node y) {
System.out.println("$$$$$ Entering returnSimilarPairs method with x,y = " + x + ", " + y);
if (x.getName().equals("time") || y.getName().equals("time")) {
return new ArrayList<>();
}
int ntiers = this.knowledge.getNumTiers();
int indx_tier = this.knowledge.isInWhichTier(x);
int indy_tier = this.knowledge.isInWhichTier(y);
int tier_diff = FastMath.max(indx_tier, indy_tier) - FastMath.min(indx_tier, indy_tier);
int indx_comp = -1;
int indy_comp = -1;
List tier_x = this.knowledge.getTier(indx_tier);
List tier_y = this.knowledge.getTier(indy_tier);
int i;
for (i = 0; i < tier_x.size(); ++i) {
if (getNameNoLag(x.getName()).equals(getNameNoLag(tier_x.get(i)))) {
indx_comp = i;
break;
}
}
for (i = 0; i < tier_y.size(); ++i) {
if (getNameNoLag(y.getName()).equals(getNameNoLag(tier_y.get(i)))) {
indy_comp = i;
break;
}
}
System.out.println("original independence: " + x + " and " + y);
if (indx_comp == -1) System.out.println("WARNING: indx_comp = -1!!!! ");
if (indy_comp == -1) System.out.println("WARNING: indy_comp = -1!!!! ");
List simListX = new ArrayList<>();
List simListY = new ArrayList<>();
for (i = 0; i < ntiers - tier_diff; ++i) {
if (this.knowledge.getTier(i).size() == 1) continue;
String A;
Node x1;
String B;
Node y1;
List tmp_tier1;
List tmp_tier2;
if (indx_tier >= indy_tier) {
tmp_tier1 = this.knowledge.getTier(i + tier_diff);
tmp_tier2 = this.knowledge.getTier(i);
} else {
tmp_tier1 = this.knowledge.getTier(i);
tmp_tier2 = this.knowledge.getTier(i + tier_diff);
}
A = tmp_tier1.get(indx_comp);
B = tmp_tier2.get(indy_comp);
if (A.equals(B)) continue;
if (A.equals(tier_x.get(indx_comp)) && B.equals(tier_y.get(indy_comp))) continue;
if (B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue;
x1 = test.getVariable(A);
y1 = test.getVariable(B);
System.out.println("Adding pair to simList = " + x1 + " and " + y1);
simListX.add(x1);
simListY.add(y1);
}
List> pairList = new ArrayList<>();
pairList.add(simListX);
pairList.add(simListY);
return (pairList);
}
private String getNameNoLag(Object obj) {
String tempS = obj.toString();
if (tempS.indexOf(':') == -1) {
return tempS;
} else return tempS.substring(0, tempS.indexOf(':'));
}
}