All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.search.SvarFges Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.ScoredGraph;
import edu.cmu.tetrad.search.utils.DagScorer;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TaskManager;
import edu.cmu.tetrad.util.TetradLogger;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

import java.io.PrintStream;
import java.util.*;
import java.util.concurrent.*;

/**
 * Adapts FGES for the time series setting, assuming the data is generated by a SVAR (structural vector autoregression).
 * The main difference is that time order is imposed, and if an edge is removed, it will also remove all homologous
 * edges to preserve the time-repeating structure assumed by SvarFCI. Based on (but not identical to) code by Entner and
 * Hoyer for their 2010 paper. Modified by dmalinsky 4/21/2016.
 * 

* The references are as follows: *

* Malinsky, D., & Spirtes, P. (2018, August). Causal structure learning from multivariate time series in settings * with unmeasured confounding. In Proceedings of 2018 ACM SIGKDD workshop on causal discovery (pp. 23-47). PMLR. *

* Entner, D., & Hoyer, P. O. (2010). On causal discovery from time series data using FCI. Probabilistic graphical * models, 121-128. *

* This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal * tiers. * * @author danielmalinsky * @version $Id: $Id * @see Fges * @see Knowledge * @see SvarFci */ public final class SvarFges implements IGraphSearch, DagScorer { /** * The number of threads to use. */ final int maxThreads = Runtime.getRuntime().availableProcessors(); /** * The top n graphs found by the algorithm, where n is numCPDAGsToStore. */ private final LinkedList topGraphs = new LinkedList<>(); /** * The static ForkJoinPool instance. */ private final ForkJoinPool pool = new ForkJoinPool(maxThreads); /** * The number of graphs searched. */ private final int[] count = new int[1]; /** * The set of removed edges. */ private final Set removedEdges = new HashSet<>(); /** * Arrows with the same totalScore are stored in this list to distinguish their order in sortedArrows. The ordering * doesn't matter; it just has to be transitive. */ private int arrowIndex; /** * Specification of forbidden and required edges. */ private Knowledge knowledge = new Knowledge(); /** * List of variables in the data set, in order. */ private List variables; /** * The true graph, if known. If this is provided, asterisks will be printed out next to false positive added edges * (that is, edges added that aren't adjacencies in the true graph). */ private Graph trueGraph; /** * An initial graph to start from. */ private Graph externalGraph; /** * Elapsed time of the most recent search. */ private long elapsedTime; /** * The totalScore for discrete searches. */ private Score score; /** * The number of top CPDAGs to store. */ private int numCPDAGsToStore; /** * True if verbose output should be printed. */ private boolean verbose; /** * Potential arrows sorted by bump high to low. The first one is a candidate for adding to the graph. */ private SortedSet sortedArrows; /** * Arrows added to sortedArrows for each . */ private Map, Set> lookupArrows; /** * A utility map to help with orientation. */ private Map> neighbors; /** * Map from variables to their column indices in the data set. */ private ConcurrentMap hashIndices; /** * A running tally of the total BIC totalScore. */ private double totalScore; /** * A graph where X--Y means that X and Y have non-zero total effect on one another. */ private Graph effectEdgesGraph; /** * Where printed output is sent. */ private transient PrintStream out = System.out; /** * An initial adjacencies graph. */ private Graph adjacencies; /** * The graph being constructed. */ private Graph graph; /** * The mode variable represents the current mode of operation. */ private Mode mode = Mode.heuristicSpeedup; /** * True if one-edge faithfulness is assumed. Speeds the algorithm up. */ private boolean faithfulnessAssumed = false; /** * Bounds the indegree of the graph. */ private int maxIndegree = -1; /** * Construct a Score and pass it in here. The totalScore should return a positive value in case of conditional * dependence and a negative values in case of conditional independence. See Chickering (2002), locally consistent * scoring criterion. * * @param score a {@link edu.cmu.tetrad.search.score.Score} object */ public SvarFges(Score score) { if (score == null) throw new NullPointerException(); setScore(score); this.graph = new EdgeListGraph(getVariables()); } /** * Traverses a semi-directed graph from a given node along a given edge. * * @param node The starting node of the traversal. * @param edge The edge to traverse from the starting node. * @return The destination node of the traversal if the edge is traversable from the starting node, or null if the * edge is not traversable from the starting node. */ private static Node traverseSemiDirected(Node node, Edge edge) { if (node == edge.getNode1()) { if (edge.getEndpoint1() == Endpoint.TAIL) { return edge.getNode2(); } } else if (node == edge.getNode2()) { if (edge.getEndpoint2() == Endpoint.TAIL) { return edge.getNode1(); } } return null; } /** * Greedy equivalence search: Start from the empty graph, add edges till the model is significant. Then start * deleting edges till a minimum is achieved. * * @return the resulting CPDAG. */ public Graph search() { this.topGraphs.clear(); this.lookupArrows = new ConcurrentHashMap<>(); List nodes = new ArrayList<>(this.variables); this.graph = new EdgeListGraph(nodes); if (this.adjacencies != null) { this.adjacencies = GraphUtils.replaceNodes(this.adjacencies, nodes); } if (this.externalGraph != null) { this.graph = new EdgeListGraph(this.externalGraph); this.graph = GraphUtils.replaceNodes(this.graph, nodes); } addRequiredEdges(this.graph); if (this.faithfulnessAssumed) { initializeForwardEdgesFromEmptyGraph(getVariables()); // Do forward search. this.mode = Mode.heuristicSpeedup; fes(); bes(); this.mode = Mode.coverNoncolliders; initializeTwoStepEdges(getVariables()); } else { initializeForwardEdgesFromEmptyGraph(getVariables()); // Do forward search. this.mode = Mode.heuristicSpeedup; fes(); bes(); this.mode = Mode.allowUnfaithfulness; initializeForwardEdgesFromExistingGraph(getVariables()); } fes(); bes(); long start = MillisecondTimes.timeMillis(); this.totalScore = 0.0; long endTime = MillisecondTimes.timeMillis(); this.elapsedTime = endTime - start; if (verbose) { TetradLogger.getInstance().log("Elapsed time = " + (this.elapsedTime) / 1000. + " s"); } // The final totalScore after search. return this.graph; } /** * Set to true if it is assumed that all path pairs with one length 1 path do not cancel. * * @param faithfulnessAssumed a boolean */ public void setFaithfulnessAssumed(boolean faithfulnessAssumed) { this.faithfulnessAssumed = faithfulnessAssumed; } /** *

Getter for the field knowledge.

* * @return the background knowledge. */ public Knowledge getKnowledge() { return this.knowledge; } /** * Sets the background knowledge. * * @param knowledge the knowledge object, specifying forbidden and required edges. */ public void setKnowledge(Knowledge knowledge) { this.knowledge = new Knowledge(knowledge); } /** * Returns the elapsed time of the search. * * @return This time. */ public long getElapsedTime() { return this.elapsedTime; } /** * Sets the true graph, which will result in some edges in output graphs being marked with asterisks. * * @param trueGraph the true graph. */ public void setTrueGraph(Graph trueGraph) { this.trueGraph = trueGraph; } /** * Returns the score of the given DAG. * * @param dag the dag to score. * @return the totalScore of the given DAG, up to a constant. */ public double getScore(Graph dag) { return scoreDag(dag); } /** *

Getter for the field numCPDAGsToStore.

* * @return the number of patterns to store. */ public int getnumCPDAGsToStore() { return this.numCPDAGsToStore; } /** * Sets the number of patterns to store. This should be set to zero for fast search. * * @param numCPDAGsToStore a int */ public void setNumCPDAGsToStore(int numCPDAGsToStore) { if (numCPDAGsToStore < 0) { throw new IllegalArgumentException("# graphs to store must at least 0: " + numCPDAGsToStore); } this.numCPDAGsToStore = numCPDAGsToStore; } /** *

Getter for the field externalGraph.

* * @return the initial graph for the search. The search is initialized to this graph and proceeds from there. */ public Graph getExternalGraph() { return this.externalGraph; } /** * Sets the initial graph. * * @param externalGraph the initial graph. */ public void setExternalGraph(Graph externalGraph) { if (externalGraph != null) { externalGraph = GraphUtils.replaceNodes(externalGraph, this.variables); if (!new HashSet<>(externalGraph.getNodes()).equals(new HashSet<>(this.variables))) { throw new IllegalArgumentException("Variables aren't the same."); } } this.externalGraph = externalGraph; } /** * Sets whether verbose output should be produced. * * @param verbose true if verbose output should be produced. */ public void setVerbose(boolean verbose) { this.verbose = verbose; } /** *

Getter for the field out.

* * @return the output stream that output (except for log output) should be sent to. */ public PrintStream getOut() { return this.out; } /** * Sets the output stream that output (except for log output) should be sent to. By detault System.out. * * @param out the output stream. */ public void setOut(PrintStream out) { this.out = out; } /** * Retrieves the adjacency graph. * * @return the adjacency graph. */ public Graph getAdjacencies() { return this.adjacencies; } /** * Sets the set of preset adjacencies for the algorithm; edges not in this adjacencies graph will not be added. * * @param adjacencies the adjacencies graph. */ public void setAdjacencies(Graph adjacencies) { this.adjacencies = adjacencies; } /** * The maximum of parents any nodes can have in the output pattern. * * @return -1 for unlimited. */ public int getMaxIndegree() { return this.maxIndegree; } /** * The maximum of parents any nodes can have in the output pattern. * * @param maxIndegree -1 for unlimited. */ public void setMaxIndegree(int maxIndegree) { if (maxIndegree < -1) throw new IllegalArgumentException(); this.maxIndegree = maxIndegree; } /** * Returns the minimum number of operations to perform before parallelizing. * * @param n The total number of operations to be performed. * @return The minimum number of operations to do before parallelizing. */ public int getMinChunk(int n) { // The minimum number of operations to do before parallelizing. int minChunk = 100; return FastMath.max(n / this.maxThreads, minChunk); } /** * Calculates the score of a Directed Acyclic Graph (DAG). * * @param dag The Directed Acyclic Graph to calculate the score for. * @return The score of the DAG. */ public double scoreDag(Graph dag) { buildIndexing(dag.getNodes()); double _score = 0.0; for (Node y : dag.getNodes()) { Set parents = new HashSet<>(dag.getParents(y)); int[] parentIndices = new int[parents.size()]; Iterator pi = parents.iterator(); int count = 0; while (pi.hasNext()) { Node nextParent = pi.next(); parentIndices[count++] = this.hashIndices.get(nextParent); } int yIndex = this.hashIndices.get(y); _score += this.score.localScore(yIndex, parentIndices); } return _score; } /** * Sets the score and updates the variables, indexing, and maxIndegree accordingly. * * @param totalScore The total score to be set. */ private void setScore(Score totalScore) { this.score = totalScore; this.variables = new ArrayList<>(); for (Node node : totalScore.getVariables()) { if (node.getNodeType() == NodeType.MEASURED) { this.variables.add(node); } } buildIndexing(totalScore.getVariables()); this.maxIndegree = this.score.getMaxDegree(); } /** * Initializes the forward edges from an empty graph. * * @param nodes The list of nodes in the graph. */ private void initializeForwardEdgesFromEmptyGraph(List nodes) { this.sortedArrows = new ConcurrentSkipListSet<>(); this.lookupArrows = new ConcurrentHashMap<>(); this.neighbors = new ConcurrentHashMap<>(); Set emptySet = new HashSet<>(); long start = MillisecondTimes.timeMillis(); this.effectEdgesGraph = new EdgeListGraph(nodes); class InitializeFromEmptyGraphTask extends RecursiveTask { public InitializeFromEmptyGraphTask() { } @Override protected Boolean compute() { Queue tasks = new ArrayDeque<>(); int numNodesPerTask = FastMath.max(100, nodes.size() / SvarFges.this.maxThreads); for (int i = 0; i < nodes.size(); i += numNodesPerTask) { NodeTaskEmptyGraph task = new NodeTaskEmptyGraph(i, FastMath.min(nodes.size(), i + numNodesPerTask), nodes, emptySet); tasks.add(task); task.fork(); for (NodeTaskEmptyGraph _task : new ArrayList<>(tasks)) { if (_task.isDone()) { _task.join(); tasks.remove(_task); } } while (tasks.size() > SvarFges.this.maxThreads) { NodeTaskEmptyGraph _task = tasks.poll(); if (_task != null) { _task.join(); } } } for (NodeTaskEmptyGraph task : tasks) { task.join(); } return true; } } try { this.pool.invoke(new InitializeFromEmptyGraphTask()); } catch (Exception e) { Thread.currentThread().interrupt(); throw e; } if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) { throw new IllegalStateException("Pool timed out"); } long stop = MillisecondTimes.timeMillis(); if (this.verbose) { TetradLogger.getInstance().log("Elapsed initializeForwardEdgesFromEmptyGraph = " + (stop - start) + " ms"); } } /** * Initializes two-step edges based on given list of nodes. * * @param nodes The list of nodes to initialize two-step edges. */ private void initializeTwoStepEdges(List nodes) { this.count[0] = 0; this.sortedArrows = new ConcurrentSkipListSet<>(); this.lookupArrows = new ConcurrentHashMap<>(); this.neighbors = new ConcurrentHashMap<>(); if (this.effectEdgesGraph == null) { this.effectEdgesGraph = new EdgeListGraph(nodes); } if (this.externalGraph != null) { for (Edge edge : this.externalGraph.getEdges()) { if (!this.effectEdgesGraph.isAdjacentTo(edge.getNode1(), edge.getNode2())) { this.effectEdgesGraph.addUndirectedEdge(edge.getNode1(), edge.getNode2()); } } } Set emptySet = new HashSet<>(0); class InitializeFromExistingGraphTask extends RecursiveTask { private final int chunk; private final int from; private final int to; public InitializeFromExistingGraphTask(int chunk, int from, int to) { this.chunk = chunk; this.from = from; this.to = to; } @Override protected Boolean compute() { if (TaskManager.getInstance().isCanceled()) return false; if (this.to - this.from <= this.chunk) { for (int i = this.from; i < this.to; i++) { if ((i + 1) % 1000 == 0) { SvarFges.this.count[0] += 1000; if (verbose) { TetradLogger.getInstance().log("Initializing effect edges: " + (SvarFges.this.count[0])); } } Node y = nodes.get(i); Set g = new HashSet<>(); for (Node n : SvarFges.this.graph.getAdjacentNodes(y)) { for (Node m : SvarFges.this.graph.getAdjacentNodes(n)) { if (SvarFges.this.graph.isAdjacentTo(y, m)) { continue; } if (SvarFges.this.graph.isDefCollider(m, n, y)) { continue; } g.add(m); } } for (Node x : g) { if (existsKnowledge()) { if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) { continue; } if (invalidSetByKnowledge(y, emptySet)) { continue; } } if (SvarFges.this.adjacencies != null && !SvarFges.this.adjacencies.isAdjacentTo(x, y)) { continue; } if (SvarFges.this.removedEdges.contains(Edges.undirectedEdge(x, y))) { continue; } calculateArrowsForward(x, y); } } } else { int mid = (this.to + this.from) / 2; InitializeFromExistingGraphTask left = new InitializeFromExistingGraphTask(this.chunk, this.from, mid); InitializeFromExistingGraphTask right = new InitializeFromExistingGraphTask(this.chunk, mid, this.to); left.fork(); right.compute(); left.join(); } return true; } } try { this.pool.invoke(new InitializeFromExistingGraphTask(getMinChunk(nodes.size()), 0, nodes.size())); } catch (Exception e) { Thread.currentThread().interrupt(); throw e; } if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) { throw new IllegalStateException("Pool timed out"); } } /** * Initializes the forward edges from an existing graph. * * @param nodes The list of nodes in the graph. */ private void initializeForwardEdgesFromExistingGraph(List nodes) { this.count[0] = 0; this.sortedArrows = new ConcurrentSkipListSet<>(); this.lookupArrows = new ConcurrentHashMap<>(); this.neighbors = new ConcurrentHashMap<>(); if (this.effectEdgesGraph == null) { this.effectEdgesGraph = new EdgeListGraph(nodes); } if (this.externalGraph != null) { for (Edge edge : this.externalGraph.getEdges()) { if (!this.effectEdgesGraph.isAdjacentTo(edge.getNode1(), edge.getNode2())) { this.effectEdgesGraph.addUndirectedEdge(edge.getNode1(), edge.getNode2()); } } } Set emptySet = new HashSet<>(0); class InitializeFromExistingGraphTask extends RecursiveTask { private final int chunk; private final int from; private final int to; public InitializeFromExistingGraphTask(int chunk, int from, int to) { this.chunk = chunk; this.from = from; this.to = to; } @Override protected Boolean compute() { if (TaskManager.getInstance().isCanceled()) return false; if (this.to - this.from <= this.chunk) { for (int i = this.from; i < this.to; i++) { if (Thread.currentThread().isInterrupted()) { break; } if ((i + 1) % 1000 == 0) { SvarFges.this.count[0] += 1000; if (verbose) { TetradLogger.getInstance().log("Initializing effect edges: " + (SvarFges.this.count[0])); } } Node y = nodes.get(i); // Set cond = new HashSet<>(); Set D = new HashSet<>(variables);// SvarFges.this.graph.paths().getMConnectedVars(y, cond)); D.remove(y); // SvarFges.this.effectEdgesGraph.getAdjacentNodes(y).forEach(D::remove); for (Node x : D) { if (existsKnowledge()) { if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) { continue; } if (invalidSetByKnowledge(y, emptySet)) { continue; } } if (SvarFges.this.adjacencies != null && !SvarFges.this.adjacencies.isAdjacentTo(x, y)) { continue; } calculateArrowsForward(x, y); } } } else { int mid = (this.to + this.from) / 2; InitializeFromExistingGraphTask left = new InitializeFromExistingGraphTask(this.chunk, this.from, mid); InitializeFromExistingGraphTask right = new InitializeFromExistingGraphTask(this.chunk, mid, this.to); left.fork(); right.compute(); left.join(); } return true; } } try { this.pool.invoke(new InitializeFromExistingGraphTask(getMinChunk(nodes.size()), 0, nodes.size())); } catch (Exception e) { Thread.currentThread().interrupt(); throw e; } if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) { throw new IllegalStateException("Pool timed out"); } } /** * Perform Forward Equivalence Search. *

* This method applies the Forward Equivalence Search algorithm to search for equivalence classes in a graph. It * iteratively processes arrows in a sorted order until all arrows are processed or the thread is interrupted. */ private void fes() { if (verbose) { TetradLogger.getInstance().log("** FORWARD EQUIVALENCE SEARCH"); } while (!this.sortedArrows.isEmpty()) { if (Thread.currentThread().isInterrupted()) { break; } Arrow arrow = this.sortedArrows.first(); this.sortedArrows.remove(arrow); Node x = arrow.getA(); Node y = arrow.getB(); if (this.graph.isAdjacentTo(x, y)) { continue; } if (!arrow.getNaYX().equals(getNaYX(x, y))) { continue; } if (!new HashSet<>(getTNeighbors(x, y)).containsAll(arrow.getHOrT())) { continue; } if (!validInsert(x, y, arrow.getHOrT(), getNaYX(x, y))) { continue; } Set T = arrow.getHOrT(); double bump = arrow.getBump(); boolean inserted = insert(x, y, T, bump); if (!inserted) continue; this.totalScore += bump; Set visited = reapplyOrientation(); Set toProcess = new HashSet<>(); for (Node node : visited) { if (Thread.currentThread().isInterrupted()) { break; } Set neighbors1 = getNeighbors(node); Set storedNeighbors = this.neighbors.get(node); if (!(neighbors1.equals(storedNeighbors))) { toProcess.add(node); } } toProcess.add(x); toProcess.add(y); storeGraph(); reevaluateForward(toProcess); } } /** * Performs the Backward Equivalence Search algorithm. */ private void bes() { if (verbose) { TetradLogger.getInstance().log("** BACKWARD EQUIVALENCE SEARCH"); } this.sortedArrows = new ConcurrentSkipListSet<>(); this.lookupArrows = new ConcurrentHashMap<>(); this.neighbors = new ConcurrentHashMap<>(); initializeArrowsBackward(); while (!this.sortedArrows.isEmpty()) { if (Thread.currentThread().isInterrupted()) { break; } Arrow arrow = this.sortedArrows.first(); this.sortedArrows.remove(arrow); Node x = arrow.getA(); Node y = arrow.getB(); if (!arrow.getNaYX().equals(getNaYX(x, y))) { continue; } if (!this.graph.isAdjacentTo(x, y)) continue; Edge edge = this.graph.getEdge(x, y); if (edge.pointsTowards(x)) continue; if (!validDelete(x, y, arrow.getHOrT(), arrow.getNaYX())) continue; Set H = arrow.getHOrT(); double bump = arrow.getBump(); boolean deleted = delete(x, y, H, bump, arrow.getNaYX()); if (!deleted) continue; this.totalScore += bump; clearArrow(x, y); Set visited = reapplyOrientation(); Set toProcess = new HashSet<>(); for (Node node : visited) { Set neighbors1 = getNeighbors(node); Set storedNeighbors = this.neighbors.get(node); if (!(neighbors1.equals(storedNeighbors))) { toProcess.add(node); } } toProcess.add(x); toProcess.add(y); toProcess.addAll(getCommonAdjacents(x, y)); storeGraph(); reevaluateBackward(toProcess); } meekOrientRestricted(getKnowledge()); } /** * Retrieves the common adjacent nodes of two given nodes. * * @param x The first node. * @param y The second node. * @return A set of nodes that are adjacent to both x and y. */ private Set getCommonAdjacents(Node x, Node y) { Set commonChildren = new HashSet<>(this.graph.getAdjacentNodes(x)); commonChildren.retainAll(this.graph.getAdjacentNodes(y)); return commonChildren; } /** * Applies the orientation of nodes in the given set by utilizing the meekOrientRestricted method. * * @return the set of nodes with re-applied orientation */ private Set reapplyOrientation() { return meekOrientRestricted(getKnowledge()); } /** * Checks if knowledge exists. * * @return true if knowledge is not empty; false otherwise. */ private boolean existsKnowledge() { return !this.knowledge.isEmpty(); } /** * Initializes the sorted arrows lists for the backward search. *

* This method iterates over each edge in the graph and performs the following steps: 1. Check if the current thread * is interrupted. If interrupted, the iteration is stopped. 2. Get the source node (x) and target node (y) of the * edge. 3. If knowledge exists, check if there is no required edge between the names of x and y. If required edge * exists, continue to next iteration. 4. Clear the arrow from x to y and from y to x (if they exist). 5. Determine * the direction of the edge and calculate the arrows backward accordingly. (If edge points towards y, calculate * arrows backward from x to y. If edge points towards x, calculate arrows backward from y to x. Otherwise, * calculate arrows backward from x to y and from y to x.) 6. Update the neighbors map with the neighbors of x and * y. */ private void initializeArrowsBackward() { for (Edge edge : this.graph.getEdges()) { if (Thread.currentThread().isInterrupted()) { break; } Node x = edge.getNode1(); Node y = edge.getNode2(); if (existsKnowledge()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { continue; } } clearArrow(x, y); clearArrow(y, x); if (edge.pointsTowards(y)) { calculateArrowsBackward(x, y); } else if (edge.pointsTowards(x)) { calculateArrowsBackward(y, x); } else { calculateArrowsBackward(x, y); calculateArrowsBackward(y, x); } this.neighbors.put(x, getNeighbors(x)); this.neighbors.put(y, getNeighbors(y)); } } /** * Recalculates new arrows based on changes in the graph for the forward search. * * @param nodes the set of nodes to reevaluate */ private void reevaluateForward(Set nodes) { class AdjTask extends RecursiveTask { private final List nodes; private final int from; private final int to; private final int chunk; public AdjTask(int chunk, List nodes, int from, int to) { this.nodes = nodes; this.from = from; this.to = to; this.chunk = chunk; } @Override protected Boolean compute() { if (this.to - this.from <= this.chunk) { for (int _w = this.from; _w < this.to; _w++) { if (Thread.currentThread().isInterrupted()) { break; } Node x = this.nodes.get(_w); List adj; if (SvarFges.this.mode == Mode.heuristicSpeedup) { adj = SvarFges.this.effectEdgesGraph.getAdjacentNodes(x); } else if (SvarFges.this.mode == Mode.coverNoncolliders) { Set g = new HashSet<>(); for (Node n : SvarFges.this.graph.getAdjacentNodes(x)) { for (Node m : SvarFges.this.graph.getAdjacentNodes(n)) { if (SvarFges.this.graph.isAdjacentTo(x, m)) { continue; } if (SvarFges.this.graph.isDefCollider(m, n, x)) { continue; } g.add(m); } } adj = new ArrayList<>(g); } else if (SvarFges.this.mode == Mode.allowUnfaithfulness) { // HashSet D = new HashSet<>(SvarFges.this.graph.paths().getMConnectedVars(x, new HashSet<>())); // D.remove(x); adj = new ArrayList<>(variables); adj.remove(x); } else { throw new IllegalStateException(); } for (Node w : adj) { if (SvarFges.this.adjacencies != null && !(SvarFges.this.adjacencies.isAdjacentTo(w, x))) { continue; } if (w == x) continue; if (!SvarFges.this.graph.isAdjacentTo(w, x)) { clearArrow(w, x); calculateArrowsForward(w, x); } } } } else { int mid = (this.to - this.from) / 2; List tasks = new ArrayList<>(); tasks.add(new AdjTask(this.chunk, this.nodes, this.from, this.from + mid)); tasks.add(new AdjTask(this.chunk, this.nodes, this.from + mid, this.to)); invokeAll(tasks); } return true; } } AdjTask task = new AdjTask(getMinChunk(nodes.size()), new ArrayList<>(nodes), 0, nodes.size()); try { pool.invoke(task); } catch (Exception e) { Thread.currentThread().interrupt(); throw e; } if (!this.pool.awaitQuiescence(1, TimeUnit.DAYS)) { Thread.currentThread().interrupt(); } } /** * Calculates the new arrows for an a->b edge. * * @param a the Node representing the start node of the edge * @param b the Node representing the end node of the edge */ private void calculateArrowsForward(Node a, Node b) { if (this.mode == Mode.heuristicSpeedup && !this.effectEdgesGraph.isAdjacentTo(a, b)) return; if (this.adjacencies != null && !this.adjacencies.isAdjacentTo(a, b)) return; this.neighbors.put(b, getNeighbors(b)); if (existsKnowledge()) { if (getKnowledge().isForbidden(a.getName(), b.getName())) { return; } } Set naYX = getNaYX(a, b); if (!GraphUtils.isClique(naYX, this.graph)) return; List TNeighbors = new ArrayList<>(getTNeighbors(a, b)); int _maxIndegree = this.maxIndegree == -1 ? 1000 : this.maxIndegree; int _max = FastMath.min(TNeighbors.size(), _maxIndegree - this.graph.getIndegree(b)); Set> previousCliques = new HashSet<>(); previousCliques.add(new HashSet<>()); Set> newCliques = new HashSet<>(); FOR: for (int i = 0; i <= _max; i++) { ChoiceGenerator gen = new ChoiceGenerator(TNeighbors.size(), i); int[] choice; while ((choice = gen.next()) != null) { if (Thread.currentThread().isInterrupted()) { break; } Set T = GraphUtils.asSet(choice, TNeighbors); Set union = new HashSet<>(naYX); union.addAll(T); boolean foundAPreviousClique = false; for (Set clique : previousCliques) { if (union.containsAll(clique)) { foundAPreviousClique = true; break; } } if (!foundAPreviousClique) { break FOR; } if (!GraphUtils.isClique(union, this.graph)) continue; newCliques.add(union); double bump = insertEval(a, b, T, naYX, this.hashIndices); if (bump > 0.0) { addArrow(a, b, naYX, T, bump); } } previousCliques = newCliques; newCliques = new HashSet<>(); } } /** * Adds an arrow between two nodes. * * @param a the starting node of the arrow * @param b the ending node of the arrow * @param naYX a set of nodes with vertical (north and south) positions * @param hOrT a set of nodes with horizontal (west and east) positions * @param bump the amount of bump to apply to the arrow */ private void addArrow(Node a, Node b, Set naYX, Set hOrT, double bump) { Arrow arrow = new Arrow(bump, a, b, hOrT, naYX, this.arrowIndex++); this.sortedArrows.add(arrow); addLookupArrow(a, b, arrow); } /** * Reevaluates arrows after removing an edge from the graph. * * @param toProcess a set of nodes to process */ private void reevaluateBackward(Set toProcess) { class BackwardTask extends RecursiveTask { private final Node r; private final List adj; private final Map hashIndices; private final int chunk; private final int from; private final int to; public BackwardTask(Node r, List adj, int chunk, int from, int to, Map hashIndices) { this.adj = adj; this.hashIndices = hashIndices; this.chunk = chunk; this.from = from; this.to = to; this.r = r; } @Override protected Boolean compute() { if (this.to - this.from <= this.chunk) { for (int _w = this.from; _w < this.to; _w++) { Node w = this.adj.get(_w); Edge e = SvarFges.this.graph.getEdge(w, this.r); if (e != null) { if (e.pointsTowards(this.r)) { clearArrow(w, this.r); clearArrow(this.r, w); calculateArrowsBackward(w, this.r); } else if (Edges.isUndirectedEdge(SvarFges.this.graph.getEdge(w, this.r))) { clearArrow(w, this.r); clearArrow(this.r, w); calculateArrowsBackward(w, this.r); calculateArrowsBackward(this.r, w); } } } } else { int mid = (this.to - this.from) / 2; List tasks = new ArrayList<>(); tasks.add(new BackwardTask(this.r, this.adj, this.chunk, this.from, this.from + mid, this.hashIndices)); tasks.add(new BackwardTask(this.r, this.adj, this.chunk, this.from + mid, this.to, this.hashIndices)); invokeAll(tasks); } return true; } } for (Node r : toProcess) { this.neighbors.put(r, getNeighbors(r)); List adjacentNodes = new ArrayList<>(this.graph.getAdjacentNodes(r)); try { this.pool.invoke(new BackwardTask(r, adjacentNodes, getMinChunk(adjacentNodes.size()), 0, adjacentNodes.size(), this.hashIndices)); } catch (Exception e) { Thread.currentThread().interrupt(); throw e; } if (!pool.awaitQuiescence(1, TimeUnit.DAYS)) { throw new IllegalStateException("Pool timed out"); } } } /** * Calculates the arrows for the removal in the backward direction. * * @param a the starting node * @param b the ending node */ private void calculateArrowsBackward(Node a, Node b) { if (existsKnowledge()) { if (!getKnowledge().noEdgeRequired(a.getName(), b.getName())) { return; } } Set naYX = getNaYX(a, b); List _naYX = new ArrayList<>(naYX); int _depth = _naYX.size(); for (int i = 0; i <= _depth; i++) { ChoiceGenerator gen = new ChoiceGenerator(_naYX.size(), i); int[] choice; while ((choice = gen.next()) != null) { if (Thread.currentThread().isInterrupted()) { break; } Set diff = GraphUtils.asSet(choice, _naYX); Set h = new HashSet<>(_naYX); h.removeAll(diff); if (existsKnowledge()) { if (invalidSetByKnowledge(b, h)) { continue; } } double bump = deleteEval(a, b, diff, this.hashIndices); if (bump > 0.0) { addArrow(a, b, naYX, h, bump); } } } } /** * Returns a set of neighbors of node Y that are connected to Y by an undirected edge and are not adjacent to node * X. * * @param x the node X * @param y the node Y * @return a set of neighbors of node Y that fulfill the specified conditions */ private Set getTNeighbors(Node x, Node y) { Set yEdges = this.graph.getEdges(y); Set tNeighbors = new HashSet<>(); for (Edge edge : yEdges) { if (!Edges.isUndirectedEdge(edge)) { continue; } Node z = edge.getDistalNode(y); if (this.graph.isAdjacentTo(z, x)) { continue; } tNeighbors.add(z); } return tNeighbors; } /** * Returns a set of neighboring nodes connected to the given node in the graph. * * @param y the node for which to retrieve neighbors * @return a set of neighboring nodes connected to the given node */ private Set getNeighbors(Node y) { Set yEdges = this.graph.getEdges(y); Set neighbors = new HashSet<>(); for (Edge edge : yEdges) { if (!Edges.isUndirectedEdge(edge)) { continue; } Node z = edge.getDistalNode(y); neighbors.add(z); } return neighbors; } /** * Evaluate the Insert(X, Y, T) operator. This method calculates the score of the graph change caused by inserting * node Y into node X with a set of target nodes T. * * @param x The destination node X * @param y The node to be inserted Y * @param t The set of target nodes T * @param naYX The set of nodes that are forbidden to be parents of Y if X is its parent * @param hashIndices The map of node-to-index associations * @return The score of the graph change caused by the node insertion */ private double insertEval(Node x, Node y, Set t, Set naYX, Map hashIndices) { Set set = new HashSet<>(naYX); set.addAll(t); set.addAll(this.graph.getParents(y)); return scoreGraphChange(y, set, x, hashIndices); } /** * Evaluate the Delete(X, Y, T) operator (Definition 12 from Chickering, 2002). * * @param x The node to delete from the graph. * @param y The node to be evaluated as a parent of X. * @param diff The set of nodes representing the difference between the original graph and the modified * graph. * @param hashIndices The map containing the indices of nodes in the hash table. * @return The score of the graph change after deleting the specified node. */ private double deleteEval(Node x, Node y, Set diff, Map hashIndices) { Set set = new HashSet<>(diff); set.addAll(this.graph.getParents(y)); set.remove(x); return -scoreGraphChange(y, set, x, hashIndices); } /** * Inserts an edge into the graph based on the given parameters. * * @param x the source node of the edge to be inserted * @param y the target node of the edge to be inserted * @param T the set of nodes to be connected to the target node * @param bump the bump value for the insertion * @return true if the insertion is successful, false otherwise */ private boolean insert(Node x, Node y, Set T, double bump) { if (this.graph.isAdjacentTo(x, y)) { return false; // The initial graph may already have put this edge in the graph. } Edge trueEdge = null; if (this.trueGraph != null) { Node _x = this.trueGraph.getNode(x.getName()); Node _y = this.trueGraph.getNode(y.getName()); trueEdge = this.trueGraph.getEdge(_x, _y); } this.graph.addDirectedEdge(x, y); // Adding similar edges to enforce repeating structure **/ addSimilarEdges(x, y); // **/ if (this.verbose) { String label = this.trueGraph != null && trueEdge != null ? "*" : ""; String message = this.graph.getNumEdges() + ". INSERT " + this.graph.getEdge(x, y) + " " + T + " " + bump + " " + label; TetradLogger.getInstance().log(message); } int numEdges = this.graph.getNumEdges(); if (verbose) { if (numEdges % 1000 == 0) TetradLogger.getInstance().log("Num edges added: " + numEdges); } if (this.verbose) { String label = this.trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance().log(this.graph.getNumEdges() + ". INSERT " + this.graph.getEdge(x, y) + " " + T + " " + bump + " " + label + " degree = " + GraphUtils.getDegree(this.graph) + " indegree = " + GraphUtils.getIndegree(this.graph)); } for (Node _t : T) { if (Thread.currentThread().isInterrupted()) { break; } this.graph.removeEdge(_t, y); // removing similar edges to enforce repeating structure **/ removeSimilarEdges(_t, y); this.graph.addDirectedEdge(_t, y); // Adding similar edges to enforce repeating structure **/ addSimilarEdges(_t, y); // **/ if (this.verbose) { String message = "--- Directing " + this.graph.getEdge(_t, y); TetradLogger.getInstance().log(message); } } return true; } /** * Delete an edge from the graph. * * @param x The source node of the edge to be deleted. * @param y The target node of the edge to be deleted. * @param H The set of nodes. * @param bump The weight of the edge being deleted. * @param naYX The set of nodes not adjacent to both x and y. * @return true if the deletion is successful, false otherwise. */ private boolean delete(Node x, Node y, Set H, double bump, Set naYX) { Edge trueEdge = null; if (this.trueGraph != null) { Node _x = this.trueGraph.getNode(x.getName()); Node _y = this.trueGraph.getNode(y.getName()); trueEdge = this.trueGraph.getEdge(_x, _y); } Edge oldxy = this.graph.getEdge(x, y); Set diff = new HashSet<>(naYX); diff.removeAll(H); this.graph.removeEdge(oldxy); this.removedEdges.add(Edges.undirectedEdge(x, y)); // removing similar edges to enforce repeating structure **/ removeSimilarEdges(x, y); int numEdges = this.graph.getNumEdges(); if (verbose) { if (numEdges % 1000 == 0) TetradLogger.getInstance().log("Num edges (backwards) = " + numEdges); } if (this.verbose) { String label = this.trueGraph != null && trueEdge != null ? "*" : ""; String message = (this.graph.getNumEdges()) + ". DELETE " + x + "-->" + y + " H = " + H + " NaYX = " + naYX + " diff = " + diff + " (" + bump + ") " + label; TetradLogger.getInstance().log(message); } for (Node h : H) { if (this.graph.isParentOf(h, y) || this.graph.isParentOf(h, x)) continue; Edge oldyh = this.graph.getEdge(y, h); this.graph.removeEdge(oldyh); this.graph.addEdge(Edges.directedEdge(y, h)); // removing similar edges (which should be undirected) and adding similar directed edges **/ removeSimilarEdges(y, h); addSimilarEdges(y, h); // **/ if (this.verbose) { String message = "--- Directing " + oldyh + " to " + this.graph.getEdge(y, h); TetradLogger.getInstance().log(message); } Edge oldxh = this.graph.getEdge(x, h); if (Edges.isUndirectedEdge(oldxh)) { this.graph.removeEdge(oldxh); this.graph.addEdge(Edges.directedEdge(x, h)); // removing similar edges (which should be undirected) and adding similar directed edges **/ removeSimilarEdges(x, h); addSimilarEdges(x, h); // **/ if (this.verbose) { String message = "--- Directing " + oldxh + " to " + this.graph.getEdge(x, h); TetradLogger.getInstance().log(message); } } } return true; } /** * Test if the candidate insertion is a valid operation. *

* This method checks if the candidate insertion of a node `x` and node `y` is a valid operation based on certain * conditions. *

* The conditions checked are: *

    *
  1. If there is any knowledge that forbids the insertion of nodes `x` and `y`. *
  2. If there is any knowledge that forbids the insertion of any nodes in the set `T` with node `y`. *
  3. If the union of set `T` and set `naYX` forms a clique in the graph. *
  4. If there exists any unblocked semi-directed path from node `y` to node `x` with a cycle bound. *
* * @param x The node to be inserted. * @param y The existing node with which `x` is to be connected. * @param T The set of nodes in the graph. * @param naYX The set of non-adjacent nodes of `y` except `x`. * @return Returns true if the candidate insertion is valid, otherwise false. */ private boolean validInsert(Node x, Node y, Set T, Set naYX) { boolean violatesKnowledge = false; if (existsKnowledge()) { if (this.knowledge.isForbidden(x.getName(), y.getName())) { violatesKnowledge = true; } for (Node t : T) { if (this.knowledge.isForbidden(t.getName(), y.getName())) { violatesKnowledge = true; } } } Set union = new HashSet<>(T); union.addAll(naYX); boolean clique = GraphUtils.isClique(union, this.graph); int cycleBound = -1; boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound); return clique && noCycle && !violatesKnowledge; } /** * Validates if the delete operation is allowed based on the given parameters. * * @param x the first node involved in the delete operation * @param y the second node involved in the delete operation * @param H the set of nodes representing external knowledge * @param naYX the set of nodes that are neighbors to y and adjacent to x. * @return true if the delete operation is valid, false otherwise */ private boolean validDelete(Node x, Node y, Set H, Set naYX) { boolean violatesKnowledge = false; if (existsKnowledge()) { for (Node h : H) { if (this.knowledge.isForbidden(x.getName(), h.getName())) { violatesKnowledge = true; } if (this.knowledge.isForbidden(y.getName(), h.getName())) { violatesKnowledge = true; } } } Set diff = new HashSet<>(naYX); diff.removeAll(H); return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge; } /** * Adds edges required by knowledge. * * @param graph the graph to add the required edges to */ private void addRequiredEdges(Graph graph) { if (!existsKnowledge()) return; for (Iterator it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) { if (Thread.currentThread().isInterrupted()) { break; } KnowledgeEdge next = it.next(); Node nodeA = graph.getNode(next.getFrom()); Node nodeB = graph.getNode(next.getTo()); if (!graph.paths().isAncestorOf(nodeB, nodeA)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeA, nodeB); if (verbose) { String message = "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB); TetradLogger.getInstance().log(message); } } } for (Edge edge : graph.getEdges()) { if (Thread.currentThread().isInterrupted()) { break; } String A = edge.getNode1().getName(); String B = edge.getNode2().getName(); if (this.knowledge.isForbidden(A, B)) { Node nodeA = edge.getNode1(); Node nodeB = edge.getNode2(); if (graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.paths().isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); if (verbose) { String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA); TetradLogger.getInstance().log(message); } } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.paths().isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); if (verbose) { String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA); TetradLogger.getInstance().log(message); } } } } else if (this.knowledge.isForbidden(B, A)) { Node nodeA = edge.getNode2(); Node nodeB = edge.getNode1(); if (graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.paths().isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); if (verbose) { String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA); TetradLogger.getInstance().log(message); } } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.paths().isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); if (verbose) { String message = "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA); TetradLogger.getInstance().log(message); } } } } } } /** * Determines if the given subset of nodes contains any node that, when oriented towards the specified node 'y', * violates any forbidden directions according to prior knowledge. * * @param y the node towards which the orientation is being checked * @param subset the set of nodes to be checked for invalid orientations * @return true if any node in the subset violates a forbidden direction when oriented towards 'y', false otherwise */ private boolean invalidSetByKnowledge(Node y, Set subset) { for (Node node : subset) { if (getKnowledge().isForbidden(node.getName(), y.getName())) { return true; } } return false; } /** * Finds all nodes that are neighbors to y and adjacent to x. * * @param x The first node (X). * @param y The second node (Y). * @return A set of nodes that are neighbors to y and adjacent to x. */ private Set getNaYX(Node x, Node y) { List adj = this.graph.getAdjacentNodes(y); Set nayx = new HashSet<>(); for (Node z : adj) { if (z == x) continue; Edge yz = this.graph.getEdge(y, z); if (!Edges.isUndirectedEdge(yz)) continue; if (!this.graph.isAdjacentTo(z, x)) continue; nayx.add(z); } return nayx; } /** * Returns true if a path consisting of undirected and directed edges toward 'to' exists of length at most 'bound.' * Cycle checker in other words. * * @param from the starting node * @param to the target node * @param cond the set of nodes to be ignored during traversal * @param bound the maximum length of the path. If -1, there is no limit. * @return true if an unblocked semi-directed path exists, false otherwise. */ private boolean existsUnblockedSemiDirectedPath(Node from, Node to, Set cond, int bound) { Queue Q = new LinkedList<>(); Set V = new HashSet<>(); Q.offer(from); V.add(from); Node e = null; int distance = 0; while (!Q.isEmpty()) { Node t = Q.remove(); if (t == to) { return true; } if (e == t) { e = null; distance++; if (distance > (bound == -1 ? 1000 : bound)) return false; } for (Node u : this.graph.getAdjacentNodes(t)) { if (Thread.currentThread().isInterrupted()) { break; } Edge edge = this.graph.getEdge(t, u); Node c = SvarFges.traverseSemiDirected(t, edge); if (c == null) continue; if (cond.contains(c)) continue; if (c == to) { return true; } if (!V.contains(c)) { V.add(c); Q.offer(c); if (e == null) { e = u; } } } } return false; } /** * Runs Meek rules on just the changed adj. * * @param knowledge the knowledge used for orienting implied edges by MeekRules * @return a set of nodes representing the results of orienting implied edges */ private Set meekOrientRestricted(Knowledge knowledge) { MeekRules rules = new MeekRules(); rules.setKnowledge(knowledge); return rules.orientImplied(this.graph); } /** * Builds the indexing of the nodes for quick lookup. * * @param nodes The list of nodes. */ private void buildIndexing(List nodes) { this.hashIndices = new ConcurrentHashMap<>(); int i = -1; for (Node n : nodes) { this.hashIndices.put(n, ++i); } } /** * Removes information associated with an edge x->y. * * @param x the source node of the edge * @param y the target node of the edge */ private synchronized void clearArrow(Node x, Node y) { OrderedPair pair = new OrderedPair<>(x, y); Set lookupArrows = this.lookupArrows.get(pair); if (lookupArrows != null) { this.sortedArrows.removeAll(lookupArrows); } this.lookupArrows.remove(pair); } /** * Adds the given arrow for the adjacency i->j. * * @param i The starting node of the arrow. * @param j The ending node of the arrow. * @param arrow The arrow to be added. */ private void addLookupArrow(Node i, Node j, Arrow arrow) { OrderedPair pair = new OrderedPair<>(i, j); Set arrows = this.lookupArrows.get(pair); if (arrows == null) { arrows = new ConcurrentSkipListSet<>(); this.lookupArrows.put(pair, arrows); } arrows.add(arrow); } /** * Calculates the score graph change given a node, set of parent nodes, another node, and a hash map of node * indices. * * @param y the node for which the score graph change is calculated * @param parents the set of parent nodes * @param x the other node * @param hashIndices the hash map containing node indices * @return the score graph change as a double value */ private double scoreGraphChange(Node y, Set parents, Node x, Map hashIndices) { int yIndex = hashIndices.get(y); if (parents.contains(x)) return Double.NaN;//throw new IllegalArgumentException(); int[] parentIndices = new int[parents.size()]; int count = 0; for (Node parent : parents) { parentIndices[count++] = hashIndices.get(parent); } return this.score.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); } private List getVariables() { return this.variables; } /** * Stores the current graph if its total score is high enough to be considered as one of the top graphs. *

* If the number of graphs to store is greater than zero, a copy of the current graph is added to the list of top * graphs along with its `totalScore`. The copy is created to prevent any subsequent modifications to affect the * stored graph. *

* If the list of top graphs exceeds the desired number of graphs to store, the lowest scored graph is removed from * the list. *

* Note: This method does not return any value. */ private void storeGraph() { if (getnumCPDAGsToStore() > 0) { Graph graphCopy = new EdgeListGraph(this.graph); this.topGraphs.addLast(new ScoredGraph(graphCopy, this.totalScore)); } if (this.topGraphs.size() == getnumCPDAGsToStore() + 1) { this.topGraphs.removeFirst(); } } /** * Returns pairs of similar nodes based on the given nodes x and y. * * @param x the first node * @param y the second node * @return the list of similar pairs of nodes */ private List> returnSimilarPairs(Node x, Node y) { if (x.getName().equals("time") || y.getName().equals("time")) { return new ArrayList<>(); } int ntiers = this.knowledge.getNumTiers(); int indx_tier = this.knowledge.isInWhichTier(x); int indy_tier = this.knowledge.isInWhichTier(y); int tier_diff = FastMath.max(indx_tier, indy_tier) - FastMath.min(indx_tier, indy_tier); int indx_comp = -1; int indy_comp = -1; List tier_x = this.knowledge.getTier(indx_tier); List tier_y = this.knowledge.getTier(indy_tier); int i; for (i = 0; i < tier_x.size(); ++i) { if (getNameNoLag(x.getName()).equals(getNameNoLag(tier_x.get(i)))) { indx_comp = i; break; } } for (i = 0; i < tier_y.size(); ++i) { if (getNameNoLag(y.getName()).equals(getNameNoLag(tier_y.get(i)))) { indy_comp = i; break; } } List simListX = new ArrayList<>(); List simListY = new ArrayList<>(); for (i = 0; i < ntiers - tier_diff; ++i) { if (this.knowledge.getTier(i).size() == 1) continue; String A; Node x1; String B; Node y1; List tmp_tier1; List tmp_tier2; if (indx_tier >= indy_tier) { tmp_tier1 = this.knowledge.getTier(i + tier_diff); tmp_tier2 = this.knowledge.getTier(i); } else { tmp_tier1 = this.knowledge.getTier(i); tmp_tier2 = this.knowledge.getTier(i + tier_diff); } A = tmp_tier1.get(indx_comp); B = tmp_tier2.get(indy_comp); if (A.equals(B)) continue; if (A.equals(tier_x.get(indx_comp)) && B.equals(tier_y.get(indy_comp))) continue; if (B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue; x1 = this.graph.getNode(A); y1 = this.graph.getNode(B); simListX.add(x1); simListY.add(y1); } List> pairList = new ArrayList<>(); pairList.add(simListX); pairList.add(simListY); return (pairList); } /** * Retrieves the name from the given object without any lag. * * @param obj the object from which to retrieve the name * @return the name extracted from the object */ public String getNameNoLag(Object obj) { String tempS = obj.toString(); if (tempS.indexOf(':') == -1) { return tempS; } else return tempS.substring(0, tempS.indexOf(':')); } /** * Adds similar edges between two nodes. * * @param x The first node. * @param y The second node. */ public void addSimilarEdges(Node x, Node y) { List> simList = returnSimilarPairs(x, y); if (simList.isEmpty()) return; List x1List = simList.get(0); List y1List = simList.get(1); Iterator itx = x1List.iterator(); Iterator ity = y1List.iterator(); while (itx.hasNext() && ity.hasNext()) { Node x1 = itx.next(); Node y1 = ity.next(); this.graph.addDirectedEdge(x1, y1); } } /** * Removes similar edges between two nodes. * * @param x the first node * @param y the second node */ public void removeSimilarEdges(Node x, Node y) { List> simList = returnSimilarPairs(x, y); if (simList.isEmpty()) return; List x1List = simList.get(0); List y1List = simList.get(1); Iterator itx = x1List.iterator(); Iterator ity = y1List.iterator(); while (itx.hasNext() && ity.hasNext()) { Node x1 = itx.next(); Node y1 = ity.next(); Edge oldxy = this.graph.getEdge(x1, y1); this.graph.removeEdge(oldxy); this.removedEdges.add(Edges.undirectedEdge(x1, y1)); } } /** * The Mode enum represents different modes/options for a particular algorithm. *

* It provides several options that can be used to configure the behavior of the algorithm. Each option has a brief * description explaining its purpose. */ private enum Mode { /** * Indicates whether unfaithfulness is allowed. */ allowUnfaithfulness, /** * Represents a mode option for the algorithm. *

* This option is used to specify whether to cover noncolliders during the processing or not. */ coverNoncolliders, /** * Represents a heuristic speedup option. *

* This option is used to enable or disable a heuristic speedup algorithm in a particular context. When this * option is enabled, the algorithm will attempt to speed up the processing by using heuristics. * * @see Mode#heuristicSpeedup */ heuristicSpeedup } // Basic data structure for an arrow a->b considered for addition or removal from the graph, together with // associated sets needed to make this determination. For both forward and backward direction, NaYX is needed. // For the forward direction, T neighbors are needed; for the backward direction, H neighbors are needed. // See Chickering (2002). The totalScore difference resulting from added in the edge (hypothetically) is recorded // as the "bump," private static class Arrow implements Comparable { private final double bump; private final Node a; private final Node b; private final Set hOrT; private final Set naYX; private final int index; /** * Represents an arrow in a graph connecting two nodes. * * @param bump the bump value of the arrow object * @param a a node object representing the 'a' node of the arrow object * @param b a node object representing the 'b' node of the arrow object * @param hOrT a set of nodes associated with H or T for the current arrow object * @param naYX a set of nodes associated with NaYX for the current arrow object * @param index the index value of the arrow object */ public Arrow(double bump, Node a, Node b, Set hOrT, Set naYX, int index) { this.bump = bump; this.a = a; this.b = b; this.hOrT = hOrT; this.naYX = naYX; this.index = index; } /** * Returns the bump value of the Arrow object. * * @return the bump value */ public double getBump() { return this.bump; } /** * Returns the Node object representing the 'a' node of the Arrow object. * * @return the Node 'a' */ public Node getA() { return this.a; } /** * Returns the Node object representing the 'b' node of the Arrow object. * * @return the Node 'b' */ public Node getB() { return this.b; } /** * Returns the set of nodes associated with H or T for the current Arrow object. * * @return the set of nodes associated with H or T */ public Set getHOrT() { return this.hOrT; } /** * Returns the set of nodes that are neighbors of y and adjacent to x for the current Arrow object. * * @return the set of nodes associated with NaYX */ public Set getNaYX() { return this.naYX; } /** * Compares this Arrow object with the specified Arrow object for order. Returns a negative integer, zero, or a * positive integer as this object is less than, equal to, or greater than the specified object. * * @param arrow the Arrow object to be compared * @return a negative integer, zero, or a positive integer as this object is less than, equal to, or greater * than the specified object */ public int compareTo(@NotNull Arrow arrow) { int compare = Double.compare(arrow.getBump(), getBump()); if (compare == 0) { return Integer.compare(getIndex(), arrow.getIndex()); } return compare; } public String toString() { return "Arrow<" + this.a + "->" + this.b + " bump = " + this.bump + " t/h = " + this.hOrT + " naYX = " + this.naYX + ">"; } public int getIndex() { return this.index; } } /** * An internal class representing a recursive task to initialize effect edges in the graph. */ private class NodeTaskEmptyGraph extends RecursiveTask { private final int from; private final int to; private final List nodes; private final Set emptySet; public NodeTaskEmptyGraph(int from, int to, List nodes, Set emptySet) { this.from = from; this.to = to; this.nodes = nodes; this.emptySet = emptySet; } @Override protected Boolean compute() { for (int i = this.from; i < this.to; i++) { if ((i + 1) % 1000 == 0) { SvarFges.this.count[0] += 1000; } Node y = this.nodes.get(i); SvarFges.this.neighbors.put(y, this.emptySet); for (int j = i + 1; j < this.nodes.size(); j++) { if (Thread.currentThread().isInterrupted()) { break; } Node x = this.nodes.get(j); if (existsKnowledge()) { if (getKnowledge().isForbidden(x.getName(), y.getName()) && getKnowledge().isForbidden(y.getName(), x.getName())) { continue; } if (invalidSetByKnowledge(y, this.emptySet)) { continue; } } if (SvarFges.this.adjacencies != null && !SvarFges.this.adjacencies.isAdjacentTo(x, y)) { continue; } int child = SvarFges.this.hashIndices.get(y); int parent = SvarFges.this.hashIndices.get(x); double bump = SvarFges.this.score.localScoreDiff(parent, child); if (bump > 0) { Edge edge = Edges.undirectedEdge(x, y); SvarFges.this.effectEdgesGraph.addEdge(edge); } if (bump > 0.0) { addArrow(x, y, this.emptySet, this.emptySet, bump); addArrow(y, x, this.emptySet, this.emptySet, bump); } } } return true; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy