edu.cmu.tetrad.search.SvarFges Maven / Gradle / Ivy
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.ScoredGraph;
import edu.cmu.tetrad.search.utils.DagScorer;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TaskManager;
import edu.cmu.tetrad.util.TetradLogger;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;
import java.io.PrintStream;
import java.util.*;
import java.util.concurrent.*;
/**
* Adapts FGES for the time series setting, assuming the data is generated by a SVAR (structural vector autoregression).
* The main difference is that time order is imposed, and if an edge is removed, it will also remove all homologous
* edges to preserve the time-repeating structure assumed by SvarFCI. Based on (but not identical to) code by Entner and
* Hoyer for their 2010 paper. Modified by dmalinsky 4/21/2016.
*
* The references are as follows:
*
* Malinsky, D., & Spirtes, P. (2018, August). Causal structure learning from multivariate time series in settings
* with unmeasured confounding. In Proceedings of 2018 ACM SIGKDD workshop on causal discovery (pp. 23-47). PMLR.
*
* Entner, D., & Hoyer, P. O. (2010). On causal discovery from time series data using FCI. Probabilistic graphical
* models, 121-128.
*
* This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
* tiers.
*
* @author danielmalinsky
* @version $Id: $Id
* @see Fges
* @see Knowledge
* @see SvarFci
*/
public final class SvarFges implements IGraphSearch, DagScorer {
/**
* The number of threads to use.
*/
final int maxThreads = Runtime.getRuntime().availableProcessors();
/**
* The top n graphs found by the algorithm, where n is numCPDAGsToStore.
*/
private final LinkedList topGraphs = new LinkedList<>();
/**
* The static ForkJoinPool instance.
*/
private final ForkJoinPool pool = new ForkJoinPool(maxThreads);
/**
* The number of graphs searched.
*/
private final int[] count = new int[1];
/**
* The set of removed edges.
*/
private final Set removedEdges = new HashSet<>();
/**
* Arrows with the same totalScore are stored in this list to distinguish their order in sortedArrows. The ordering
* doesn't matter; it just has to be transitive.
*/
private int arrowIndex;
/**
* Specification of forbidden and required edges.
*/
private Knowledge knowledge = new Knowledge();
/**
* List of variables in the data set, in order.
*/
private List variables;
/**
* The true graph, if known. If this is provided, asterisks will be printed out next to false positive added edges
* (that is, edges added that aren't adjacencies in the true graph).
*/
private Graph trueGraph;
/**
* An initial graph to start from.
*/
private Graph externalGraph;
/**
* Elapsed time of the most recent search.
*/
private long elapsedTime;
/**
* The totalScore for discrete searches.
*/
private Score score;
/**
* The number of top CPDAGs to store.
*/
private int numCPDAGsToStore;
/**
* True if verbose output should be printed.
*/
private boolean verbose;
/**
* Potential arrows sorted by bump high to low. The first one is a candidate for adding to the graph.
*/
private SortedSet sortedArrows;
/**
* Arrows added to sortedArrows for each .
*/
private Map, Set> lookupArrows;
/**
* A utility map to help with orientation.
*/
private Map> neighbors;
/**
* Map from variables to their column indices in the data set.
*/
private ConcurrentMap hashIndices;
/**
* A running tally of the total BIC totalScore.
*/
private double totalScore;
/**
* A graph where X--Y means that X and Y have non-zero total effect on one another.
*/
private Graph effectEdgesGraph;
/**
* Where printed output is sent.
*/
private transient PrintStream out = System.out;
/**
* An initial adjacencies graph.
*/
private Graph adjacencies;
/**
* The graph being constructed.
*/
private Graph graph;
/**
* The mode variable represents the current mode of operation.
*/
private Mode mode = Mode.heuristicSpeedup;
/**
* True if one-edge faithfulness is assumed. Speeds the algorithm up.
*/
private boolean faithfulnessAssumed = false;
/**
* Bounds the indegree of the graph.
*/
private int maxIndegree = -1;
/**
* Construct a Score and pass it in here. The totalScore should return a positive value in case of conditional
* dependence and a negative values in case of conditional independence. See Chickering (2002), locally consistent
* scoring criterion.
*
* @param score a {@link edu.cmu.tetrad.search.score.Score} object
*/
public SvarFges(Score score) {
if (score == null) throw new NullPointerException();
setScore(score);
this.graph = new EdgeListGraph(getVariables());
}
/**
* Traverses a semi-directed graph from a given node along a given edge.
*
* @param node The starting node of the traversal.
* @param edge The edge to traverse from the starting node.
* @return The destination node of the traversal if the edge is traversable from the starting node, or null if the
* edge is not traversable from the starting node.
*/
private static Node traverseSemiDirected(Node node, Edge edge) {
if (node == edge.getNode1()) {
if (edge.getEndpoint1() == Endpoint.TAIL) {
return edge.getNode2();
}
} else if (node == edge.getNode2()) {
if (edge.getEndpoint2() == Endpoint.TAIL) {
return edge.getNode1();
}
}
return null;
}
/**
* Greedy equivalence search: Start from the empty graph, add edges till the model is significant. Then start
* deleting edges till a minimum is achieved.
*
* @return the resulting CPDAG.
*/
public Graph search() {
this.topGraphs.clear();
this.lookupArrows = new ConcurrentHashMap<>();
List nodes = new ArrayList<>(this.variables);
this.graph = new EdgeListGraph(nodes);
if (this.adjacencies != null) {
this.adjacencies = GraphUtils.replaceNodes(this.adjacencies, nodes);
}
if (this.externalGraph != null) {
this.graph = new EdgeListGraph(this.externalGraph);
this.graph = GraphUtils.replaceNodes(this.graph, nodes);
}
addRequiredEdges(this.graph);
if (this.faithfulnessAssumed) {
initializeForwardEdgesFromEmptyGraph(getVariables());
// Do forward search.
this.mode = Mode.heuristicSpeedup;
fes();
bes();
this.mode = Mode.coverNoncolliders;
initializeTwoStepEdges(getVariables());
} else {
initializeForwardEdgesFromEmptyGraph(getVariables());
// Do forward search.
this.mode = Mode.heuristicSpeedup;
fes();
bes();
this.mode = Mode.allowUnfaithfulness;
initializeForwardEdgesFromExistingGraph(getVariables());
}
fes();
bes();
long start = MillisecondTimes.timeMillis();
this.totalScore = 0.0;
long endTime = MillisecondTimes.timeMillis();
this.elapsedTime = endTime - start;
if (verbose) {
TetradLogger.getInstance().log("Elapsed time = " + (this.elapsedTime) / 1000. + " s");
}
// The final totalScore after search.
return this.graph;
}
/**
* Set to true if it is assumed that all path pairs with one length 1 path do not cancel.
*
* @param faithfulnessAssumed a boolean
*/
public void setFaithfulnessAssumed(boolean faithfulnessAssumed) {
this.faithfulnessAssumed = faithfulnessAssumed;
}
/**
* Getter for the field knowledge
.
*
* @return the background knowledge.
*/
public Knowledge getKnowledge() {
return this.knowledge;
}
/**
* Sets the background knowledge.
*
* @param knowledge the knowledge object, specifying forbidden and required edges.
*/
public void setKnowledge(Knowledge knowledge) {
this.knowledge = new Knowledge(knowledge);
}
/**
* Returns the elapsed time of the search.
*
* @return This time.
*/
public long getElapsedTime() {
return this.elapsedTime;
}
/**
* Sets the true graph, which will result in some edges in output graphs being marked with asterisks.
*
* @param trueGraph the true graph.
*/
public void setTrueGraph(Graph trueGraph) {
this.trueGraph = trueGraph;
}
/**
* Returns the score of the given DAG.
*
* @param dag the dag to score.
* @return the totalScore of the given DAG, up to a constant.
*/
public double getScore(Graph dag) {
return scoreDag(dag);
}
/**
* Getter for the field numCPDAGsToStore
.
*
* @return the number of patterns to store.
*/
public int getnumCPDAGsToStore() {
return this.numCPDAGsToStore;
}
/**
* Sets the number of patterns to store. This should be set to zero for fast search.
*
* @param numCPDAGsToStore a int
*/
public void setNumCPDAGsToStore(int numCPDAGsToStore) {
if (numCPDAGsToStore < 0) {
throw new IllegalArgumentException("# graphs to store must at least 0: " + numCPDAGsToStore);
}
this.numCPDAGsToStore = numCPDAGsToStore;
}
/**
* Getter for the field externalGraph
.
*
* @return the initial graph for the search. The search is initialized to this graph and proceeds from there.
*/
public Graph getExternalGraph() {
return this.externalGraph;
}
/**
* Sets the initial graph.
*
* @param externalGraph the initial graph.
*/
public void setExternalGraph(Graph externalGraph) {
if (externalGraph != null) {
externalGraph = GraphUtils.replaceNodes(externalGraph, this.variables);
if (!new HashSet<>(externalGraph.getNodes()).equals(new HashSet<>(this.variables))) {
throw new IllegalArgumentException("Variables aren't the same.");
}
}
this.externalGraph = externalGraph;
}
/**
* Sets whether verbose output should be produced.
*
* @param verbose true if verbose output should be produced.
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Getter for the field out
.
*
* @return the output stream that output (except for log output) should be sent to.
*/
public PrintStream getOut() {
return this.out;
}
/**
* Sets the output stream that output (except for log output) should be sent to. By detault System.out.
*
* @param out the output stream.
*/
public void setOut(PrintStream out) {
this.out = out;
}
/**
* Retrieves the adjacency graph.
*
* @return the adjacency graph.
*/
public Graph getAdjacencies() {
return this.adjacencies;
}
/**
* Sets the set of preset adjacencies for the algorithm; edges not in this adjacencies graph will not be added.
*
* @param adjacencies the adjacencies graph.
*/
public void setAdjacencies(Graph adjacencies) {
this.adjacencies = adjacencies;
}
/**
* The maximum of parents any nodes can have in the output pattern.
*
* @return -1 for unlimited.
*/
public int getMaxIndegree() {
return this.maxIndegree;
}
/**
* The maximum of parents any nodes can have in the output pattern.
*
* @param maxIndegree -1 for unlimited.
*/
public void setMaxIndegree(int maxIndegree) {
if (maxIndegree < -1) throw new IllegalArgumentException();
this.maxIndegree = maxIndegree;
}
/**
* Returns the minimum number of operations to perform before parallelizing.
*
* @param n The total number of operations to be performed.
* @return The minimum number of operations to do before parallelizing.
*/
public int getMinChunk(int n) {
// The minimum number of operations to do before parallelizing.
int minChunk = 100;
return FastMath.max(n / this.maxThreads, minChunk);
}
/**
* Calculates the score of a Directed Acyclic Graph (DAG).
*
* @param dag The Directed Acyclic Graph to calculate the score for.
* @return The score of the DAG.
*/
public double scoreDag(Graph dag) {
buildIndexing(dag.getNodes());
double _score = 0.0;
for (Node y : dag.getNodes()) {
Set parents = new HashSet<>(dag.getParents(y));
int[] parentIndices = new int[parents.size()];
Iterator pi = parents.iterator();
int count = 0;
while (pi.hasNext()) {
Node nextParent = pi.next();
parentIndices[count++] = this.hashIndices.get(nextParent);
}
int yIndex = this.hashIndices.get(y);
_score += this.score.localScore(yIndex, parentIndices);
}
return _score;
}
/**
* Sets the score and updates the variables, indexing, and maxIndegree accordingly.
*
* @param totalScore The total score to be set.
*/
private void setScore(Score totalScore) {
this.score = totalScore;
this.variables = new ArrayList<>();
for (Node node : totalScore.getVariables()) {
if (node.getNodeType() == NodeType.MEASURED) {
this.variables.add(node);
}
}
buildIndexing(totalScore.getVariables());
this.maxIndegree = this.score.getMaxDegree();
}
/**
* Initializes the forward edges from an empty graph.
*
* @param nodes The list of nodes in the graph.
*/
private void initializeForwardEdgesFromEmptyGraph(List nodes) {
this.sortedArrows = new ConcurrentSkipListSet<>();
this.lookupArrows = new ConcurrentHashMap<>();
this.neighbors = new ConcurrentHashMap<>();
Set emptySet = new HashSet<>();
long start = MillisecondTimes.timeMillis();
this.effectEdgesGraph = new EdgeListGraph(nodes);
class InitializeFromEmptyGraphTask extends RecursiveTask {
public InitializeFromEmptyGraphTask() {
}
@Override
protected Boolean compute() {
Queue tasks = new ArrayDeque<>();
int numNodesPerTask = FastMath.max(100, nodes.size() / SvarFges.this.maxThreads);
for (int i = 0; i < nodes.size(); i += numNodesPerTask) {
NodeTaskEmptyGraph task = new NodeTaskEmptyGraph(i, FastMath.min(nodes.size(), i + numNodesPerTask), nodes, emptySet);
tasks.add(task);
task.fork();
for (NodeTaskEmptyGraph _task : new ArrayList<>(tasks)) {
if (_task.isDone()) {
_task.join();
tasks.remove(_task);
}
}
while (tasks.size() > SvarFges.this.maxThreads) {
NodeTaskEmptyGraph _task = tasks.poll();
if (_task != null) {
_task.join();
}
}
}
for (NodeTaskEmptyGraph task : tasks) {
task.join();
}
return true;
}
}
try {
this.pool.invoke(new InitializeFromEmptyGraphTask());
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) {
throw new IllegalStateException("Pool timed out");
}
long stop = MillisecondTimes.timeMillis();
if (this.verbose) {
TetradLogger.getInstance().log("Elapsed initializeForwardEdgesFromEmptyGraph = " + (stop - start) + " ms");
}
}
/**
* Initializes two-step edges based on given list of nodes.
*
* @param nodes The list of nodes to initialize two-step edges.
*/
private void initializeTwoStepEdges(List nodes) {
this.count[0] = 0;
this.sortedArrows = new ConcurrentSkipListSet<>();
this.lookupArrows = new ConcurrentHashMap<>();
this.neighbors = new ConcurrentHashMap<>();
if (this.effectEdgesGraph == null) {
this.effectEdgesGraph = new EdgeListGraph(nodes);
}
if (this.externalGraph != null) {
for (Edge edge : this.externalGraph.getEdges()) {
if (!this.effectEdgesGraph.isAdjacentTo(edge.getNode1(), edge.getNode2())) {
this.effectEdgesGraph.addUndirectedEdge(edge.getNode1(), edge.getNode2());
}
}
}
Set emptySet = new HashSet<>(0);
class InitializeFromExistingGraphTask extends RecursiveTask {
private final int chunk;
private final int from;
private final int to;
public InitializeFromExistingGraphTask(int chunk, int from, int to) {
this.chunk = chunk;
this.from = from;
this.to = to;
}
@Override
protected Boolean compute() {
if (TaskManager.getInstance().isCanceled()) return false;
if (this.to - this.from <= this.chunk) {
for (int i = this.from; i < this.to; i++) {
if ((i + 1) % 1000 == 0) {
SvarFges.this.count[0] += 1000;
if (verbose) {
TetradLogger.getInstance().log("Initializing effect edges: " + (SvarFges.this.count[0]));
}
}
Node y = nodes.get(i);
Set g = new HashSet<>();
for (Node n : SvarFges.this.graph.getAdjacentNodes(y)) {
for (Node m : SvarFges.this.graph.getAdjacentNodes(n)) {
if (SvarFges.this.graph.isAdjacentTo(y, m)) {
continue;
}
if (SvarFges.this.graph.isDefCollider(m, n, y)) {
continue;
}
g.add(m);
}
}
for (Node x : g) {
if (existsKnowledge()) {
if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) {
continue;
}
if (invalidSetByKnowledge(y, emptySet)) {
continue;
}
}
if (SvarFges.this.adjacencies != null && !SvarFges.this.adjacencies.isAdjacentTo(x, y)) {
continue;
}
if (SvarFges.this.removedEdges.contains(Edges.undirectedEdge(x, y))) {
continue;
}
calculateArrowsForward(x, y);
}
}
} else {
int mid = (this.to + this.from) / 2;
InitializeFromExistingGraphTask left = new InitializeFromExistingGraphTask(this.chunk, this.from, mid);
InitializeFromExistingGraphTask right = new InitializeFromExistingGraphTask(this.chunk, mid, this.to);
left.fork();
right.compute();
left.join();
}
return true;
}
}
try {
this.pool.invoke(new InitializeFromExistingGraphTask(getMinChunk(nodes.size()), 0, nodes.size()));
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) {
throw new IllegalStateException("Pool timed out");
}
}
/**
* Initializes the forward edges from an existing graph.
*
* @param nodes The list of nodes in the graph.
*/
private void initializeForwardEdgesFromExistingGraph(List nodes) {
this.count[0] = 0;
this.sortedArrows = new ConcurrentSkipListSet<>();
this.lookupArrows = new ConcurrentHashMap<>();
this.neighbors = new ConcurrentHashMap<>();
if (this.effectEdgesGraph == null) {
this.effectEdgesGraph = new EdgeListGraph(nodes);
}
if (this.externalGraph != null) {
for (Edge edge : this.externalGraph.getEdges()) {
if (!this.effectEdgesGraph.isAdjacentTo(edge.getNode1(), edge.getNode2())) {
this.effectEdgesGraph.addUndirectedEdge(edge.getNode1(), edge.getNode2());
}
}
}
Set emptySet = new HashSet<>(0);
class InitializeFromExistingGraphTask extends RecursiveTask {
private final int chunk;
private final int from;
private final int to;
public InitializeFromExistingGraphTask(int chunk, int from, int to) {
this.chunk = chunk;
this.from = from;
this.to = to;
}
@Override
protected Boolean compute() {
if (TaskManager.getInstance().isCanceled()) return false;
if (this.to - this.from <= this.chunk) {
for (int i = this.from; i < this.to; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
if ((i + 1) % 1000 == 0) {
SvarFges.this.count[0] += 1000;
if (verbose) {
TetradLogger.getInstance().log("Initializing effect edges: " + (SvarFges.this.count[0]));
}
}
Node y = nodes.get(i);
// Set cond = new HashSet<>();
Set D = new HashSet<>(variables);// SvarFges.this.graph.paths().getMConnectedVars(y, cond));
D.remove(y);
// SvarFges.this.effectEdgesGraph.getAdjacentNodes(y).forEach(D::remove);
for (Node x : D) {
if (existsKnowledge()) {
if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) {
continue;
}
if (invalidSetByKnowledge(y, emptySet)) {
continue;
}
}
if (SvarFges.this.adjacencies != null && !SvarFges.this.adjacencies.isAdjacentTo(x, y)) {
continue;
}
calculateArrowsForward(x, y);
}
}
} else {
int mid = (this.to + this.from) / 2;
InitializeFromExistingGraphTask left = new InitializeFromExistingGraphTask(this.chunk, this.from, mid);
InitializeFromExistingGraphTask right = new InitializeFromExistingGraphTask(this.chunk, mid, this.to);
left.fork();
right.compute();
left.join();
}
return true;
}
}
try {
this.pool.invoke(new InitializeFromExistingGraphTask(getMinChunk(nodes.size()), 0, nodes.size()));
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) {
throw new IllegalStateException("Pool timed out");
}
}
/**
* Perform Forward Equivalence Search.
*
* This method applies the Forward Equivalence Search algorithm to search for equivalence classes in a graph. It
* iteratively processes arrows in a sorted order until all arrows are processed or the thread is interrupted.
*/
private void fes() {
if (verbose) {
TetradLogger.getInstance().log("** FORWARD EQUIVALENCE SEARCH");
}
while (!this.sortedArrows.isEmpty()) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Arrow arrow = this.sortedArrows.first();
this.sortedArrows.remove(arrow);
Node x = arrow.getA();
Node y = arrow.getB();
if (this.graph.isAdjacentTo(x, y)) {
continue;
}
if (!arrow.getNaYX().equals(getNaYX(x, y))) {
continue;
}
if (!new HashSet<>(getTNeighbors(x, y)).containsAll(arrow.getHOrT())) {
continue;
}
if (!validInsert(x, y, arrow.getHOrT(), getNaYX(x, y))) {
continue;
}
Set T = arrow.getHOrT();
double bump = arrow.getBump();
boolean inserted = insert(x, y, T, bump);
if (!inserted) continue;
this.totalScore += bump;
Set visited = reapplyOrientation();
Set toProcess = new HashSet<>();
for (Node node : visited) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Set neighbors1 = getNeighbors(node);
Set storedNeighbors = this.neighbors.get(node);
if (!(neighbors1.equals(storedNeighbors))) {
toProcess.add(node);
}
}
toProcess.add(x);
toProcess.add(y);
storeGraph();
reevaluateForward(toProcess);
}
}
/**
* Performs the Backward Equivalence Search algorithm.
*/
private void bes() {
if (verbose) {
TetradLogger.getInstance().log("** BACKWARD EQUIVALENCE SEARCH");
}
this.sortedArrows = new ConcurrentSkipListSet<>();
this.lookupArrows = new ConcurrentHashMap<>();
this.neighbors = new ConcurrentHashMap<>();
initializeArrowsBackward();
while (!this.sortedArrows.isEmpty()) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Arrow arrow = this.sortedArrows.first();
this.sortedArrows.remove(arrow);
Node x = arrow.getA();
Node y = arrow.getB();
if (!arrow.getNaYX().equals(getNaYX(x, y))) {
continue;
}
if (!this.graph.isAdjacentTo(x, y)) continue;
Edge edge = this.graph.getEdge(x, y);
if (edge.pointsTowards(x)) continue;
if (!validDelete(x, y, arrow.getHOrT(), arrow.getNaYX())) continue;
Set H = arrow.getHOrT();
double bump = arrow.getBump();
boolean deleted = delete(x, y, H, bump, arrow.getNaYX());
if (!deleted) continue;
this.totalScore += bump;
clearArrow(x, y);
Set visited = reapplyOrientation();
Set toProcess = new HashSet<>();
for (Node node : visited) {
Set neighbors1 = getNeighbors(node);
Set storedNeighbors = this.neighbors.get(node);
if (!(neighbors1.equals(storedNeighbors))) {
toProcess.add(node);
}
}
toProcess.add(x);
toProcess.add(y);
toProcess.addAll(getCommonAdjacents(x, y));
storeGraph();
reevaluateBackward(toProcess);
}
meekOrientRestricted(getKnowledge());
}
/**
* Retrieves the common adjacent nodes of two given nodes.
*
* @param x The first node.
* @param y The second node.
* @return A set of nodes that are adjacent to both x and y.
*/
private Set getCommonAdjacents(Node x, Node y) {
Set commonChildren = new HashSet<>(this.graph.getAdjacentNodes(x));
commonChildren.retainAll(this.graph.getAdjacentNodes(y));
return commonChildren;
}
/**
* Applies the orientation of nodes in the given set by utilizing the meekOrientRestricted method.
*
* @return the set of nodes with re-applied orientation
*/
private Set reapplyOrientation() {
return meekOrientRestricted(getKnowledge());
}
/**
* Checks if knowledge exists.
*
* @return true if knowledge is not empty; false otherwise.
*/
private boolean existsKnowledge() {
return !this.knowledge.isEmpty();
}
/**
* Initializes the sorted arrows lists for the backward search.
*
* This method iterates over each edge in the graph and performs the following steps: 1. Check if the current thread
* is interrupted. If interrupted, the iteration is stopped. 2. Get the source node (x) and target node (y) of the
* edge. 3. If knowledge exists, check if there is no required edge between the names of x and y. If required edge
* exists, continue to next iteration. 4. Clear the arrow from x to y and from y to x (if they exist). 5. Determine
* the direction of the edge and calculate the arrows backward accordingly. (If edge points towards y, calculate
* arrows backward from x to y. If edge points towards x, calculate arrows backward from y to x. Otherwise,
* calculate arrows backward from x to y and from y to x.) 6. Update the neighbors map with the neighbors of x and
* y.
*/
private void initializeArrowsBackward() {
for (Edge edge : this.graph.getEdges()) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Node x = edge.getNode1();
Node y = edge.getNode2();
if (existsKnowledge()) {
if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) {
continue;
}
}
clearArrow(x, y);
clearArrow(y, x);
if (edge.pointsTowards(y)) {
calculateArrowsBackward(x, y);
} else if (edge.pointsTowards(x)) {
calculateArrowsBackward(y, x);
} else {
calculateArrowsBackward(x, y);
calculateArrowsBackward(y, x);
}
this.neighbors.put(x, getNeighbors(x));
this.neighbors.put(y, getNeighbors(y));
}
}
/**
* Recalculates new arrows based on changes in the graph for the forward search.
*
* @param nodes the set of nodes to reevaluate
*/
private void reevaluateForward(Set nodes) {
class AdjTask extends RecursiveTask {
private final List nodes;
private final int from;
private final int to;
private final int chunk;
public AdjTask(int chunk, List nodes, int from, int to) {
this.nodes = nodes;
this.from = from;
this.to = to;
this.chunk = chunk;
}
@Override
protected Boolean compute() {
if (this.to - this.from <= this.chunk) {
for (int _w = this.from; _w < this.to; _w++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Node x = this.nodes.get(_w);
List adj;
if (SvarFges.this.mode == Mode.heuristicSpeedup) {
adj = SvarFges.this.effectEdgesGraph.getAdjacentNodes(x);
} else if (SvarFges.this.mode == Mode.coverNoncolliders) {
Set g = new HashSet<>();
for (Node n : SvarFges.this.graph.getAdjacentNodes(x)) {
for (Node m : SvarFges.this.graph.getAdjacentNodes(n)) {
if (SvarFges.this.graph.isAdjacentTo(x, m)) {
continue;
}
if (SvarFges.this.graph.isDefCollider(m, n, x)) {
continue;
}
g.add(m);
}
}
adj = new ArrayList<>(g);
} else if (SvarFges.this.mode == Mode.allowUnfaithfulness) {
// HashSet D = new HashSet<>(SvarFges.this.graph.paths().getMConnectedVars(x, new HashSet<>()));
// D.remove(x);
adj = new ArrayList<>(variables);
adj.remove(x);
} else {
throw new IllegalStateException();
}
for (Node w : adj) {
if (SvarFges.this.adjacencies != null && !(SvarFges.this.adjacencies.isAdjacentTo(w, x))) {
continue;
}
if (w == x) continue;
if (!SvarFges.this.graph.isAdjacentTo(w, x)) {
clearArrow(w, x);
calculateArrowsForward(w, x);
}
}
}
} else {
int mid = (this.to - this.from) / 2;
List tasks = new ArrayList<>();
tasks.add(new AdjTask(this.chunk, this.nodes, this.from, this.from + mid));
tasks.add(new AdjTask(this.chunk, this.nodes, this.from + mid, this.to));
invokeAll(tasks);
}
return true;
}
}
AdjTask task = new AdjTask(getMinChunk(nodes.size()), new ArrayList<>(nodes), 0, nodes.size());
try {
pool.invoke(task);
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
if (!this.pool.awaitQuiescence(1, TimeUnit.DAYS)) {
Thread.currentThread().interrupt();
}
}
/**
* Calculates the new arrows for an a->b edge.
*
* @param a the Node representing the start node of the edge
* @param b the Node representing the end node of the edge
*/
private void calculateArrowsForward(Node a, Node b) {
if (this.mode == Mode.heuristicSpeedup && !this.effectEdgesGraph.isAdjacentTo(a, b)) return;
if (this.adjacencies != null && !this.adjacencies.isAdjacentTo(a, b)) return;
this.neighbors.put(b, getNeighbors(b));
if (existsKnowledge()) {
if (getKnowledge().isForbidden(a.getName(), b.getName())) {
return;
}
}
Set naYX = getNaYX(a, b);
if (!GraphUtils.isClique(naYX, this.graph)) return;
List TNeighbors = new ArrayList<>(getTNeighbors(a, b));
int _maxIndegree = this.maxIndegree == -1 ? 1000 : this.maxIndegree;
int _max = FastMath.min(TNeighbors.size(), _maxIndegree - this.graph.getIndegree(b));
Set> previousCliques = new HashSet<>();
previousCliques.add(new HashSet<>());
Set> newCliques = new HashSet<>();
FOR:
for (int i = 0; i <= _max; i++) {
ChoiceGenerator gen = new ChoiceGenerator(TNeighbors.size(), i);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Set T = GraphUtils.asSet(choice, TNeighbors);
Set union = new HashSet<>(naYX);
union.addAll(T);
boolean foundAPreviousClique = false;
for (Set clique : previousCliques) {
if (union.containsAll(clique)) {
foundAPreviousClique = true;
break;
}
}
if (!foundAPreviousClique) {
break FOR;
}
if (!GraphUtils.isClique(union, this.graph)) continue;
newCliques.add(union);
double bump = insertEval(a, b, T, naYX, this.hashIndices);
if (bump > 0.0) {
addArrow(a, b, naYX, T, bump);
}
}
previousCliques = newCliques;
newCliques = new HashSet<>();
}
}
/**
* Adds an arrow between two nodes.
*
* @param a the starting node of the arrow
* @param b the ending node of the arrow
* @param naYX a set of nodes with vertical (north and south) positions
* @param hOrT a set of nodes with horizontal (west and east) positions
* @param bump the amount of bump to apply to the arrow
*/
private void addArrow(Node a, Node b, Set naYX, Set hOrT, double bump) {
Arrow arrow = new Arrow(bump, a, b, hOrT, naYX, this.arrowIndex++);
this.sortedArrows.add(arrow);
addLookupArrow(a, b, arrow);
}
/**
* Reevaluates arrows after removing an edge from the graph.
*
* @param toProcess a set of nodes to process
*/
private void reevaluateBackward(Set toProcess) {
class BackwardTask extends RecursiveTask {
private final Node r;
private final List adj;
private final Map hashIndices;
private final int chunk;
private final int from;
private final int to;
public BackwardTask(Node r, List adj, int chunk, int from, int to, Map hashIndices) {
this.adj = adj;
this.hashIndices = hashIndices;
this.chunk = chunk;
this.from = from;
this.to = to;
this.r = r;
}
@Override
protected Boolean compute() {
if (this.to - this.from <= this.chunk) {
for (int _w = this.from; _w < this.to; _w++) {
Node w = this.adj.get(_w);
Edge e = SvarFges.this.graph.getEdge(w, this.r);
if (e != null) {
if (e.pointsTowards(this.r)) {
clearArrow(w, this.r);
clearArrow(this.r, w);
calculateArrowsBackward(w, this.r);
} else if (Edges.isUndirectedEdge(SvarFges.this.graph.getEdge(w, this.r))) {
clearArrow(w, this.r);
clearArrow(this.r, w);
calculateArrowsBackward(w, this.r);
calculateArrowsBackward(this.r, w);
}
}
}
} else {
int mid = (this.to - this.from) / 2;
List tasks = new ArrayList<>();
tasks.add(new BackwardTask(this.r, this.adj, this.chunk, this.from, this.from + mid, this.hashIndices));
tasks.add(new BackwardTask(this.r, this.adj, this.chunk, this.from + mid, this.to, this.hashIndices));
invokeAll(tasks);
}
return true;
}
}
for (Node r : toProcess) {
this.neighbors.put(r, getNeighbors(r));
List adjacentNodes = new ArrayList<>(this.graph.getAdjacentNodes(r));
try {
this.pool.invoke(new BackwardTask(r, adjacentNodes, getMinChunk(adjacentNodes.size()), 0,
adjacentNodes.size(), this.hashIndices));
} catch (Exception e) {
Thread.currentThread().interrupt();
throw e;
}
if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) {
throw new IllegalStateException("Pool timed out");
}
}
}
/**
* Calculates the arrows for the removal in the backward direction.
*
* @param a the starting node
* @param b the ending node
*/
private void calculateArrowsBackward(Node a, Node b) {
if (existsKnowledge()) {
if (!getKnowledge().noEdgeRequired(a.getName(), b.getName())) {
return;
}
}
Set naYX = getNaYX(a, b);
List _naYX = new ArrayList<>(naYX);
int _depth = _naYX.size();
for (int i = 0; i <= _depth; i++) {
ChoiceGenerator gen = new ChoiceGenerator(_naYX.size(), i);
int[] choice;
while ((choice = gen.next()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Set diff = GraphUtils.asSet(choice, _naYX);
Set h = new HashSet<>(_naYX);
h.removeAll(diff);
if (existsKnowledge()) {
if (invalidSetByKnowledge(b, h)) {
continue;
}
}
double bump = deleteEval(a, b, diff, this.hashIndices);
if (bump > 0.0) {
addArrow(a, b, naYX, h, bump);
}
}
}
}
/**
* Returns a set of neighbors of node Y that are connected to Y by an undirected edge and are not adjacent to node
* X.
*
* @param x the node X
* @param y the node Y
* @return a set of neighbors of node Y that fulfill the specified conditions
*/
private Set getTNeighbors(Node x, Node y) {
Set yEdges = this.graph.getEdges(y);
Set tNeighbors = new HashSet<>();
for (Edge edge : yEdges) {
if (!Edges.isUndirectedEdge(edge)) {
continue;
}
Node z = edge.getDistalNode(y);
if (this.graph.isAdjacentTo(z, x)) {
continue;
}
tNeighbors.add(z);
}
return tNeighbors;
}
/**
* Returns a set of neighboring nodes connected to the given node in the graph.
*
* @param y the node for which to retrieve neighbors
* @return a set of neighboring nodes connected to the given node
*/
private Set getNeighbors(Node y) {
Set yEdges = this.graph.getEdges(y);
Set neighbors = new HashSet<>();
for (Edge edge : yEdges) {
if (!Edges.isUndirectedEdge(edge)) {
continue;
}
Node z = edge.getDistalNode(y);
neighbors.add(z);
}
return neighbors;
}
/**
* Evaluate the Insert(X, Y, T) operator. This method calculates the score of the graph change caused by inserting
* node Y into node X with a set of target nodes T.
*
* @param x The destination node X
* @param y The node to be inserted Y
* @param t The set of target nodes T
* @param naYX The set of nodes that are forbidden to be parents of Y if X is its parent
* @param hashIndices The map of node-to-index associations
* @return The score of the graph change caused by the node insertion
*/
private double insertEval(Node x, Node y, Set t, Set naYX, Map hashIndices) {
Set set = new HashSet<>(naYX);
set.addAll(t);
set.addAll(this.graph.getParents(y));
return scoreGraphChange(y, set, x, hashIndices);
}
/**
* Evaluate the Delete(X, Y, T) operator (Definition 12 from Chickering, 2002).
*
* @param x The node to delete from the graph.
* @param y The node to be evaluated as a parent of X.
* @param diff The set of nodes representing the difference between the original graph and the modified
* graph.
* @param hashIndices The map containing the indices of nodes in the hash table.
* @return The score of the graph change after deleting the specified node.
*/
private double deleteEval(Node x, Node y, Set diff, Map hashIndices) {
Set set = new HashSet<>(diff);
set.addAll(this.graph.getParents(y));
set.remove(x);
return -scoreGraphChange(y, set, x, hashIndices);
}
/**
* Inserts an edge into the graph based on the given parameters.
*
* @param x the source node of the edge to be inserted
* @param y the target node of the edge to be inserted
* @param T the set of nodes to be connected to the target node
* @param bump the bump value for the insertion
* @return true if the insertion is successful, false otherwise
*/
private boolean insert(Node x, Node y, Set T, double bump) {
if (this.graph.isAdjacentTo(x, y)) {
return false; // The initial graph may already have put this edge in the graph.
}
Edge trueEdge = null;
if (this.trueGraph != null) {
Node _x = this.trueGraph.getNode(x.getName());
Node _y = this.trueGraph.getNode(y.getName());
trueEdge = this.trueGraph.getEdge(_x, _y);
}
this.graph.addDirectedEdge(x, y);
// Adding similar edges to enforce repeating structure **/
addSimilarEdges(x, y);
// **/
if (this.verbose) {
String label = this.trueGraph != null && trueEdge != null ? "*" : "";
String message = this.graph.getNumEdges() + ". INSERT " + this.graph.getEdge(x, y) + " " + T + " " + bump + " " + label;
TetradLogger.getInstance().log(message);
}
int numEdges = this.graph.getNumEdges();
if (verbose) {
if (numEdges % 1000 == 0) TetradLogger.getInstance().log("Num edges added: " + numEdges);
}
if (this.verbose) {
String label = this.trueGraph != null && trueEdge != null ? "*" : "";
TetradLogger.getInstance().log(this.graph.getNumEdges() + ". INSERT " + this.graph.getEdge(x, y) + " " + T + " " + bump + " " + label + " degree = " + GraphUtils.getDegree(this.graph) + " indegree = " + GraphUtils.getIndegree(this.graph));
}
for (Node _t : T) {
if (Thread.currentThread().isInterrupted()) {
break;
}
this.graph.removeEdge(_t, y);
// removing similar edges to enforce repeating structure **/
removeSimilarEdges(_t, y);
this.graph.addDirectedEdge(_t, y);
// Adding similar edges to enforce repeating structure **/
addSimilarEdges(_t, y);
// **/
if (this.verbose) {
String message = "--- Directing " + this.graph.getEdge(_t, y);
TetradLogger.getInstance().log(message);
}
}
return true;
}
/**
* Delete an edge from the graph.
*
* @param x The source node of the edge to be deleted.
* @param y The target node of the edge to be deleted.
* @param H The set of nodes.
* @param bump The weight of the edge being deleted.
* @param naYX The set of nodes not adjacent to both x and y.
* @return true if the deletion is successful, false otherwise.
*/
private boolean delete(Node x, Node y, Set H, double bump, Set naYX) {
Edge trueEdge = null;
if (this.trueGraph != null) {
Node _x = this.trueGraph.getNode(x.getName());
Node _y = this.trueGraph.getNode(y.getName());
trueEdge = this.trueGraph.getEdge(_x, _y);
}
Edge oldxy = this.graph.getEdge(x, y);
Set diff = new HashSet<>(naYX);
diff.removeAll(H);
this.graph.removeEdge(oldxy);
this.removedEdges.add(Edges.undirectedEdge(x, y));
// removing similar edges to enforce repeating structure **/
removeSimilarEdges(x, y);
int numEdges = this.graph.getNumEdges();
if (verbose) {
if (numEdges % 1000 == 0) TetradLogger.getInstance().log("Num edges (backwards) = " + numEdges);
}
if (this.verbose) {
String label = this.trueGraph != null && trueEdge != null ? "*" : "";
String message = (this.graph.getNumEdges()) + ". DELETE " + x + "-->" + y + " H = " + H + " NaYX = " + naYX + " diff = " + diff + " (" + bump + ") " + label;
TetradLogger.getInstance().log(message);
}
for (Node h : H) {
if (this.graph.isParentOf(h, y) || this.graph.isParentOf(h, x)) continue;
Edge oldyh = this.graph.getEdge(y, h);
this.graph.removeEdge(oldyh);
this.graph.addEdge(Edges.directedEdge(y, h));
// removing similar edges (which should be undirected) and adding similar directed edges **/
removeSimilarEdges(y, h);
addSimilarEdges(y, h);
// **/
if (this.verbose) {
String message = "--- Directing " + oldyh + " to " + this.graph.getEdge(y, h);
TetradLogger.getInstance().log(message);
}
Edge oldxh = this.graph.getEdge(x, h);
if (Edges.isUndirectedEdge(oldxh)) {
this.graph.removeEdge(oldxh);
this.graph.addEdge(Edges.directedEdge(x, h));
// removing similar edges (which should be undirected) and adding similar directed edges **/
removeSimilarEdges(x, h);
addSimilarEdges(x, h);
// **/
if (this.verbose) {
String message = "--- Directing " + oldxh + " to " + this.graph.getEdge(x, h);
TetradLogger.getInstance().log(message);
}
}
}
return true;
}
/**
* Test if the candidate insertion is a valid operation.
*
* This method checks if the candidate insertion of a node `x` and node `y` is a valid operation based on certain
* conditions.
*
* The conditions checked are:
*
* - If there is any knowledge that forbids the insertion of nodes `x` and `y`.
*
- If there is any knowledge that forbids the insertion of any nodes in the set `T` with node `y`.
*
- If the union of set `T` and set `naYX` forms a clique in the graph.
*
- If there exists any unblocked semi-directed path from node `y` to node `x` with a cycle bound.
*
*
* @param x The node to be inserted.
* @param y The existing node with which `x` is to be connected.
* @param T The set of nodes in the graph.
* @param naYX The set of non-adjacent nodes of `y` except `x`.
* @return Returns true if the candidate insertion is valid, otherwise false.
*/
private boolean validInsert(Node x, Node y, Set T, Set naYX) {
boolean violatesKnowledge = false;
if (existsKnowledge()) {
if (this.knowledge.isForbidden(x.getName(), y.getName())) {
violatesKnowledge = true;
}
for (Node t : T) {
if (this.knowledge.isForbidden(t.getName(), y.getName())) {
violatesKnowledge = true;
}
}
}
Set union = new HashSet<>(T);
union.addAll(naYX);
boolean clique = GraphUtils.isClique(union, this.graph);
int cycleBound = -1;
boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound);
return clique && noCycle && !violatesKnowledge;
}
/**
* Validates if the delete operation is allowed based on the given parameters.
*
* @param x the first node involved in the delete operation
* @param y the second node involved in the delete operation
* @param H the set of nodes representing external knowledge
* @param naYX the set of nodes that are neighbors to y and adjacent to x.
* @return true if the delete operation is valid, false otherwise
*/
private boolean validDelete(Node x, Node y, Set H, Set naYX) {
boolean violatesKnowledge = false;
if (existsKnowledge()) {
for (Node h : H) {
if (this.knowledge.isForbidden(x.getName(), h.getName())) {
violatesKnowledge = true;
}
if (this.knowledge.isForbidden(y.getName(), h.getName())) {
violatesKnowledge = true;
}
}
}
Set diff = new HashSet<>(naYX);
diff.removeAll(H);
return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge;
}
/**
* Adds edges required by knowledge.
*
* @param graph the graph to add the required edges to
*/
private void addRequiredEdges(Graph graph) {
if (!existsKnowledge()) return;
for (Iterator it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) {
if (Thread.currentThread().isInterrupted()) {
break;
}
KnowledgeEdge next = it.next();
Node nodeA = graph.getNode(next.getFrom());
Node nodeB = graph.getNode(next.getTo());
if (!graph.paths().isAncestorOf(nodeB, nodeA)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeA, nodeB);
if (verbose) {
String message = "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB);
TetradLogger.getInstance().log(message);
}
}
}
for (Edge edge : graph.getEdges()) {
if (Thread.currentThread().isInterrupted()) {
break;
}
String A = edge.getNode1().getName();
String B = edge.getNode2().getName();
if (this.knowledge.isForbidden(A, B)) {
Node nodeA = edge.getNode1();
Node nodeB = edge.getNode2();
if (graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA);
TetradLogger.getInstance().log(message);
}
}
}
if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA);
TetradLogger.getInstance().log(message);
}
}
}
} else if (this.knowledge.isForbidden(B, A)) {
Node nodeA = edge.getNode2();
Node nodeB = edge.getNode1();
if (graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA);
TetradLogger.getInstance().log(message);
}
}
}
if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
if (!graph.paths().isAncestorOf(nodeA, nodeB)) {
graph.removeEdges(nodeA, nodeB);
graph.addDirectedEdge(nodeB, nodeA);
if (verbose) {
String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA);
TetradLogger.getInstance().log(message);
}
}
}
}
}
}
/**
* Determines if the given subset of nodes contains any node that, when oriented towards the specified node 'y',
* violates any forbidden directions according to prior knowledge.
*
* @param y the node towards which the orientation is being checked
* @param subset the set of nodes to be checked for invalid orientations
* @return true if any node in the subset violates a forbidden direction when oriented towards 'y', false otherwise
*/
private boolean invalidSetByKnowledge(Node y, Set subset) {
for (Node node : subset) {
if (getKnowledge().isForbidden(node.getName(), y.getName())) {
return true;
}
}
return false;
}
/**
* Finds all nodes that are neighbors to y and adjacent to x.
*
* @param x The first node (X).
* @param y The second node (Y).
* @return A set of nodes that are neighbors to y and adjacent to x.
*/
private Set getNaYX(Node x, Node y) {
List adj = this.graph.getAdjacentNodes(y);
Set nayx = new HashSet<>();
for (Node z : adj) {
if (z == x) continue;
Edge yz = this.graph.getEdge(y, z);
if (!Edges.isUndirectedEdge(yz)) continue;
if (!this.graph.isAdjacentTo(z, x)) continue;
nayx.add(z);
}
return nayx;
}
/**
* Returns true if a path consisting of undirected and directed edges toward 'to' exists of length at most 'bound.'
* Cycle checker in other words.
*
* @param from the starting node
* @param to the target node
* @param cond the set of nodes to be ignored during traversal
* @param bound the maximum length of the path. If -1, there is no limit.
* @return true if an unblocked semi-directed path exists, false otherwise.
*/
private boolean existsUnblockedSemiDirectedPath(Node from, Node to, Set cond, int bound) {
Queue Q = new LinkedList<>();
Set V = new HashSet<>();
Q.offer(from);
V.add(from);
Node e = null;
int distance = 0;
while (!Q.isEmpty()) {
Node t = Q.remove();
if (t == to) {
return true;
}
if (e == t) {
e = null;
distance++;
if (distance > (bound == -1 ? 1000 : bound)) return false;
}
for (Node u : this.graph.getAdjacentNodes(t)) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Edge edge = this.graph.getEdge(t, u);
Node c = SvarFges.traverseSemiDirected(t, edge);
if (c == null) continue;
if (cond.contains(c)) continue;
if (c == to) {
return true;
}
if (!V.contains(c)) {
V.add(c);
Q.offer(c);
if (e == null) {
e = u;
}
}
}
}
return false;
}
/**
* Runs Meek rules on just the changed adj.
*
* @param knowledge the knowledge used for orienting implied edges by MeekRules
* @return a set of nodes representing the results of orienting implied edges
*/
private Set meekOrientRestricted(Knowledge knowledge) {
MeekRules rules = new MeekRules();
rules.setKnowledge(knowledge);
return rules.orientImplied(this.graph);
}
/**
* Builds the indexing of the nodes for quick lookup.
*
* @param nodes The list of nodes.
*/
private void buildIndexing(List nodes) {
this.hashIndices = new ConcurrentHashMap<>();
int i = -1;
for (Node n : nodes) {
this.hashIndices.put(n, ++i);
}
}
/**
* Removes information associated with an edge x->y.
*
* @param x the source node of the edge
* @param y the target node of the edge
*/
private synchronized void clearArrow(Node x, Node y) {
OrderedPair pair = new OrderedPair<>(x, y);
Set lookupArrows = this.lookupArrows.get(pair);
if (lookupArrows != null) {
this.sortedArrows.removeAll(lookupArrows);
}
this.lookupArrows.remove(pair);
}
/**
* Adds the given arrow for the adjacency i->j.
*
* @param i The starting node of the arrow.
* @param j The ending node of the arrow.
* @param arrow The arrow to be added.
*/
private void addLookupArrow(Node i, Node j, Arrow arrow) {
OrderedPair pair = new OrderedPair<>(i, j);
Set arrows = this.lookupArrows.get(pair);
if (arrows == null) {
arrows = new ConcurrentSkipListSet<>();
this.lookupArrows.put(pair, arrows);
}
arrows.add(arrow);
}
/**
* Calculates the score graph change given a node, set of parent nodes, another node, and a hash map of node
* indices.
*
* @param y the node for which the score graph change is calculated
* @param parents the set of parent nodes
* @param x the other node
* @param hashIndices the hash map containing node indices
* @return the score graph change as a double value
*/
private double scoreGraphChange(Node y, Set parents, Node x, Map hashIndices) {
int yIndex = hashIndices.get(y);
if (parents.contains(x)) return Double.NaN;//throw new IllegalArgumentException();
int[] parentIndices = new int[parents.size()];
int count = 0;
for (Node parent : parents) {
parentIndices[count++] = hashIndices.get(parent);
}
return this.score.localScoreDiff(hashIndices.get(x), yIndex, parentIndices);
}
private List getVariables() {
return this.variables;
}
/**
* Stores the current graph if its total score is high enough to be considered as one of the top graphs.
*
* If the number of graphs to store is greater than zero, a copy of the current graph is added to the list of top
* graphs along with its `totalScore`. The copy is created to prevent any subsequent modifications to affect the
* stored graph.
*
* If the list of top graphs exceeds the desired number of graphs to store, the lowest scored graph is removed from
* the list.
*
* Note: This method does not return any value.
*/
private void storeGraph() {
if (getnumCPDAGsToStore() > 0) {
Graph graphCopy = new EdgeListGraph(this.graph);
this.topGraphs.addLast(new ScoredGraph(graphCopy, this.totalScore));
}
if (this.topGraphs.size() == getnumCPDAGsToStore() + 1) {
this.topGraphs.removeFirst();
}
}
/**
* Returns pairs of similar nodes based on the given nodes x and y.
*
* @param x the first node
* @param y the second node
* @return the list of similar pairs of nodes
*/
private List> returnSimilarPairs(Node x, Node y) {
if (x.getName().equals("time") || y.getName().equals("time")) {
return new ArrayList<>();
}
int ntiers = this.knowledge.getNumTiers();
int indx_tier = this.knowledge.isInWhichTier(x);
int indy_tier = this.knowledge.isInWhichTier(y);
int tier_diff = FastMath.max(indx_tier, indy_tier) - FastMath.min(indx_tier, indy_tier);
int indx_comp = -1;
int indy_comp = -1;
List tier_x = this.knowledge.getTier(indx_tier);
List tier_y = this.knowledge.getTier(indy_tier);
int i;
for (i = 0; i < tier_x.size(); ++i) {
if (getNameNoLag(x.getName()).equals(getNameNoLag(tier_x.get(i)))) {
indx_comp = i;
break;
}
}
for (i = 0; i < tier_y.size(); ++i) {
if (getNameNoLag(y.getName()).equals(getNameNoLag(tier_y.get(i)))) {
indy_comp = i;
break;
}
}
List simListX = new ArrayList<>();
List simListY = new ArrayList<>();
for (i = 0; i < ntiers - tier_diff; ++i) {
if (this.knowledge.getTier(i).size() == 1) continue;
String A;
Node x1;
String B;
Node y1;
List tmp_tier1;
List tmp_tier2;
if (indx_tier >= indy_tier) {
tmp_tier1 = this.knowledge.getTier(i + tier_diff);
tmp_tier2 = this.knowledge.getTier(i);
} else {
tmp_tier1 = this.knowledge.getTier(i);
tmp_tier2 = this.knowledge.getTier(i + tier_diff);
}
A = tmp_tier1.get(indx_comp);
B = tmp_tier2.get(indy_comp);
if (A.equals(B)) continue;
if (A.equals(tier_x.get(indx_comp)) && B.equals(tier_y.get(indy_comp))) continue;
if (B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue;
x1 = this.graph.getNode(A);
y1 = this.graph.getNode(B);
simListX.add(x1);
simListY.add(y1);
}
List> pairList = new ArrayList<>();
pairList.add(simListX);
pairList.add(simListY);
return (pairList);
}
/**
* Retrieves the name from the given object without any lag.
*
* @param obj the object from which to retrieve the name
* @return the name extracted from the object
*/
public String getNameNoLag(Object obj) {
String tempS = obj.toString();
if (tempS.indexOf(':') == -1) {
return tempS;
} else return tempS.substring(0, tempS.indexOf(':'));
}
/**
* Adds similar edges between two nodes.
*
* @param x The first node.
* @param y The second node.
*/
public void addSimilarEdges(Node x, Node y) {
List> simList = returnSimilarPairs(x, y);
if (simList.isEmpty()) return;
List x1List = simList.get(0);
List y1List = simList.get(1);
Iterator itx = x1List.iterator();
Iterator ity = y1List.iterator();
while (itx.hasNext() && ity.hasNext()) {
Node x1 = itx.next();
Node y1 = ity.next();
this.graph.addDirectedEdge(x1, y1);
}
}
/**
* Removes similar edges between two nodes.
*
* @param x the first node
* @param y the second node
*/
public void removeSimilarEdges(Node x, Node y) {
List> simList = returnSimilarPairs(x, y);
if (simList.isEmpty()) return;
List x1List = simList.get(0);
List y1List = simList.get(1);
Iterator itx = x1List.iterator();
Iterator ity = y1List.iterator();
while (itx.hasNext() && ity.hasNext()) {
Node x1 = itx.next();
Node y1 = ity.next();
Edge oldxy = this.graph.getEdge(x1, y1);
this.graph.removeEdge(oldxy);
this.removedEdges.add(Edges.undirectedEdge(x1, y1));
}
}
/**
* The Mode enum represents different modes/options for a particular algorithm.
*
* It provides several options that can be used to configure the behavior of the algorithm. Each option has a brief
* description explaining its purpose.
*/
private enum Mode {
/**
* Indicates whether unfaithfulness is allowed.
*/
allowUnfaithfulness,
/**
* Represents a mode option for the algorithm.
*
* This option is used to specify whether to cover noncolliders during the processing or not.
*/
coverNoncolliders,
/**
* Represents a heuristic speedup option.
*
* This option is used to enable or disable a heuristic speedup algorithm in a particular context. When this
* option is enabled, the algorithm will attempt to speed up the processing by using heuristics.
*
* @see Mode#heuristicSpeedup
*/
heuristicSpeedup
}
// Basic data structure for an arrow a->b considered for addition or removal from the graph, together with
// associated sets needed to make this determination. For both forward and backward direction, NaYX is needed.
// For the forward direction, T neighbors are needed; for the backward direction, H neighbors are needed.
// See Chickering (2002). The totalScore difference resulting from added in the edge (hypothetically) is recorded
// as the "bump,"
private static class Arrow implements Comparable {
private final double bump;
private final Node a;
private final Node b;
private final Set hOrT;
private final Set naYX;
private final int index;
/**
* Represents an arrow in a graph connecting two nodes.
*
* @param bump the bump value of the arrow object
* @param a a node object representing the 'a' node of the arrow object
* @param b a node object representing the 'b' node of the arrow object
* @param hOrT a set of nodes associated with H or T for the current arrow object
* @param naYX a set of nodes associated with NaYX for the current arrow object
* @param index the index value of the arrow object
*/
public Arrow(double bump, Node a, Node b, Set hOrT, Set naYX, int index) {
this.bump = bump;
this.a = a;
this.b = b;
this.hOrT = hOrT;
this.naYX = naYX;
this.index = index;
}
/**
* Returns the bump value of the Arrow object.
*
* @return the bump value
*/
public double getBump() {
return this.bump;
}
/**
* Returns the Node object representing the 'a' node of the Arrow object.
*
* @return the Node 'a'
*/
public Node getA() {
return this.a;
}
/**
* Returns the Node object representing the 'b' node of the Arrow object.
*
* @return the Node 'b'
*/
public Node getB() {
return this.b;
}
/**
* Returns the set of nodes associated with H or T for the current Arrow object.
*
* @return the set of nodes associated with H or T
*/
public Set getHOrT() {
return this.hOrT;
}
/**
* Returns the set of nodes that are neighbors of y and adjacent to x for the current Arrow object.
*
* @return the set of nodes associated with NaYX
*/
public Set getNaYX() {
return this.naYX;
}
/**
* Compares this Arrow object with the specified Arrow object for order. Returns a negative integer, zero, or a
* positive integer as this object is less than, equal to, or greater than the specified object.
*
* @param arrow the Arrow object to be compared
* @return a negative integer, zero, or a positive integer as this object is less than, equal to, or greater
* than the specified object
*/
public int compareTo(@NotNull Arrow arrow) {
int compare = Double.compare(arrow.getBump(), getBump());
if (compare == 0) {
return Integer.compare(getIndex(), arrow.getIndex());
}
return compare;
}
public String toString() {
return "Arrow<" + this.a + "->" + this.b + " bump = " + this.bump + " t/h = " + this.hOrT + " naYX = " + this.naYX + ">";
}
public int getIndex() {
return this.index;
}
}
/**
* An internal class representing a recursive task to initialize effect edges in the graph.
*/
private class NodeTaskEmptyGraph extends RecursiveTask {
private final int from;
private final int to;
private final List nodes;
private final Set emptySet;
public NodeTaskEmptyGraph(int from, int to, List nodes, Set emptySet) {
this.from = from;
this.to = to;
this.nodes = nodes;
this.emptySet = emptySet;
}
@Override
protected Boolean compute() {
for (int i = this.from; i < this.to; i++) {
if ((i + 1) % 1000 == 0) {
SvarFges.this.count[0] += 1000;
}
Node y = this.nodes.get(i);
SvarFges.this.neighbors.put(y, this.emptySet);
for (int j = i + 1; j < this.nodes.size(); j++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
Node x = this.nodes.get(j);
if (existsKnowledge()) {
if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) {
continue;
}
if (invalidSetByKnowledge(y, this.emptySet)) {
continue;
}
}
if (SvarFges.this.adjacencies != null && !SvarFges.this.adjacencies.isAdjacentTo(x, y)) {
continue;
}
int child = SvarFges.this.hashIndices.get(y);
int parent = SvarFges.this.hashIndices.get(x);
double bump = SvarFges.this.score.localScoreDiff(parent, child);
if (bump > 0) {
Edge edge = Edges.undirectedEdge(x, y);
SvarFges.this.effectEdgesGraph.addEdge(edge);
}
if (bump > 0.0) {
addArrow(x, y, this.emptySet, this.emptySet, bump);
addArrow(y, x, this.emptySet, this.emptySet, bump);
}
}
}
return true;
}
}
}