edu.cmu.tetrad.search.Fges Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph //
// Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.score.GraphScore;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.ScoredGraph;
import edu.cmu.tetrad.search.utils.Bes;
import edu.cmu.tetrad.search.utils.DagScorer;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import org.jetbrains.annotations.NotNull;
import java.io.PrintStream;
import java.util.*;
import java.util.concurrent.*;
import static edu.cmu.tetrad.graph.Edges.directedEdge;
import static org.apache.commons.math3.util.FastMath.max;
import static org.apache.commons.math3.util.FastMath.min;
/**
* Implements the Fast Greedy Equivalence Search (FGES) algorithm. This is an implementation of the Greedy Equivalence
* Search algorithm, originally due to Chris Meek but developed significantly by Max Chickering. FGES uses with some
* optimizations that allow it to scale accurately to thousands of variables accurately for the sparse case. The
* reference for FGES is this:
*
* Ramsey, J., Glymour, M., Sanchez-Romero, R., & Glymour, C. (2017). A million variables and more: the fast greedy
* equivalence search algorithm for learning high-dimensional graphical causal models, with an application to functional
* magnetic resonance images. International journal of data science and analytics, 3, 121-129.
*
* The reference for Chickering's GES is this:
*
* Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine Learning Research.
*
* FGES works for the continuous case, the discrete case, and the mixed continuous/discrete case, so long as a BIC score
* is available for the type of data in question.
*
* To speed things up, it has been assumed that variables X and Y with zero correlation do not correspond to edges in
* the graph. This is a restricted form of the heuristic speedup assumption, something GES does not assume. This
* heuristic speedup assumption needs to be explicitly turned on using setHeuristicSpeedup(true).
*
* Also, edges to be added or remove from the graph in the forward or backward phase, respectively are cached, together
* with the ancillary information needed to do the additions or removals, to reduce rescoring.
*
* A number of other optimizations were also. See code for de tails.
*
* This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
* tiers.
*
* @author Ricardo Silva
* @author josephramsey
* @version $Id: $Id
* @see Grasp
* @see Boss
* @see Sp
* @see Knowledge
*/
public final class Fges implements IGraphSearch, DagScorer {
/**
* Used to find semidirected paths for cycle checking.
*/
private final Set emptySet = new HashSet<>();
/**
* Used to find semidirected paths for cycle checking.
*/
private final int[] count = new int[1];
/**
* Used to find semidirected paths for cycle checking.
*/
private final int depth = 10000;
/**
* The logger for this class. The config needs to be set.
*/
private final TetradLogger logger = TetradLogger.getInstance();
/**
* The top n graphs found by the algorithm, where n is numPatternsToStore.
*/
private final LinkedList topGraphs = new LinkedList<>();
/**
* Potential arrows sorted by bump high to low. The first one is a candidate for adding to the graph.
*/
private final SortedSet sortedArrows = new ConcurrentSkipListSet<>();
/**
* Map from edges to arrows.
*/
private final Map arrowsMap = new ConcurrentHashMap<>();
/**
* The fork join pool.
*/
private ForkJoinPool pool;
/**
* Whether one-edge faithfulness should be assumed.
*/
private boolean faithfulnessAssumed = false;
/**
* Specification of forbidden and required edges.
*/
private Knowledge knowledge = new Knowledge();
/**
* List of variables in the data set, in order.
*/
private List variables;
/**
* An initial graph to start from.
*/
private Graph initialGraph;
/**
* If non-null, edges not adjacent in this graph will not be added.
*/
private Graph boundGraph = null;
/**
* Elapsed time of the most recent search.
*/
private long elapsedTime;
/**
* The totalScore for discrete searches.
*/
private Score score;
/**
* True if verbose output should be printed.
*/
private boolean verbose = false;
/**
* Whether Meek rules should be verbose.
*/
private boolean meekVerbose = false;
/**
* Map from variables to their column indices in the data set.
*/
private ConcurrentMap hashIndices;
/**
* A graph where X--Y means that X and Y have non-zero total effect on one another.
*/
private Graph effectEdgesGraph;
/**
* Where printed output is sent.
*/
private transient PrintStream out = System.out;
/**
* The graph being constructed.
*/
private Graph graph;
/**
* Arrows with the same totalScore are stored in this list to distinguish their order in sortedArrows. The ordering
* doesn't matter; it just has to be transitive.
*/
private int arrowIndex = 0;
/**
* The score of the model.
*/
private double modelScore;
/**
* Internal.
*/
private Mode mode = Mode.heuristicSpeedup;
/**
* Bounds the degree of the graph.
*/
private int maxDegree = -1;
/**
* True if the first step of adding an edge to an empty graph should be scored in both directions for each edge with
* the maximum score chosen.
*/
private boolean symmetricFirstStep = false;
/**
* The number of threads to use to run the algorithm.
*/
private int numThreads = 1;
/**
* Constructor. Construct a Score and pass it in here. The totalScore should return a positive value in case of
* conditional dependence and a negative values in case of conditional independence. See Chickering (2002), locally
* consistent scoring criterion. This by default uses all the processors on the machine.
*
* @param score The score to use. The score should yield better scores for more correct local models. The algorithm
* as given by Chickering assumes the score will be a BIC score of some sort.
*/
public Fges(Score score) {
if (score == null) {
throw new NullPointerException();
}
setScore(score);
this.graph = new EdgeListGraph(getVariables());
this.pool = new ForkJoinPool(numThreads);
}
/**
* Used to find semidirected paths for cycle checking.
*/
private static Node traverseSemiDirected(Node node, Edge edge) {
if (node == edge.getNode1()) {
if (edge.getEndpoint1() == Endpoint.TAIL) {
return edge.getNode2();
}
} else if (node == edge.getNode2()) {
if (edge.getEndpoint2() == Endpoint.TAIL) {
return edge.getNode1();
}
}
return null;
}
/**
* Greedy equivalence search: Start from the empty graph, add edges till the model is significant. Then start
* deleting edges till a minimum is achieved.
*
* @return the resulting Pattern.
*/
public Graph search() {
long start = MillisecondTimes.timeMillis();
topGraphs.clear();
graph = new EdgeListGraph(getVariables());
if (boundGraph != null) {
boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables());
}
if (initialGraph != null) {
graph = new EdgeListGraph(initialGraph);
graph = GraphUtils.replaceNodes(graph, getVariables());
}
addRequiredEdges(graph);
initializeEffectEdges(getVariables());
this.mode = Mode.heuristicSpeedup;
fes();
bes();
this.mode = Mode.coverNoncolliders;
fes();
bes();
if (!faithfulnessAssumed) {
this.mode = Mode.allowUnfaithfulness;
fes();
bes();
}
long endTime = MillisecondTimes.timeMillis();
this.elapsedTime = endTime - start;
if (verbose) {
this.logger.log("Elapsed time = " + (elapsedTime) / 1000. + " s");
}
this.modelScore = scoreDag(GraphTransforms.dagFromCpdag(graph, null, true, verbose), true);
return graph;
}
/**
* Sets whether one-edge faithfulness should be assumed. This assumption is that if X and Y are unconditionally
* dependent, then there is an edge between X and Y in the graph. This could in principle be false, as for a path
* cancellation whether one path is A->B->C->D and the other path is A->D.
*
* @param faithfulnessAssumed True, if so.
*/
public void setFaithfulnessAssumed(boolean faithfulnessAssumed) {
this.faithfulnessAssumed = faithfulnessAssumed;
}
/**
* Returns the background knowledge.
*
* @return This knowledge
*/
public Knowledge getKnowledge() {
return knowledge;
}
/**
* Sets the background knowledge.
*
* @param knowledge the knowledge object, specifying forbidden and required edges.
*/
public void setKnowledge(Knowledge knowledge) {
if (knowledge == null) {
throw new NullPointerException();
}
this.knowledge = knowledge;
}
/**
* Returns the elapsed time of the search.
*
* @return This elapsed time.
*/
public long getElapsedTime() {
return elapsedTime;
}
/**
* Scores a Directed Acyclic Graph (DAG) based on its structure.
*
* @param dag The input DAG to be scored.
* @return The score of the DAG.
*/
public double scoreDag(Graph dag) {
return scoreDag(dag, false);
}
/**
* Sets whether verbose output should be produced. Verbose output generated by the Meek rules is treated
* separately.
*
* @param verbose True iff the case.
* @see #setMeekVerbose(boolean)
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Sets whether verbose output should be produced for the Meek rules.
*
* @param meekVerbose True iff the case.
*/
public void setMeekVerbose(boolean meekVerbose) {
this.meekVerbose = meekVerbose;
}
/**
* Returns the output stream used for printing.
*
* @return the output stream used for printing.
*/
public PrintStream getOut() {
return out;
}
/**
* Sets the output stream that output (except for log output) should be sent to. By default System.out.
*
* @param out This print stream.
*/
public void setOut(PrintStream out) {
this.out = out;
}
/**
* If non-null, edges not adjacent in this graph will not be added.
*
* @param boundGraph This bound graph.
*/
public void setBoundGraph(Graph boundGraph) {
if (boundGraph == null) {
this.boundGraph = null;
} else {
this.boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables());
}
}
/**
* The maximum of parents any nodes can have in the output pattern.
*
* @return -1 for unlimited.
*/
public int getMaxDegree() {
return maxDegree;
}
/**
* The maximum of parents any nodes can have in the output pattern.
*
* @param maxDegree -1 for unlimited.
*/
public void setMaxDegree(int maxDegree) {
if (maxDegree < -1) {
throw new IllegalArgumentException();
}
this.maxDegree = maxDegree;
}
/**
* Sets whether the first step of the procedure will score both X->Y and Y->X and prefer the higher score (for
* adding X--Y to the graph).
*
* @param symmetricFirstStep True iff the case.
*/
public void setSymmetricFirstStep(boolean symmetricFirstStep) {
this.symmetricFirstStep = symmetricFirstStep;
}
/**
* Returns the score of the final search model.
*
* @return This score.
*/
public double getModelScore() {
return modelScore;
}
/**
* Sets the discrete scoring function to use.
*/
private void setScore(Score score) {
this.score = score;
this.variables = new ArrayList<>();
for (Node node : score.getVariables()) {
if (node.getNodeType() == NodeType.MEASURED) {
this.variables.add(node);
}
}
buildIndexing(score.getVariables());
this.maxDegree = this.score.getMaxDegree();
}
private int getChunkSize(int n) {
int chunk = n / Runtime.getRuntime().availableProcessors();
if (chunk < 100) chunk = 100;
return chunk;
}
/**
* Initializes the effectEdgesGraph using the given list of nodes.
*
* @param nodes The list of nodes.
*/
private void initializeEffectEdges(final List nodes) {
long start = MillisecondTimes.timeMillis();
this.effectEdgesGraph = new EdgeListGraph(nodes);
List> tasks = new ArrayList<>();
int chunkSize = getChunkSize(nodes.size());
for (int i = 0; i < nodes.size(); i += chunkSize) {
if (Thread.currentThread().isInterrupted()) {
this.pool.shutdownNow();
throw new RuntimeException("Interrupted");
// break;
}
NodeTaskEmptyGraph task = new NodeTaskEmptyGraph(i, min(nodes.size(), i + chunkSize), nodes, emptySet);
tasks.add(task);
}
try {
pool.invokeAll(tasks);
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
long stop = MillisecondTimes.timeMillis();
if (verbose) {
out.println("Elapsed initializeForwardEdgesFromEmptyGraph = " + (stop - start) + " ms");
}
}
/**
* Performs the Forward Equivalent Search (FES) algorithm.
*
* This algorithm searches for causal relationships in a directed acyclic graph (DAG) by iteratively examining each
* arrow and determining if it satisfies the FES conditions.
*
* The FES algorithm follows these steps: - Re-evaluate the forward relationships of all variables in the graph -
* While there are still arrows to examine: - Get the first arrow from the sorted arrows list - Remove the arrow
* from the sorted arrows list - Retrieve the source (A) and target (B) nodes of the arrow - If A and B are already
* adjacent, skip this arrow and move to the next one - If the degree of A or B exceeds the maximum degree allowed,
* skip this arrow - If the N_A_Y_X value of A and B differs from the arrow's N_A_Y_X value, skip this arrow - If
* the set of T-neighbors of A and B differs from the arrow's T-neighbors value, skip this arrow - If the set of
* parents of B differs from the arrow's parents value, skip this arrow - If the insertion of the arrow creates a
* cycle or violates the FES conditions, skip this arrow - Insert the arrow into the graph with the appropriate bump
* value - Revert the graph to complete partial DAG (CPDAG) - Recalculate the forward relationships of the nodes
* involved in the insertion
*/
private void fes() {
int maxDegree = this.maxDegree == -1 ? 1000 : this.maxDegree;
reevaluateForward(new HashSet<>(variables));
while (!sortedArrows.isEmpty()) {
Arrow arrow = sortedArrows.first();
sortedArrows.remove(arrow);
Node x = arrow.getA();
Node y = arrow.getB();
if (graph.isAdjacentTo(x, y)) {
continue;
}
if (graph.getDegree(x) > maxDegree - 1) {
continue;
}
if (graph.getDegree(y) > maxDegree - 1) {
continue;
}
if (!getNaYX(x, y).equals(arrow.getNaYX())) {
continue;
}
if (!new HashSet<>(getTNeighbors(x, y)).equals(arrow.getTNeighbors())) {
continue;
}
if (!new HashSet<>(graph.getParents(y)).equals(new HashSet<>(arrow.getParents()))) {
continue;
}
if (!validInsert(x, y, arrow.getHOrT(), getNaYX(x, y))) {
continue;
}
insert(x, y, arrow.getHOrT(), arrow.getBump());
Set process = revertToCpdag();
process.add(x);
process.add(y);
process.addAll(getCommonAdjacents(x, y));
reevaluateForward(new HashSet<>(process));
}
}
/**
* Runs the Bes algorithm.
*
* @see Bes
*/
private void bes() {
Bes bes = new Bes(score);
bes.setDepth(depth);
bes.setVerbose(verbose);
bes.setKnowledge(knowledge);
bes.bes(graph, variables);
}
/**
* Checks if knowledge exists.
*
* @return true if knowledge exists, false otherwise.
*/
private boolean existsKnowledge() {
return !knowledge.isEmpty();
}
/**
* Reevaluates the forward direction of arrows for a set of nodes.
*
* @param nodes the set of nodes for which to reevaluate arrows
*/
private void reevaluateForward(final Set nodes) {
class AdjTask implements Callable {
private final List nodes;
private final int from;
private final int to;
private AdjTask(List nodes, int from, int to) {
this.nodes = nodes;
this.from = from;
this.to = to;
}
@Override
public Boolean call() {
for (int _y = from; _y < to; _y++) {
if (Thread.currentThread().isInterrupted()) break;
Node y = nodes.get(_y);
List adj;
if (mode == Mode.heuristicSpeedup) {
adj = effectEdgesGraph.getAdjacentNodes(y);
} else if (mode == Mode.coverNoncolliders) {
Set g = new HashSet<>();
for (Node n : graph.getAdjacentNodes(y)) {
for (Node m : graph.getAdjacentNodes(n)) {
if (graph.isAdjacentTo(y, m)) {
continue;
}
if (graph.isDefCollider(m, n, y)) {
continue;
}
g.add(m);
}
}
adj = new ArrayList<>(g);
} else if (mode == Mode.allowUnfaithfulness) {
adj = new ArrayList<>(variables);
} else {
throw new IllegalStateException();
}
for (Node x : adj) {
if (boundGraph != null && !(boundGraph.isAdjacentTo(x, y))) {
continue;
}
calculateArrowsForward(x, y);
}
}
return true;
}
}
List> tasks = new ArrayList<>();
int chunkSize = getChunkSize(nodes.size());
for (int i = 0; i < nodes.size(); i += chunkSize) {
if (Thread.currentThread().isInterrupted()) {
pool.shutdownNow();
throw new RuntimeException("Interrupted");
}
AdjTask task = new AdjTask(new ArrayList<>(nodes), i, min(nodes.size(), i + chunkSize));
tasks.add(task);
}
try {
pool.invokeAll(tasks);
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
}
/**
* Calculates and adds forward arrows from node a to node b based on the given conditions.
*
* @param a the starting node
* @param b the ending node
*/
private void calculateArrowsForward(Node a, Node b) {
if (boundGraph != null && !boundGraph.isAdjacentTo(a, b)) {
return;
}
if (a == b) return;
if (graph.isAdjacentTo(a, b)) return;
if (existsKnowledge()) {
if (getKnowledge().isForbidden(a.getName(), b.getName())) {
return;
}
}
Set naYX = getNaYX(a, b);
List TNeighbors = getTNeighbors(a, b);
Set parents = new HashSet<>(graph.getParents(b));
HashSet TNeighborsSet = new HashSet<>(TNeighbors);
ArrowConfig config = new ArrowConfig(TNeighborsSet, naYX, parents);
ArrowConfig storedConfig = arrowsMap.get(directedEdge(a, b));
if (storedConfig != null && storedConfig.equals(config)) return;
arrowsMap.put(directedEdge(a, b), new ArrowConfig(TNeighborsSet, naYX, parents));
int _depth = min(depth, TNeighbors.size());
final SublistGenerator gen = new SublistGenerator(TNeighbors.size(), _depth);// TNeighbors.size());
int[] choice;
Set maxT = null;
double maxBump = Double.NEGATIVE_INFINITY;
List> TT = new ArrayList<>();
while ((choice = gen.next()) != null) {
Set _T = GraphUtils.asSet(choice, TNeighbors);
TT.add(_T);
}
class EvalTask implements Callable {
private final List> Ts;
private final ConcurrentMap hashIndices;
private final int from;
private final int to;
private Set maxT = null;
private double maxBump = Double.NEGATIVE_INFINITY;
public EvalTask(List> Ts, int from, int to, ConcurrentMap hashIndices) {
this.Ts = Ts;
this.hashIndices = hashIndices;
this.from = from;
this.to = to;
}
@Override
public EvalPair call() {
for (int k = from; k < to; k++) {
if (Thread.currentThread().isInterrupted()) break;
double _bump = insertEval(a, b, Ts.get(k), naYX, parents, this.hashIndices);
if (_bump > maxBump) {
maxT = Ts.get(k);
maxBump = _bump;
}
}
EvalPair pair = new EvalPair();
pair.T = maxT;
pair.bump = maxBump;
return pair;
}
}
int chunkSize = getChunkSize(TT.size());
List tasks = new ArrayList<>();
for (int i = 0; i < TT.size(); i += chunkSize) {
if (Thread.currentThread().isInterrupted()) {
pool.shutdownNow();
// break;
throw new RuntimeException("Interrupted");
}
EvalTask task = new EvalTask(TT, i, min(TT.size(), i + chunkSize), hashIndices);
tasks.add(task);
}
List> futures = pool.invokeAll(tasks);
for (Future future : futures) {
try {
EvalPair pair = future.get();
if (pair.bump > maxBump) {
maxT = pair.T;
maxBump = pair.bump;
}
} catch (InterruptedException | ExecutionException e) {
Thread.currentThread().interrupt();
TetradLogger.getInstance().log(e.getMessage());
return;
}
}
if (maxBump > 0) {
addArrowForward(a, b, maxT, TNeighborsSet, naYX, parents, maxBump);
}
}
/**
* Adds an arrow connecting two nodes in the specified direction.
*
* @param a the starting node of the arrow
* @param b the ending node of the arrow
* @param hOrT the set of nodes representing heads or tails
* @param TNeighbors the set of nodes representing tail neighbors
* @param naYX the set of nodes representing NA and YX
* @param parents the set of parent nodes
* @param bump the bump value of the arrow
*/
private void addArrowForward(Node a, Node b, Set hOrT, Set TNeighbors, Set naYX, Set parents, double bump) {
Arrow arrow = new Arrow(bump, a, b, hOrT, TNeighbors, naYX, parents, arrowIndex++);
sortedArrows.add(arrow);
}
/**
* Returns a set of common adjacent nodes between two given nodes.
*
* @param x the first node
* @param y the second node
* @return a set of common adjacent nodes
*/
private Set getCommonAdjacents(Node x, Node y) {
Set adj = new HashSet<>(graph.getAdjacentNodes(x));
adj.retainAll(graph.getAdjacentNodes(y));
return adj;
}
/**
* Get all adj that are connected to Y by an undirected edge and not adjacent to X.
*
* @param x the first node
* @param y the second node
* @return a list of T-neighbors of the two nodes
*/
private List getTNeighbors(Node x, Node y) {
Set yEdges = graph.getEdges(y);
List tNeighbors = new ArrayList<>();
for (Edge edge : yEdges) {
if (!Edges.isUndirectedEdge(edge)) {
continue;
}
Node z = edge.getDistalNode(y);
if (graph.isAdjacentTo(z, x)) {
continue;
}
tNeighbors.add(z);
}
return tNeighbors;
}
/**
* Evaluate the Insert(X, Y, TNeighbors) operator (Definition 12 from Chickering, 2002).
*
* @param x The starting node of the edge.
* @param y The ending node of the edge.
* @param T The set of nodes representing the graph.
* @param naYX The set of nodes not adjacent to 'y' but adjacent to 'x'.
* @param parents The set of parent nodes of 'x'.
* @param hashIndices The map of nodes to their corresponding indices.
* @return The evaluation score after inserting the edge between 'x' and 'y'.
*/
private double insertEval(Node x, Node y, Set T, Set naYX, Set parents, Map hashIndices) {
Set set = new HashSet<>(naYX);
set.addAll(T);
set.addAll(parents);
return scoreGraphChange(x, y, set, hashIndices);
}
/**
* Do an actual insertion. (Definition 12 from Chickering, 2002).
*
* @param x the source node
* @param y the target node
* @param T a set of nodes to be updated
* @param bump a value used for updating the graph
*/
private void insert(Node x, Node y, Set T, double bump) {
graph.addDirectedEdge(x, y);
int numEdges = graph.getNumEdges();
if (numEdges % 1000 == 0) {
out.println("Num edges added: " + numEdges);
}
if (verbose) {
int cond = T.size() + getNaYX(x, y).size() + graph.getParents(y).size();
if (verbose) {
final String message = graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + T + " " + bump + " degree = " + GraphUtils.getDegree(graph) + " indegree = " + GraphUtils.getIndegree(graph) + " cond = " + cond;
TetradLogger.getInstance().log(message);
}
}
for (Node _t : T) {
graph.removeEdge(_t, y);
graph.addDirectedEdge(_t, y);
if (verbose) {
String message = "--- Directing " + graph.getEdge(_t, y);
TetradLogger.getInstance().log(message);
}
}
}
/**
* Checks if inserting nodes into a graph conforms to certain conditions. (Theorem 15 from Chickering, 2002).
*
* @param x the first node to be inserted
* @param y the second node to be inserted
* @param T the set of existing nodes
* @param naYX the set of nodes not adjacent to x or y
* @return true if inserting the nodes satisfies the conditions, false otherwise
*/
private boolean validInsert(Node x, Node y, Set T, Set naYX) {
boolean violatesKnowledge = false;
if (existsKnowledge()) {
if (knowledge.isForbidden(x.getName(), y.getName())) {
violatesKnowledge = true;
}
for (Node t : T) {
if (knowledge.isForbidden(t.getName(), y.getName())) {
violatesKnowledge = true;
}
}
}
Set union = new HashSet<>(T);
union.addAll(naYX);
return isClique(union) && semidirectedPathCondition(y, x, union) && !violatesKnowledge;
}
/**
* Adds the required edges to the provided graph based on the existing knowledge.
*
* @param graph the graph to which the required edges will be added
*/
private void addRequiredEdges(Graph graph) {
if (!existsKnowledge()) {
return;
}
for (Iterator it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) {
if (Thread.currentThread().isInterrupted()) {
pool.shutdownNow();
// break;
throw new RuntimeException("Interrupted");
}
KnowledgeEdge next = it.next();
Node nodeA = graph.getNode(next.getFrom());
Node nodeB = graph.getNode(next.getTo());
if (!graph.paths().isAncestorOf(nodeB, nodeA)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeA, nodeB);
if (verbose) {
TetradLogger.getInstance().log("Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
}
}
}
for (Edge edge : graph.getEdges()) {
if (Thread.currentThread().isInterrupted()) {
pool.shutdownNow();
throw new RuntimeException("Interrupted");
}
final String A = edge.getNode1().getName();
final String B = edge.getNode2().getName();
if (knowledge.isForbidden(A, B)) {
Node nodeA = edge.getNode1();
Node nodeB = edge.getNode2();
if (graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
TetradLogger.getInstance().log("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
TetradLogger.getInstance().log("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
} else if (knowledge.isForbidden(B, A)) {
Node nodeA = edge.getNode2();
Node nodeB = edge.getNode1();
if (graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
TetradLogger.getInstance().log("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
TetradLogger.getInstance().log("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
}
}
}
/**
* Checks if at least one node in the given subset is forbidden by knowledge.
*
* @param y The node for which to check if it is forbidden by knowledge.
* @param subset The set of nodes to check against knowledge.
* @return True if at least one node in the subset is forbidden by knowledge, otherwise false.
*/
private boolean invalidSetByKnowledge(Node y, Set subset) {
for (Node node : subset) {
if (getKnowledge().isForbidden(node.getName(), y.getName())) {
return true;
}
}
return false;
}
/**
* Retrieves the set of nodes that are adjacent to node y and are also adjacent to node x.
*
* @param x the first node
* @param y the second node
* @return a set of nodes that are neighbors of both x and y
*/
private Set getNaYX(Node x, Node y) {
List adj = graph.getAdjacentNodes(y);
Set nayx = new HashSet<>();
for (Node z : adj) {
if (z == x) {
continue;
}
Edge yz = graph.getEdge(y, z);
if (!Edges.isUndirectedEdge(yz)) {
continue;
}
if (!graph.isAdjacentTo(z, x)) {
continue;
}
nayx.add(z);
}
return nayx;
}
/**
* Checks if the given set of nodes forms a clique.
*
* @param nodes the set of nodes to check
* @return true if the given set of nodes forms a clique, false otherwise
*/
private boolean isClique(Set nodes) {
List _nodes = new ArrayList<>(nodes);
for (int i = 0; i < _nodes.size(); i++) {
for (int j = i + 1; j < _nodes.size(); j++) {
if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) {
return false;
}
}
}
return true;
}
/**
* Determines whether there exists a semi-directed path from a given source node to a target node that satisfies a
* set of conditions.
*
* @param from the source node
* @param to the target node
* @param cond the set of conditions
* @return {@code true} if a semi-directed path from the source node to the target node exists that satisfies the
* conditions; {@code false} otherwise
* @throws IllegalArgumentException if the source node and the target node are the same
*/
private boolean semidirectedPathCondition(Node from, Node to, Set cond) {
if (from == to) throw new IllegalArgumentException();
Queue Q = new LinkedList<>();
Set V = new HashSet<>();
Q.add(from);
V.add(from);
while (!Q.isEmpty()) {
Node t = Q.remove();
if (cond.contains(t)) {
continue;
}
if (t == to) {
return false;
}
for (Node u : graph.getAdjacentNodes(t)) {
Edge edge = graph.getEdge(t, u);
Node c = traverseSemiDirected(t, edge);
if (c == null) {
continue;
}
if (!V.contains(c)) {
V.add(c);
Q.offer(c);
}
}
}
return true;
}
/**
* Reverts the graph to a completed partially directed acyclic graph (CPDAG) using Meek rules.
*
* @return the set of nodes in the CPDAG
*/
private Set revertToCpdag() {
MeekRules rules = new MeekRules();
rules.setKnowledge(getKnowledge());
rules.setMeekPreventCycles(true);
rules.setVerbose(meekVerbose);
return rules.orientImplied(graph);
}
/**
* Builds indexing for the given list of nodes.
*
* @param nodes the list of nodes to build indexing for
*/
private void buildIndexing(List nodes) {
this.hashIndices = new ConcurrentHashMap<>();
int i = -1;
for (Node n : nodes) {
this.hashIndices.put(n, ++i);
}
}
/**
* Calculates the score of a directed acyclic graph (DAG).
*
* @param dag The DAG to be scored.
* @param recordScores Indicates whether or not to record the scores for each node in the graph.
* @return The total score of the DAG.
*/
private double scoreDag(Graph dag, boolean recordScores) {
if (score instanceof GraphScore) return 0.0;
dag = GraphUtils.replaceNodes(dag, getVariables());
double _score = 0;
for (Node node : getVariables()) {
List x = dag.getParents(node);
int[] parentIndices = new int[x.size()];
int count = 0;
for (Node parent : x) {
parentIndices[count++] = hashIndices.get(parent);
}
final double nodeScore = score.localScore(hashIndices.get(node), parentIndices);
if (recordScores) {
node.addAttribute("Score", nodeScore);
}
_score += nodeScore;
}
if (recordScores) {
graph.addAttribute("Score", _score);
}
return _score;
}
/**
* Calculates the score graph change between two nodes.
*
* @param x The first node.
* @param y The second node.
* @param parents The set of parent nodes.
* @param hashIndices A mapping of nodes to their corresponding indices.
* @return The score graph change between the two nodes.
* @throws IllegalArgumentException If x is the same as y or y is one of x's parents.
*/
private double scoreGraphChange(Node x, Node y, Set parents, Map hashIndices) {
int xIndex = hashIndices.get(x);
int yIndex = hashIndices.get(y);
if (x == y) {
throw new IllegalArgumentException();
}
if (parents.contains(y)) {
throw new IllegalArgumentException();
}
int[] parentIndices = new int[parents.size()];
int count = 0;
for (Node parent : parents) {
parentIndices[count++] = hashIndices.get(parent);
}
return score.localScoreDiff(xIndex, yIndex, parentIndices);
}
/**
* Retrieves a list of variables.
*
* @return the list of variables
*/
private List getVariables() {
return variables;
}
/**
* Sets the number of threads to use. By default, the number of threads is 1.
*
* @param numThreads The number of threads to use. Must be at least 1.
*/
public void setNumThreads(int numThreads) {
if (numThreads < 1) throw new IllegalArgumentException("numThreads must be at least 1.");
this.numThreads = numThreads;
this.pool = new ForkJoinPool(numThreads);
}
/**
* Sets the initial graph for the application.
*
* @param initialGraph the initial graph to be set
*/
public void setInitialGraph(Graph initialGraph) {
this.initialGraph = initialGraph;
}
/**
* Enumeration representing the different modes for the Mode class.
*
* This enumeration defines the different modes that can be used in the Mode class. The available modes are: -
* allowUnfaithfulness: This mode allows unfaithfulness in the results. - heuristicSpeedup: This mode applies a
* heuristic speedup strategy. - coverNoncolliders: This mode includes noncolliders in the coverage.
*/
private enum Mode {
allowUnfaithfulness, heuristicSpeedup, coverNoncolliders
}
/**
* This class represents the configuration of arrows.
*/
private static class ArrowConfig {
/**
* Sets the T set.
*
* @param t the set of Node objects for the T set.
*/
private Set T;
/**
* This variable represents a set of nodes. It is used in the ArrowConfig class to store the nayx set.
*
* @see ArrowConfig
*/
private Set nayx;
/**
* This variable represents a set of parent nodes.
*/
private Set parents;
/**
* Constructs a new ArrowConfig with the specified sets of nodes.
*
* @param T The set of T nodes.
* @param nayx The set of nayx nodes.
* @param parents The set of parent nodes.
*/
public ArrowConfig(Set T, Set nayx, Set parents) {
this.setT(T);
this.setNayx(nayx);
this.setParents(parents);
}
/**
* Sets the T set.
*
* @param t the set of Node objects for the T set.
*/
public void setT(Set t) {
T = t;
}
/**
* Sets the nayx set.
*
* @param nayx the set of Node objects for the nayx set
*/
public void setNayx(Set nayx) {
this.nayx = nayx;
}
/**
* Sets the parents of this node.
*
* @param parents the set of Node objects to set as parents
*/
public void setParents(Set parents) {
this.parents = parents;
}
/**
* Compares this ArrowConfig object with the specified object for equality.
*
* @param o the object to be compared for equality with this ArrowConfig
* @return true if the specified object is equal to this ArrowConfig, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ArrowConfig that = (ArrowConfig) o;
return T.equals(that.T) && nayx.equals(that.nayx) && parents.equals(that.parents);
}
/**
* Returns the hash code value for this ArrowConfig object. The hashCode is calculated by computing a hash code
* for each of the T, nayx, and parents sets using the Objects .hash() method.
*
* @return the computed hash code value for this ArrowConfig object
* @see Objects#hash(Object...)
*/
@Override
public int hashCode() {
return Objects.hash(T, nayx, parents);
}
}
/**
* Basic data structure for an arrow a->b considered for addition or removal from the graph, together with
* associated sets needed to make this determination. For both forward and backward direction, NaYX is needed. For
* the forward direction, TNeighbors neighbors are needed; for the backward direction, H neighbors are needed. See
* Chickering (2002). The totalScore difference resulting from added in the edge (hypothetically) is recorded as the
* "bump."
*/
private static class Arrow implements Comparable {
/**
* Represents the bump value for an Arrow object. The bump value determines the ordering of Arrow objects in a
* SortedSet. A higher bump value indicates a higher priority.
*
* @see Arrow
*/
private final double bump;
/**
* Represents a for a-$gt;b.
*
* @see Arrow
* @see Node
*/
private final Node a;
/**
* Represents b for a->b.
*/
private final Node b;
/**
* Represents a T set or an H set, depending on whehter we are running the forward or the backward search.
*/
private final Set hOrT;
/**
* Represents the NaYX set.
*/
private final Set naYX;
/**
* Represents the parents.
*/
private final Set parents;
/**
* A unique index.
*/
private final int index;
/**
* Represents the T-neighbors
*/
private Set TNeighbors;
/**
* Constructs a new instance of the Arrow class with the given parameters.
*
* @param bump The bump value of the arrow.
* @param a The first node of the arrow.
* @param b The second node of the arrow.
* @param hOrT The set of nodes representing H or T.
* @param capTorH The set of nodes representing Cap or H.
* @param naYX The set of nodes representing Na or YX.
* @param parents The set of parent nodes.
* @param index The index of the arrow.
*/
Arrow(double bump, Node a, Node b, Set hOrT, Set capTorH, Set naYX, Set parents, int index) {
this.bump = bump;
this.a = a;
this.b = b;
this.setTNeighbors(capTorH);
this.hOrT = hOrT;
this.naYX = naYX;
this.index = index;
this.parents = parents;
}
/**
* Retrieves the bump value of the Arrow.
*
* @return The bump value of the Arrow.
*/
public double getBump() {
return bump;
}
/**
* Retrieves the first node of the arrow.
*
* @return the first node of the arrow.
*/
public Node getA() {
return a;
}
/**
* Retrieves the second node of the arrow.
*
* @return the second node of the arrow.
*/
public Node getB() {
return b;
}
/**
* Retrieves the set of nodes representing H or T.
*
* @return The set of nodes representing H or T.
*/
Set getHOrT() {
return hOrT;
}
/**
* Retrieves the neighbors of b that are adjacent to a.
*
* @return The set of nodes representing NaYX.
*/
Set getNaYX() {
return naYX;
}
/**
* Sorting by bump, high to low. The problem is the SortedSet contains won't add a new element if it compares to
* zero with an existing element, so for the cases where the comparison is to zero (i.e. have the same bump), we
* need to determine as quickly as possible a determinate ordering (fixed) ordering for two variables. The
* fastest way to do this is using a hash code, though it's still possible for two Arrows to have the same hash
* code but not be equal. If we're paranoid, in this case, we calculate a determinate comparison not equal to
* zero by keeping a list. This last part is commented out by default.
*/
public int compareTo(@NotNull Arrow arrow) {
final int compare = Double.compare(arrow.getBump(), getBump());
if (compare == 0) {
return Integer.compare(getIndex(), arrow.getIndex());
}
return compare;
}
/**
* Returns a string representation of this Arrow object.
*
* @return a string representation of this Arrow object
*/
public String toString() {
return "Arrow<" + a + "->" + b + " bump = " + bump + " t/h = " + hOrT + " TNeighbors = " + getTNeighbors() + " parents = " + parents + " naYX = " + naYX + ">";
}
/**
* Retrieves the index of the Arrow.
*
* @return The index of the Arrow.
*/
public int getIndex() {
return index;
}
/**
* Retrieves the set of neighboring nodes of the Arrow that are represented as T.
*
* @return The set of neighboring nodes represented as T.
*/
public Set getTNeighbors() {
return TNeighbors;
}
/**
* Sets the neighbors of b that are not adjacent to a.
*
* @param TNeighbors The set of neighboring nodes represented as T.
*/
public void setTNeighbors(Set TNeighbors) {
this.TNeighbors = TNeighbors;
}
/**
* Retrieves the set of parent nodes.
*
* @return The set of parent nodes.
*/
public Set getParents() {
return parents;
}
}
/**
* Represents a pair of evaluation values.
*
* This class is used to store information about a set of nodes (T) and a bump value (bump). It is a private static
* class, meaning it is accessible only within the enclosing class.
*/
private static class EvalPair {
Set T;
double bump;
}
/**
* A class representing a task for finding empty graphs in a given node list.
*/
class NodeTaskEmptyGraph implements Callable {
private final int from;
private final int to;
private final List nodes;
private final Set emptySet;
/**
* A class representing a task for finding empty graphs in a given node list.
*/
NodeTaskEmptyGraph(int from, int to, List nodes, Set emptySet) {
this.from = from;
this.to = to;
this.nodes = nodes;
this.emptySet = emptySet;
}
/**
* Executes the task for finding empty graphs in a given node list.
*
* @return true if the task completes successfully, false otherwise.
*/
@Override
public Boolean call() {
for (int i = from; i < to; i++) {
if (Thread.currentThread().isInterrupted()) break;
if ((i + 1) % 1000 == 0) {
count[0] += 1000;
out.println("Initializing effect edges: " + (count[0]));
}
Node y = nodes.get(i);
for (int j = i + 1; j < nodes.size(); j++) {
if (Thread.currentThread().isInterrupted()) {
pool.shutdownNow();
throw new RuntimeException("Interrupted");
}
Node x = nodes.get(j);
if (existsKnowledge()) {
if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) {
continue;
}
if (invalidSetByKnowledge(y, emptySet)) {
continue;
}
}
if (boundGraph != null && !boundGraph.isAdjacentTo(x, y)) {
continue;
}
int child = hashIndices.get(y);
int parent = hashIndices.get(x);
double bump = score.localScoreDiff(parent, child);
if (symmetricFirstStep) {
double bump2 = score.localScoreDiff(child, parent);
bump = max(bump, bump2);
}
if (boundGraph != null && !boundGraph.isAdjacentTo(x, y)) {
continue;
}
if (bump > 0) {
effectEdgesGraph.addEdge(Edges.undirectedEdge(x, y));
addArrowForward(x, y, emptySet, emptySet, emptySet, emptySet, bump);
addArrowForward(y, x, emptySet, emptySet, emptySet, emptySet, bump);
}
}
}
return true;
}
}
}