edu.cmu.tetrad.search.work_in_progress.Ion Maven / Gradle / Ivy
The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search.work_in_progress;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.PossibleMConnectingPath;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.*;
/**
* Implements the ION (Integration of Overlapping Networks) algorithm for distributed causal inference. The algorithm
* takes as input a set of PAGs (presumably learned using a local learning algorithm) over variable sets that may have
* some variables in common and others not in common. The algorithm returns a complete set of PAGs over every variable
* form an input PAG_of_the_true_DAG that are consistent (same d-separations and d-connections) with every input
* PAG_of_the_true_DAG.
*
* @author Robert Tillman
* @author josephramsey
* @version $Id: $Id
*/
public class Ion {
/**
* The input PAGs being to be intergrated, possibly FCI outputs.
*/
private final List input = new ArrayList<>();
/**
* The output PAGs over all variables consistent with the input PAGs
*/
private final List output = new ArrayList<>();
/**
* All the variables being integrated from the input PAGs
*/
private final List variables = new ArrayList<>();
/**
* Definite noncolliders
*/
private final Set definiteNoncolliders = new HashSet<>();
private final Set discrimGraphs = new HashSet<>();
private final Set finalResult = new HashSet<>();
// running runtime and time and size information for hitting sets
private final List recGraphs = new ArrayList<>();
private final List recHitTimes = new ArrayList<>();
// prune using path length
private boolean pathLengthSearch = true;
// prune using adjacencies
private boolean doAdjacencySearch;
/**
* separations and associations found in the input PAGs
*/
private Set separations;
/**
* tracks changes for final orientations orientation methods
*/
private boolean changeFlag = true;
private double runtime;
// maximum memory usage
private double maxMemory;
// knowledge if available.
private Knowledge knowledge = new Knowledge();
//============================= Constructor ============================//
/**
* Constructs a new instance of the ION search from the input PAGs
*
* @param pags The PAGs to be integrated
*/
public Ion(List pags) {
this.input.addAll(pags);
for (Graph pag : this.input) {
for (Node node : pag.getNodes()) {
if (!this.variables.contains(node.getName())) {
this.variables.add(node.getName());
}
}
for (Triple triple : getAllTriples(pag)) {
if (pag.isDefNoncollider(triple.getX(), triple.getY(), triple.getZ())) {
pag.addUnderlineTriple(triple.getX(), triple.getY(), triple.getZ());
}
}
}
}
//============================= Public Methods ============================//
/**
* treks.
*
* @param graph a {@link edu.cmu.tetrad.graph.Graph} object
* @param node1 a {@link edu.cmu.tetrad.graph.Node} object
* @param node2 a {@link edu.cmu.tetrad.graph.Node} object
* @return a {@link java.util.List} object
*/
public static List> treks(Graph graph, Node node1, Node node2) {
List> paths = new LinkedList<>();
Ion.treks(graph, node1, node2, new LinkedList<>(), paths);
return paths;
}
/**
* Constructs the list of treks between node1 and node2.
*/
private static void treks(Graph graph, Node node1, Node node2,
LinkedList path, List> paths) {
path.addLast(node1);
for (Edge edge : graph.getEdges(node1)) {
Node next = Edges.traverse(node1, edge);
if (next == null) {
continue;
}
if (path.size() > 1) {
Node node0 = path.get(path.size() - 2);
if (next == node0) {
continue;
}
if (graph.isDefCollider(node0, node1, next)) {
continue;
}
}
if (next == node2) {
LinkedList _path = new LinkedList<>(path);
_path.add(next);
paths.add(_path);
continue;
}
if (path.contains(next)) {
continue;
}
Ion.treks(graph, next, node2, path, paths);
}
path.removeLast();
}
/**
* Sets path length search on or off.
*
* @param doPathLengthSearch True if on.
*/
public void setDoPathLengthSearch(boolean doPathLengthSearch) {
this.pathLengthSearch = doPathLengthSearch;
}
/**
* Sets adjacency search on or off
*
* @param doAdjacencySearch True if on.
*/
public void setDoAdjacencySearch(boolean doAdjacencySearch) {
this.doAdjacencySearch = doAdjacencySearch;
}
/**
* Sets the knowledge to be used for this search.
*
* @param knowledge This knowledge.
*/
public void setKnowledge(Knowledge knowledge) {
if (knowledge == null) {
throw new NullPointerException("Knowledge must not be null.");
}
this.knowledge = knowledge;
}
/**
* Runs the ION search and returns a list of compatible graphs.
*
* @return These graphs.
*/
public List search() {
long start = MillisecondTimes.timeMillis();
TetradLogger.getInstance().log("Starting ION Search.");
logGraphs("\nInitial Pags: ", this.input);
TetradLogger.getInstance().log("Transfering local information.");
long steps = MillisecondTimes.timeMillis();
/*
* Step 1 - Create the empty graph
*/
List varNodes = new ArrayList<>();
for (String varName : this.variables) {
varNodes.add(new GraphNode(varName));
}
Graph graph = new EdgeListGraph(varNodes);
/*
* Step 2 - Transfer local information from the PAGs (adjacencies
* and edge orientations)
*/
// transfers edges from each graph and finds definite noncolliders
transferLocal(graph);
// adds edges for variables never jointly measured
for (NodePair pair : nonIntersection(graph)) {
graph.addEdge(new Edge(pair.getFirst(), pair.getSecond(), Endpoint.CIRCLE, Endpoint.CIRCLE));
}
String message3 = "Steps 1-2: " + (MillisecondTimes.timeMillis() - steps) / 1000. + "s";
TetradLogger.getInstance().log(message3);
System.out.println("step2");
System.out.println(graph);
/*
* Step 3
*
* Branch and prune step that blocks problematic undirectedPaths, possibly d-connecting undirectedPaths
*/
steps = MillisecondTimes.timeMillis();
Queue searchPags = new LinkedList<>();
// place graph constructed in step 2 into the queue
searchPags.offer(graph);
// get d-separations and d-connections
List> sepAndAssoc = findSepAndAssoc(graph);
this.separations = sepAndAssoc.get(0);
Set associations = sepAndAssoc.get(1);
Map, List> paths;
// Queue step3PagsSet = new LinkedList();
HashSet step3PagsSet = new HashSet<>();
Set reject = new HashSet<>();
// if no d-separations, nothing left to search
if (this.separations.isEmpty()) {
// makes orientations preventing definite noncolliders from becoming colliders
// do final orientations
// doFinalOrientation(graph);
step3PagsSet.add(graph);
}
// sets length to iterate once if search over path lengths not enabled, otherwise set to 2
int numNodes = graph.getNumNodes();
int pl = numNodes - 1;
if (this.pathLengthSearch) {
pl = 2;
}
// iterates over path length, then adjacencies
for (int l = pl; l < numNodes; l++) {
if (this.pathLengthSearch) {
TetradLogger.getInstance().log("Braching over path lengths: " + l + " of " + (numNodes - 1));
}
int seps = this.separations.size();
final int currentSep = 1;
int numAdjacencies = this.separations.size();
for (IonIndependenceFacts fact : this.separations) {
if (this.doAdjacencySearch) {
TetradLogger.getInstance().log("Braching over path nonadjacencies: " + currentSep + " of " + numAdjacencies);
}
seps--;
// uses two queues to keep up with which PAGs are being iterated and which have been
// accepted to be iterated over in the next iteration of the above for loop
searchPags.addAll(step3PagsSet);
this.recGraphs.add(searchPags.size());
step3PagsSet.clear();
while (!searchPags.isEmpty()) {
System.out.println("ION Step 3 size: " + searchPags.size());
double currentUsage = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
if (currentUsage > this.maxMemory) this.maxMemory = currentUsage;
// deques first PAG from searchPags
Graph pag = searchPags.poll();
// Part 3.a - finds possibly d-connecting undirectedPaths between each pair of nodes
// known to be m-separated
List mConnections = new ArrayList<>();
// checks to see if looping over adjacencies
if (this.doAdjacencySearch) {
for (Collection conditions : fact.getZ()) {
// checks to see if looping over path lengths
if (this.pathLengthSearch) {
mConnections.addAll(PossibleMConnectingPath.findMConnectingPathsOfLength
(pag, fact.getX(), fact.getY(), conditions, l));
} else {
mConnections.addAll(PossibleMConnectingPath.findMConnectingPaths
(pag, fact.getX(), fact.getY(), conditions));
}
}
} else {
for (IonIndependenceFacts allfact : this.separations) {
for (Collection conditions : allfact.getZ()) {
// checks to see if looping over path lengths
if (this.pathLengthSearch) {
mConnections.addAll(PossibleMConnectingPath.findMConnectingPathsOfLength
(pag, allfact.getX(), allfact.getY(), conditions, l));
} else {
mConnections.addAll(PossibleMConnectingPath.findMConnectingPaths
(pag, allfact.getX(), allfact.getY(), conditions));
}
}
}
}
// accept PAG_of_the_true_DAG go to next PAG_of_the_true_DAG if no possibly d-connecting undirectedPaths
if (mConnections.isEmpty()) {
step3PagsSet.add(pag);
continue;
}
// maps conditioning sets to list of possibly d-connecting undirectedPaths
paths = new HashMap<>();
for (PossibleMConnectingPath path : mConnections) {
List p = paths.get(path.getConditions());
if (p == null) {
p = new LinkedList<>();
}
p.add(path);
paths.put(path.getConditions(), p);
}
// Part 3.b - finds minimal graphical changes to block possibly d-connecting undirectedPaths
List> possibleChanges = new ArrayList<>();
for (Set changes : findChanges(paths)) {
Set newChanges = new HashSet<>();
for (GraphChange gc : changes) {
boolean okay = true;
for (Triple collider : gc.getColliders()) {
if (pag.isUnderlineTriple(collider.getX(), collider.getY(), collider.getZ())) {
okay = false;
break;
}
}
if (!okay) {
continue;
}
for (Triple collider : gc.getNoncolliders()) {
if (pag.isDefCollider(collider.getX(), collider.getY(), collider.getZ())) {
okay = false;
break;
}
}
if (okay) {
newChanges.add(gc);
}
}
if (!newChanges.isEmpty()) {
possibleChanges.add(newChanges);
} else {
possibleChanges.clear();
break;
}
}
float starthitset = MillisecondTimes.timeMillis();
Collection hittingSets = IonHittingSet.findHittingSet(possibleChanges);
this.recHitTimes.add((MillisecondTimes.timeMillis() - starthitset) / 1000.);
// Part 3.c - checks the newly constructed graphs from 3.b and rejects those that
// cycles or produce independencies known not to occur from the input PAGs or
// include undirectedPaths from definite nonancestors
for (GraphChange gc : hittingSets) {
boolean badhittingset = false;
for (Edge edge : gc.getRemoves()) {
Node node1 = edge.getNode1();
Node node2 = edge.getNode2();
Set triples = new HashSet<>();
triples.addAll(gc.getColliders());
triples.addAll(gc.getNoncolliders());
if (triples.size() != (gc.getColliders().size() + gc.getNoncolliders().size())) {
badhittingset = true;
break;
}
for (Triple triple : triples) {
if (node1.equals(triple.getY())) {
if (node2.equals(triple.getX()) ||
node2.equals(triple.getZ())) {
badhittingset = true;
break;
}
}
if (node2.equals(triple.getY())) {
if (node1.equals(triple.getX()) ||
node1.equals(triple.getZ())) {
badhittingset = true;
break;
}
}
}
if (badhittingset) {
break;
}
for (NodePair pair : gc.getOrients()) {
if ((node1.equals(pair.getFirst()) && node2.equals(pair.getSecond())) ||
(node2.equals(pair.getFirst()) && node1.equals(pair.getSecond()))) {
badhittingset = true;
break;
}
}
if (badhittingset) {
break;
}
}
if (!badhittingset) {
for (NodePair pair : gc.getOrients()) {
for (Triple triple : gc.getNoncolliders()) {
if (pair.getSecond().equals(triple.getY())) {
if (pair.getFirst().equals(triple.getX()) &&
pag.getEndpoint(triple.getZ(), triple.getY()).equals(Endpoint.ARROW)) {
badhittingset = true;
break;
}
if (pair.getFirst().equals(triple.getZ()) &&
pag.getEndpoint(triple.getX(), triple.getY()).equals(Endpoint.ARROW)) {
badhittingset = true;
break;
}
}
if (badhittingset) {
break;
}
}
if (badhittingset) {
break;
}
}
}
if (badhittingset) {
continue;
}
Graph changed = gc.applyTo(pag);
// if graph change has already been rejected move on to next graph
if (reject.contains(changed)) {
continue;
}
// if graph change has already been accepted move on to next graph
if (step3PagsSet.contains(changed)) {
continue;
}
// reject if null, predicts false independencies or has cycle
if (predictsFalseIndependence(associations, changed)
|| changed.paths().existsDirectedCycle()) {
reject.add(changed);
}
// makes orientations preventing definite noncolliders from becoming colliders
// do final orientations
// doFinalOrientation(changed);
// now add graph to queue
step3PagsSet.add(changed);
}
}
// exits loop if not looping over adjacencies
if (!this.doAdjacencySearch) {
break;
}
}
}
String message2 = "Step 3: " + (MillisecondTimes.timeMillis() - steps) / 1000. + "s";
TetradLogger.getInstance().log(message2);
Queue step3Pags = new LinkedList<>(step3PagsSet);
/*
* Step 4
*
* Finds redundant undirectedPaths and uses this information to expand the list
* of possible graphs
*/
steps = MillisecondTimes.timeMillis();
Map necEdges;
Set outputPags = new HashSet<>();
while (!step3Pags.isEmpty()) {
Graph pag = step3Pags.poll();
necEdges = new HashMap<>();
// Step 4.a - if x and y are known to be unconditionally associated and there is
// exactly one trek between them, mark each edge on that trek as necessary and
// make the tiples on the trek definite noncolliders
// initially mark each edge as not necessary
for (Edge edge : pag.getEdges()) {
necEdges.put(edge, false);
}
// look for unconditional associations
for (IonIndependenceFacts fact : associations) {
for (Set nodes : fact.getZ()) {
if (nodes.isEmpty()) {
List> treks = Ion.treks(pag, fact.x, fact.y);
if (treks.size() == 1) {
List trek = treks.get(0);
List triples = new ArrayList<>();
for (int i = 1; i < trek.size(); i++) {
// marks each edge in trek as necessary
necEdges.put(pag.getEdge(trek.get(i - 1), trek.get(i)), true);
if (i == 1) {
continue;
}
// makes each triple a noncollider
pag.addUnderlineTriple(trek.get(i - 2), trek.get(i - 1), trek.get(i));
}
}
// stop looping once the empty set is found
break;
}
}
}
// Part 4.b - branches by generating graphs for every combination of removing
// redundant undirectedPaths
boolean elimTreks;
// checks to see if removing redundant undirectedPaths eliminates every trek between
// two variables known to be nconditionally assoicated
List possRemovePags = possRemove(pag, necEdges);
double currentUsage = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
if (currentUsage > this.maxMemory) this.maxMemory = currentUsage;
for (Graph newPag : possRemovePags) {
elimTreks = false;
// looks for unconditional associations
for (IonIndependenceFacts fact : associations) {
for (Set nodes : fact.getZ()) {
if (nodes.isEmpty()) {
if (Ion.treks(newPag, fact.x, fact.y).isEmpty()) {
elimTreks = true;
}
// stop looping once the empty set is found
break;
}
}
}
// add new PAG to output unless a necessary trek has been eliminated
if (!elimTreks) {
outputPags.add(newPag);
}
}
}
outputPags = removeMoreSpecific(outputPags);
// outputPags = applyKnowledge(outputPags);
String message1 = "Step 4: " + (MillisecondTimes.timeMillis() - steps) / 1000. + "s";
TetradLogger.getInstance().log(message1);
/*
* Step 5
*
* Generate the Markov equivalence classes for graphs and accept only
* those that do not predict false d-separations
*/
steps = MillisecondTimes.timeMillis();
Set outputSet = new HashSet<>();
for (Graph pag : outputPags) {
Set unshieldedPossibleColliders = new HashSet<>();
for (Triple triple : getPossibleTriples(pag)) {
if (!pag.isAdjacentTo(triple.getX(), triple.getZ())) {
unshieldedPossibleColliders.add(triple);
}
}
PowerSet pset = new PowerSet<>(unshieldedPossibleColliders);
for (Set set : pset) {
Graph newGraph = new EdgeListGraph(pag);
for (Triple triple : set) {
newGraph.setEndpoint(triple.getX(), triple.getY(), Endpoint.ARROW);
newGraph.setEndpoint(triple.getZ(), triple.getY(), Endpoint.ARROW);
}
doFinalOrientation(newGraph);
}
for (Graph outputPag : this.finalResult) {
if (!predictsFalseIndependence(associations, outputPag)) {
Set underlineTriples = new HashSet<>(outputPag.getUnderLines());
for (Triple triple : underlineTriples) {
outputPag.removeUnderlineTriple(triple.getX(), triple.getY(), triple.getZ());
}
outputSet.add(outputPag);
}
}
}
// outputSet = applyKnowledge(outputSet);
outputSet = checkPaths(outputSet);
this.output.addAll(outputSet);
String message = "Step 5: " + (MillisecondTimes.timeMillis() - steps) / 1000. + "s";
TetradLogger.getInstance().log(message);
this.runtime = ((MillisecondTimes.timeMillis() - start) / 1000.);
logGraphs("\nReturning output (" + this.output.size() + " Graphs):", this.output);
double currentUsage = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
if (currentUsage > this.maxMemory) this.maxMemory = currentUsage;
return this.output;
}
// return hitting set sizes
/**
* Getter for the field runtime
.
*
* @return The total runtime and times for hitting set calculations.
*/
public List getRuntime() {
double totalhit = 0;
double longesthit = 0;
double averagehit = 0;
for (Double i : this.recHitTimes) {
totalhit += i;
averagehit += i / this.recHitTimes.size();
if (i > longesthit) {
longesthit = i;
}
}
List list = new ArrayList<>();
list.add(Double.toString(this.runtime));
list.add(Double.toString(totalhit));
list.add(Double.toString(longesthit));
list.add(Double.toString(averagehit));
return list;
}
// summarizes time and hitting set time and size information for latex
/**
* getMaxMemUsage.
*
* @return The maximum memory used in a run of ION
*/
public double getMaxMemUsage() {
return this.maxMemory;
}
//============================= Private Methods ============================//
/**
* getIterations.
*
* @return a {@link java.util.List} object
*/
public List getIterations() {
int totalit = 0;
int largestit = 0;
int averageit = 0;
for (Integer i : this.recGraphs) {
totalit += i;
averageit += i / this.recGraphs.size();
if (i > largestit) {
largestit = i;
}
}
List list = new ArrayList<>();
list.add(totalit);
list.add(largestit);
list.add(averageit);
return list;
}
/**
* Summarizes time and hitting set time and size information for latex
*
* @return A string summarizing this information.
*/
public String getStats() {
String stats = "Total running time: " + this.runtime + "\\\\";
int totalit = 0;
int largestit = 0;
int averageit = 0;
for (Integer i : this.recGraphs) {
totalit += i;
averageit += i;
if (i > largestit) {
largestit = i;
}
}
averageit /= this.recGraphs.size();
double totalhit = 0;
double longesthit = 0;
double averagehit = 0;
for (Double i : this.recHitTimes) {
totalhit += i;
averagehit += i / this.recHitTimes.size();
if (i > longesthit) {
longesthit = i;
}
}
stats += "Total iterations in step 3: " + totalit + "\\\\";
stats += "Largest set of iterations in step 3: " + largestit + "\\\\";
stats += "Average iterations set in step 3: " + averageit + "\\\\";
stats += "Total hitting sets calculation time: " + totalhit + "\\\\";
stats += "Average hitting set calculation time: " + averagehit + "\\\\";
stats += "Longest hitting set calculation time: " + longesthit + "\\\\";
return stats;
}
/**
* Logs a set of graphs with a corresponding message
*/
private void logGraphs(String message, List extends Graph> graphs) {
if (message != null) {
TetradLogger.getInstance().log(message);
}
for (Graph graph : graphs) {
String message1 = graph.toString();
TetradLogger.getInstance().log(message1);
}
}
/**
* Generates NodePairs of all possible pairs of nodes from given list of nodes.
*/
private List allNodePairs(List nodes) {
List nodePairs = new ArrayList<>();
for (int j = 0; j < nodes.size() - 1; j++) {
for (int k = j + 1; k < nodes.size(); k++) {
nodePairs.add(new NodePair(nodes.get(j), nodes.get(k)));
}
}
return nodePairs;
}
/*
* @return all triples in a graph
*/
/**
* Finds all node pairs that are not adjacent in an input graph
*/
private Set nonadjacencies(Graph graph) {
Set nonadjacencies = new HashSet<>();
for (Graph inputPag : this.input) {
for (NodePair pair : allNodePairs(inputPag.getNodes())) {
if (!inputPag.isAdjacentTo(pair.getFirst(), pair.getSecond())) {
nonadjacencies.add(new NodePair(graph.getNode(pair.getFirst().getName()), graph.getNode(pair.getSecond().getName())));
}
}
}
return nonadjacencies;
}
/**
* Transfers local information from the input PAGs by adding edges from local PAGs with their orientations and
* unorienting the edges if there is a conflict and recording definite noncolliders.
*/
private void transferLocal(Graph graph) {
Set nonadjacencies = nonadjacencies(graph);
for (Graph pag : this.input) {
for (Edge edge : pag.getEdges()) {
NodePair graphNodePair = new NodePair(graph.getNode(edge.getNode1().getName()), graph.getNode(edge.getNode2().getName()));
if (nonadjacencies.contains(graphNodePair)) {
continue;
}
if (!graph.isAdjacentTo(graphNodePair.getFirst(), graphNodePair.getSecond())) {
graph.addEdge(new Edge(graphNodePair.getFirst(), graphNodePair.getSecond(), edge.getEndpoint1(), edge.getEndpoint2()));
continue;
}
Endpoint first = edge.getEndpoint1();
Endpoint firstCurrent = graph.getEndpoint(graphNodePair.getSecond(), graphNodePair.getFirst());
if (!first.equals(Endpoint.CIRCLE)) {
if ((first.equals(Endpoint.ARROW) && firstCurrent.equals(Endpoint.TAIL)) ||
(first.equals(Endpoint.TAIL) && firstCurrent.equals(Endpoint.ARROW))) {
graph.setEndpoint(graphNodePair.getSecond(), graphNodePair.getFirst(), Endpoint.CIRCLE);
} else {
graph.setEndpoint(graphNodePair.getSecond(), graphNodePair.getFirst(), edge.getEndpoint1());
}
}
Endpoint second = edge.getEndpoint2();
Endpoint secondCurrent = graph.getEndpoint(graphNodePair.getFirst(), graphNodePair.getSecond());
if (!second.equals(Endpoint.CIRCLE)) {
if ((second.equals(Endpoint.ARROW) && secondCurrent.equals(Endpoint.TAIL)) ||
(second.equals(Endpoint.TAIL) && secondCurrent.equals(Endpoint.ARROW))) {
graph.setEndpoint(graphNodePair.getFirst(), graphNodePair.getSecond(), Endpoint.CIRCLE);
} else {
graph.setEndpoint(graphNodePair.getFirst(), graphNodePair.getSecond(), edge.getEndpoint2());
}
}
}
for (Triple triple : pag.getUnderLines()) {
Triple graphTriple = new Triple(graph.getNode(triple.getX().getName()), graph.getNode(triple.getY().getName()), graph.getNode(triple.getZ().getName()));
if (graphTriple.alongPathIn(graph)) {
graph.addUnderlineTriple(graphTriple.getX(), graphTriple.getY(), graphTriple.getZ());
this.definiteNoncolliders.add(graphTriple);
}
}
}
}
private Set getAllTriples(Graph graph) {
Set triples = new HashSet<>();
for (Node node : graph.getNodes()) {
List adjNodes = new ArrayList<>(graph.getAdjacentNodes(node));
for (int i = 0; i < adjNodes.size() - 1; i++) {
for (int j = i + 1; j < adjNodes.size(); j++) {
triples.add(new Triple(adjNodes.get(i), node, adjNodes.get(j)));
}
}
}
return triples;
}
/**
* @return variable pairs that are not in the intersection of the variable sets for any two input PAGs
*/
private List nonIntersection(Graph graph) {
List> varsets = new ArrayList<>();
for (Graph inputPag : this.input) {
Set varset = new HashSet<>();
for (Node node : inputPag.getNodes()) {
varset.add(node.getName());
}
varsets.add(varset);
}
List pairs = new ArrayList();
for (int i = 0; i < this.variables.size() - 1; i++) {
for (int j = i + 1; j < this.variables.size(); j++) {
boolean intersection = false;
for (Set varset : varsets) {
if (varset.containsAll(Arrays.asList(this.variables.get(i), this.variables.get(j)))) {
intersection = true;
break;
}
}
if (!intersection) {
pairs.add(new NodePair(graph.getNode(this.variables.get(i)), graph.getNode(this.variables.get(j))));
}
}
}
return pairs;
}
/**
* Finds the association or seperation sets for every pair of nodes.
*/
private List> findSepAndAssoc(Graph graph) {
Set separations = new HashSet<>();
Set associations = new HashSet<>();
List allNodes = allNodePairs(graph.getNodes());
for (NodePair pair : allNodes) {
Node x = pair.getFirst();
Node y = pair.getSecond();
List variables = new ArrayList<>(graph.getNodes());
variables.remove(x);
variables.remove(y);
List> subsets = GraphSearchUtils.powerSet(variables);
IonIndependenceFacts indep = new IonIndependenceFacts(x, y, new HashSet<>());
IonIndependenceFacts assoc = new IonIndependenceFacts(x, y, new HashSet<>());
boolean addIndep = false;
boolean addAssoc = false;
for (Graph pag : this.input) {
for (Set subset : subsets) {
if (containsAll(pag, subset, pair)) {
Node pagX = pag.getNode(x.getName());
Node pagY = pag.getNode(y.getName());
ArrayList pagSubset = new ArrayList<>();
for (Node node : subset) {
pagSubset.add(pag.getNode(node.getName()));
}
if (pag.paths().isMSeparatedFrom(pagX, pagY, new HashSet<>(pagSubset), false)) {
if (!pag.isAdjacentTo(pagX, pagY)) {
addIndep = true;
indep.addMoreZ(new HashSet<>(subset));
}
} else {
addAssoc = true;
assoc.addMoreZ(new HashSet<>(subset));
}
}
}
}
if (addIndep) separations.add(indep);
if (addAssoc) associations.add(assoc);
}
List> facts = new ArrayList<>(2);
facts.add(0, separations);
facts.add(1, associations);
return facts;
}
/**
* States whether the given graph contains the nodes in the given set and the node pair.
*/
private boolean containsAll(Graph g, Set nodes, NodePair pair) {
List nodeNames = g.getNodeNames();
if (!nodeNames.contains(pair.getFirst().getName()) || !nodeNames.contains(pair.getSecond().getName())) {
return false;
}
for (Node node : nodes) {
if (!nodeNames.contains(node.getName())) {
return false;
}
}
return true;
}
/**
* Checks given pag against a set of necessary associations to determine if the pag implies an indepedence where one
* is known to not exist.
*/
private boolean predictsFalseIndependence(Set associations, Graph pag) {
for (IonIndependenceFacts assocFact : associations)
for (Set conditioningSet : assocFact.getZ())
if (pag.paths().isMSeparatedFrom(
assocFact.getX(), assocFact.getY(), conditioningSet, false))
return true;
return false;
}
/**
* @return all the triples in the graph that can be either oriented as a collider or non-collider.
*/
private Set getPossibleTriples(Graph pag) {
Set possibleTriples = new HashSet<>();
for (Triple triple : getAllTriples(pag)) {
if (pag.isAdjacentTo(triple.getX(), triple.getY()) && pag.isAdjacentTo(triple.getY(), triple.getZ())
&& !pag.isUnderlineTriple(triple.getX(), triple.getY(), triple.getZ()) &&
!this.definiteNoncolliders.contains(triple) &&
!pag.isDefCollider(triple.getX(), triple.getY(), triple.getZ())) {
possibleTriples.add(triple);
}
}
return possibleTriples;
}
/**
* Given a map between sets of conditioned on variables and lists of PossibleMConnectingPaths, finds all the
* possible GraphChanges which could be used to block said undirectedPaths
*/
private List> findChanges(Map, List> paths) {
List> pagChanges = new ArrayList<>();
Set, List>> entries = paths.entrySet();
/* Loop through each entry, ie each conditioned set of variables. */
for (Map.Entry, List> entry : entries) {
Collection conditions = entry.getKey();
List mConnectng = entry.getValue();
/* loop through each path */
for (PossibleMConnectingPath possible : mConnectng) {
List possPath = possible.getPath();
/* Created with 2*# of undirectedPaths as appoximation. might have to increase size once */
Set pathChanges = new HashSet<>(2 * possPath.size());
/* find those conditions which are not along the path (used in colider) */
List outsidePath = new ArrayList<>(conditions.size());
for (Node condition : conditions) {
if (!possPath.contains(condition))
outsidePath.add(condition);
}
/* Walk through path, node by node */
for (int i = 0; i < possPath.size() - 1; i++) {
Node current = possPath.get(i);
Node next = possPath.get(i + 1);
GraphChange gc;
/* for each pair of nodes, add the operation to remove their edge */
gc = new GraphChange();
gc.addRemove(possible.getPag().getEdge(current, next));
pathChanges.add(gc);
/* for each triple centered on a node which is an element of the conditioning
* set, add the operation to orient as a nonColider around that node */
if (conditions.contains(current) && i > 0) {
gc = new GraphChange();
Triple nonColider = new Triple(possPath.get(i - 1), current, next);
gc.addNonCollider(nonColider);
pathChanges.add(gc);
}
/* for each node on the path not in the conditioning set, make a colider. It
* is necessary though to ensure that there are no undirectedPaths implying that a
* conditioned variable (even outside the path) is a decendant of a colider */
if ((!conditions.contains(current)) && i > 0) {
Triple colider = new Triple(possPath.get(i - 1), current, next);
if (possible.getPag().isUnderlineTriple(possPath.get(i - 1), current, next))
continue;
Edge edge1 = possible.getPag().getEdge(colider.getX(), colider.getY());
Edge edge2 = possible.getPag().getEdge(colider.getZ(), colider.getY());
if (edge1.getNode1().equals(colider.getY())) {
if (edge1.getEndpoint1().equals(Endpoint.TAIL)) {
continue;
}
} else if (edge1.getNode2().equals(colider.getY())) {
if (edge1.getEndpoint2().equals(Endpoint.TAIL)) {
continue;
}
}
if (edge2.getNode1().equals(colider.getY())) {
if (edge2.getEndpoint1().equals(Endpoint.TAIL)) {
continue;
}
} else if (edge2.getNode2().equals(colider.getY())) {
if (edge2.getEndpoint2().equals(Endpoint.TAIL)) {
continue;
}
}
/* Simple case, no conditions outside the path, so just add colider */
if (outsidePath.size() == 0) {
gc = new GraphChange();
gc.addCollider(colider);
pathChanges.add(gc);
continue;
}
/* ensure nondecendency in possible path between getModel and each conditioned
* variable outside the path */
for (Node outside : outsidePath) {
/* list of possible decendant undirectedPaths */
List decendantPaths = new ArrayList<>();
decendantPaths
= PossibleMConnectingPath.findMConnectingPaths
(possible.getPag(), current, outside, new ArrayList<>());
if (decendantPaths.isEmpty()) {
gc = new GraphChange();
gc.addCollider(colider);
pathChanges.add(gc);
continue;
}
/* loop over each possible path which might indicate decendency */
for (PossibleMConnectingPath decendantPDCPath : decendantPaths) {
List decendantPath = decendantPDCPath.getPath();
/* walk down path checking orientation (path may already
* imply non-decendency) and creating changes if need be*/
boolean impliesDecendant = true;
Set colideChanges = new HashSet<>();
for (int j = 0; j < decendantPath.size() - 1; j++) {
Node from = decendantPath.get(j);
// chaneges from +1
Node to = decendantPath.get(j + 1);
Edge currentEdge = possible.getPag().getEdge(from, to);
if (currentEdge.getEndpoint1().equals(Endpoint.ARROW)) {
impliesDecendant = false;
break;
}
gc = new GraphChange();
gc.addCollider(colider);
gc.addRemove(currentEdge);
colideChanges.add(gc);
gc = new GraphChange();
gc.addCollider(colider);
gc.addOrient(to, from);
colideChanges.add(gc);
}
if (impliesDecendant)
pathChanges.addAll(colideChanges);
}
}
}
}
pagChanges.add(pathChanges);
}
}
return pagChanges;
}
/**
* Constructs PossRemove, every combination of removing of not removing redudant undirectedPaths
*/
private List possRemove(Graph pag, Map necEdges) {
// list of edges that can be removed
List remEdges = new ArrayList<>();
for (Edge remEdge : necEdges.keySet()) {
if (!necEdges.get(remEdge))
remEdges.add(remEdge);
}
// powerset of edges that can be removed
PowerSet pset = new PowerSet<>(remEdges);
List possRemove = new ArrayList<>();
// for each set of edges in the powerset remove edges from graph and add to PossRemove
for (Set set : pset) {
Graph newPag = new EdgeListGraph(pag);
for (Edge edge : set) {
newPag.removeEdge(edge);
}
possRemove.add(newPag);
}
return possRemove;
}
/*
* Does the final set of orientations after colliders have been oriented
*/
private void doFinalOrientation(Graph graph) {
this.discrimGraphs.clear();
Set currentDiscrimGraphs = new HashSet<>();
currentDiscrimGraphs.add(graph);
while (this.changeFlag) {
this.changeFlag = false;
currentDiscrimGraphs.addAll(this.discrimGraphs);
this.discrimGraphs.clear();
for (Graph newGraph : currentDiscrimGraphs) {
doubleTriangle(newGraph);
awayFromColliderAncestorCycle(newGraph);
if (!discrimPaths(newGraph)) {
if (this.changeFlag) {
this.discrimGraphs.add(newGraph);
} else {
this.finalResult.add(newGraph);
}
}
}
currentDiscrimGraphs.clear();
}
this.changeFlag = true;
}
/**
* Implements the double-triangle orientation rule, which states that if D*-oB, A*->B<-*C and A*-*D*-*C is a
* noncollider, then D*->B.
*/
private void doubleTriangle(Graph graph) {
List nodes = graph.getNodes();
for (Node B : nodes) {
List intoBArrows = graph.getNodesInTo(B, Endpoint.ARROW);
List intoBCircles = graph.getNodesInTo(B, Endpoint.CIRCLE);
//possible A's and C's are those with arrows into B
List possA = new LinkedList<>(intoBArrows);
List possC = new LinkedList<>(intoBArrows);
//possible D's are those with circles into B
for (Node D : intoBCircles) {
for (Node A : possA) {
for (Node C : possC) {
if (C == A) {
continue;
}
//skip anything not a double triangle
if (!graph.isAdjacentTo(A, D) ||
!graph.isAdjacentTo(C, D)) {
continue;
}
//skip if A,D,C is a collider
if (graph.isDefCollider(A, D, C)) {
continue;
}
//if all of the previous tests pass, orient D*-oB as D*->B
if (!isArrowheadAllowed(graph, D, B)) {
continue;
}
graph.setEndpoint(D, B, Endpoint.ARROW);
this.changeFlag = true;
}
}
}
}
}
// Does only the ancestor and cycle rules of these repeatedly until no changes
private void awayFromAncestorCycle(Graph graph) {
while (this.changeFlag) {
this.changeFlag = false;
List nodes = graph.getNodes();
for (Node B : nodes) {
List adj = new ArrayList<>(graph.getAdjacentNodes(B));
if (adj.size() < 2) {
continue;
}
ChoiceGenerator cg = new ChoiceGenerator(adj.size(), 2);
int[] combination;
while ((combination = cg.next()) != null) {
Node A = adj.get(combination[0]);
Node C = adj.get(combination[1]);
//choice gen doesnt do diff orders, so must switch A & C around.
awayFromAncestor(graph, A, B, C);
awayFromAncestor(graph, C, B, A);
awayFromCycle(graph, A, B, C);
awayFromCycle(graph, C, B, A);
}
}
}
this.changeFlag = true;
}
// Does all 3 of these rules at once instead of going through all
// triples multiple times per iteration of doFinalOrientation.
private void awayFromColliderAncestorCycle(Graph graph) {
List nodes = graph.getNodes();
for (Node B : nodes) {
List adj = new ArrayList<>(graph.getAdjacentNodes(B));
if (adj.size() < 2) {
continue;
}
ChoiceGenerator cg = new ChoiceGenerator(adj.size(), 2);
int[] combination;
while ((combination = cg.next()) != null) {
Node A = adj.get(combination[0]);
Node C = adj.get(combination[1]);
//choice gen doesnt do diff orders, so must switch A & C around.
ruleR1(A, B, C, graph);
ruleR1(C, B, A, graph);
ruleR2(A, B, C, graph);
ruleR2(C, B, A, graph);
}
}
}
/// R1, away from collider
private void ruleR1(Node a, Node b, Node c, Graph graph) {
if (graph.isAdjacentTo(a, c)) {
return;
}
if (graph.getEndpoint(a, b) == Endpoint.ARROW && graph.getEndpoint(c, b) == Endpoint.CIRCLE) {
if (!isArrowheadAllowed(graph, b, c)) {
return;
}
graph.setEndpoint(c, b, Endpoint.TAIL);
graph.setEndpoint(b, c, Endpoint.ARROW);
}
}
// if a*->Bo-oC and not a*-*c, then a*->b-->c
// (orient either circle if present, don't need both)
//if Ao->c and a-->b-->c, then a-->c
// Zhang's rule R2, awy from ancestor.
private void ruleR2(Node a, Node b, Node c, Graph graph) {
if (!graph.isAdjacentTo(a, c)) {
return;
}
if (graph.getEndpoint(b, a) == Endpoint.TAIL && graph.getEndpoint(a, b) == Endpoint.ARROW
&& graph.getEndpoint(b, c) == Endpoint.ARROW && graph.getEndpoint(a, c) == Endpoint.CIRCLE) {
if (!isArrowheadAllowed(graph, a, c)) {
return;
}
graph.setEndpoint(a, c, Endpoint.ARROW);
} else if (graph.getEndpoint(a, b) == Endpoint.ARROW && graph.getEndpoint(c, b) == Endpoint.TAIL
&& graph.getEndpoint(b, c) == Endpoint.ARROW && graph.getEndpoint(a, c) == Endpoint.CIRCLE
) {
if (!isArrowheadAllowed(graph, a, c)) {
return;
}
graph.setEndpoint(a, c, Endpoint.ARROW);
}
}
//if a*-oC and either a-->b*->c or a*->b-->c, then a*->c
private boolean isArrowheadAllowed(Graph graph, Node x, Node y) {
if (graph.getEndpoint(x, y) == Endpoint.ARROW) {
return true;
}
if (graph.getEndpoint(x, y) == Endpoint.TAIL) {
return false;
}
if (graph.getEndpoint(y, x) == Endpoint.ARROW) {
graph.getEndpoint(x, y);
}
return true;
}
private void awayFromCollider(Graph graph, Node a, Node b, Node c) {
Endpoint BC = graph.getEndpoint(b, c);
Endpoint CB = graph.getEndpoint(c, b);
if (!(graph.isAdjacentTo(a, c)) &&
(graph.getEndpoint(a, b) == Endpoint.ARROW)) {
if (CB == Endpoint.CIRCLE || CB == Endpoint.TAIL) {
if (BC == Endpoint.CIRCLE) {
if (!isArrowheadAllowed(graph, b, c)) {
return;
}
graph.setEndpoint(b, c, Endpoint.ARROW);
this.changeFlag = true;
}
}
if (BC == Endpoint.CIRCLE || BC == Endpoint.ARROW) {
if (CB == Endpoint.CIRCLE) {
graph.setEndpoint(c, b, Endpoint.TAIL);
this.changeFlag = true;
}
}
}
}
private void awayFromAncestor(Graph graph, Node a, Node b, Node c) {
if ((graph.isAdjacentTo(a, c)) &&
(graph.getEndpoint(a, c) == Endpoint.CIRCLE)) {
if ((graph.getEndpoint(a, b) == Endpoint.ARROW) &&
(graph.getEndpoint(b, c) == Endpoint.ARROW) && (
(graph.getEndpoint(b, a) == Endpoint.TAIL) ||
(graph.getEndpoint(c, b) == Endpoint.TAIL))) {
if (!isArrowheadAllowed(graph, a, c)) {
return;
}
graph.setEndpoint(a, c, Endpoint.ARROW);
this.changeFlag = true;
}
}
}
//if Ao->c and a-->b-->c, then a-->c
private void awayFromCycle(Graph graph, Node a, Node b, Node c) {
if ((graph.isAdjacentTo(a, c)) &&
(graph.getEndpoint(a, c) == Endpoint.ARROW) &&
(graph.getEndpoint(c, a) == Endpoint.CIRCLE)) {
if (graph.paths().isDirected(a, b) && graph.paths().isDirected(b, c)) {
graph.setEndpoint(c, a, Endpoint.TAIL);
this.changeFlag = true;
}
}
}
/**
* Finds discriminating paths.
*/
private boolean discrimPaths(Graph graph) {
List nodes = graph.getNodes();
for (Node b : nodes) {
//potential A and C candidate pairs are only those
// that look like this: A<-oBo->C or A<->Bo->C
List possAandC = graph.getNodesOutTo(b, Endpoint.ARROW);
//keep arrows and circles
List possA = new LinkedList<>(possAandC);
possA.removeAll(graph.getNodesInTo(b, Endpoint.TAIL));
//keep only circles
List possC = new LinkedList<>(possAandC);
possC.retainAll(graph.getNodesInTo(b, Endpoint.CIRCLE));
for (Node a : possA) {
for (Node c : possC) {
if (!graph.isParentOf(a, c)) {
continue;
}
LinkedList reachable = new LinkedList<>();
reachable.add(a);
if (reachablePathFindOrient(graph, a, b, c, reachable)) {
return true;
}
}
}
}
return false;
}
/**
* a method to search "back from a" to find a DDP. It is called with a reachability list (first consisting only of
* a). This is breadth-first, utilizing "reachability" concept from Geiger, Verma, and Pearl 1990. The body of a DDP
* consists of colliders that are parents of c.
*/
private boolean reachablePathFindOrient(Graph graph, Node a, Node b, Node c,
LinkedList reachable) {
Set cParents = new HashSet<>(graph.getParents(c));
// Needed to avoid cycles in failure case.
Set visited = new HashSet<>();
visited.add(b);
visited.add(c);
// We don't want to include a,b,or c on the path, so they are added to
// the "visited" set. b and c are added explicitly here; a will be
// added in the first while iteration.
while (reachable.size() > 0) {
Node x = reachable.removeFirst();
visited.add(x);
// Possible DDP path endpoints.
List pathExtensions = graph.getNodesInTo(x, Endpoint.ARROW);
pathExtensions.removeAll(visited);
for (Node l : pathExtensions) {
// If l is reachable and not adjacent to c, its a DDP
// endpoint, so do DDP orientation. Otherwise, if l <-> c,
// add l to the list of reachable nodes.
if (!graph.isAdjacentTo(l, c)) {
// Check whether should be reoriented given
// that l is not adjacent to c; if so, orient and stop.
doDdpOrientation(graph, l, a, b, c);
return true;
} else if (cParents.contains(l)) {
if (graph.getEndpoint(x, l) == Endpoint.ARROW) {
reachable.add(l);
}
}
}
}
return false;
}
/**
* Orients the edges inside the definte discriminating path triangle. Takes the left endpoint, and a,b,c as
* arguments.
*/
private void doDdpOrientation(Graph graph, Node l, Node a, Node b, Node c) {
this.changeFlag = true;
for (IonIndependenceFacts iif : this.separations) {
if ((iif.getX().equals(l) && iif.getY().equals(c)) ||
iif.getY().equals(l) && iif.getX().equals(c)) {
for (Set condSet : iif.getZ()) {
if (condSet.contains(b)) {
graph.setEndpoint(c, b, Endpoint.TAIL);
this.discrimGraphs.add(graph);
return;
}
}
break;
}
}
Graph newGraph1 = new EdgeListGraph(graph);
newGraph1.setEndpoint(a, b, Endpoint.ARROW);
newGraph1.setEndpoint(c, b, Endpoint.ARROW);
this.discrimGraphs.add(newGraph1);
Graph newGraph2 = new EdgeListGraph(graph);
newGraph2.setEndpoint(c, b, Endpoint.TAIL);
this.discrimGraphs.add(newGraph2);
}
private Set removeMoreSpecific(Set outputPags) {
Set moreSpecific = new HashSet<>();
// checks for and removes PAGs tht are more specific, same skeleton and orientations
// except for one or more arrows or tails where another graph has circles, than other
// pags in the output graphs that may be produced from the edge removes in step 4
for (Graph pag : outputPags) {
for (Graph pag2 : outputPags) {
// if same pag
if (pag.equals(pag2)) {
continue;
}
// if different number of edges then continue
if (pag.getEdges().size() != pag2.getEdges().size()) {
continue;
}
boolean sameAdjacencies = true;
for (Edge edge1 : pag.getEdges()) {
if (!pag2.isAdjacentTo(edge1.getNode1(), edge1.getNode2())) {
sameAdjacencies = false;
}
}
if (sameAdjacencies) {
// checks to see if pag2 has same arrows and tails
boolean arrowstails = true;
boolean circles = true;
for (Edge edge2 : pag2.getEdges()) {
Edge edge1 = pag.getEdge(edge2.getNode1(), edge2.getNode2());
if (edge1.getNode1().equals(edge2.getNode1())) {
if (!edge2.getEndpoint1().equals(Endpoint.CIRCLE)) {
if (!edge1.getEndpoint1().equals(edge2.getEndpoint1())) {
arrowstails = false;
}
} else {
if (!edge1.getEndpoint1().equals(edge2.getEndpoint1())) {
circles = false;
}
}
if (!edge2.getEndpoint2().equals(Endpoint.CIRCLE)) {
if (!edge1.getEndpoint2().equals(edge2.getEndpoint2())) {
arrowstails = false;
}
} else {
if (!edge1.getEndpoint2().equals(edge2.getEndpoint2())) {
circles = false;
}
}
} else if (edge1.getNode1().equals(edge2.getNode2())) {
if (!edge2.getEndpoint1().equals(Endpoint.CIRCLE)) {
if (!edge1.getEndpoint2().equals(edge2.getEndpoint1())) {
arrowstails = false;
}
} else {
if (!edge1.getEndpoint2().equals(edge2.getEndpoint1())) {
circles = false;
}
}
if (!edge2.getEndpoint2().equals(Endpoint.CIRCLE)) {
if (!edge1.getEndpoint1().equals(edge2.getEndpoint2())) {
arrowstails = false;
}
} else {
if (!edge1.getEndpoint1().equals(edge2.getEndpoint2())) {
circles = false;
}
}
}
}
if (arrowstails && !circles) {
moreSpecific.add(pag);
break;
}
}
}
}
for (Graph pag : moreSpecific) {
outputPags.remove(pag);
}
return outputPags;
}
private Set checkPaths(Set pags) {
HashSet pagsOut = new HashSet<>();
for (Graph pag : pags) {
boolean allAccountFor = true;
GRAPH:
for (Graph inGraph : this.input) {
for (Edge edge : inGraph.getEdges()) {
Node node1 = pag.getNode(edge.getNode1().getName());
Node node2 = pag.getNode(edge.getNode2().getName());
if (Edges.isDirectedEdge(edge)) {
if (!pag.paths().existsSemiDirectedPath(node1, Collections.singleton(node2))) {
allAccountFor = false;
break GRAPH;
}
}
if (/*!pag.existsTrek(node1, node2) ||*/ Edges.isPartiallyOrientedEdge(edge)) {
if (pag.paths().existsSemiDirectedPath(node2, Collections.singleton(node1))) {
allAccountFor = false;
break GRAPH;
}
}
}
}
if (allAccountFor) {
pagsOut.add(pag);
}
}
return pagsOut;
}
private Graph screenForKnowledge(Graph pag) {
for (Iterator it = this.knowledge.forbiddenEdgesIterator(); it.hasNext(); ) {
KnowledgeEdge next = it.next();
Node y = pag.getNode(next.getFrom());
Node x = pag.getNode(next.getTo());
if (x == null || y == null) {
continue;
}
Edge edge = pag.getEdge(x, y);
if (edge == null) {
continue;
}
if (edge.getProximalEndpoint(x) == Endpoint.ARROW && edge.getProximalEndpoint(y) == Endpoint.TAIL) {
return null;
} else if (edge.getProximalEndpoint(x) == Endpoint.ARROW && edge.getProximalEndpoint(y) == Endpoint.CIRCLE) {
pag.removeEdge(edge);
pag.addEdge(Edges.bidirectedEdge(x, y));
} else if (edge.getProximalEndpoint(x) == Endpoint.CIRCLE && edge.getProximalEndpoint(y) == Endpoint.CIRCLE) {
pag.removeEdge(edge);
pag.addEdge(Edges.partiallyOrientedEdge(x, y));
}
}
for (Iterator it = this.knowledge.requiredEdgesIterator(); it.hasNext(); ) {
KnowledgeEdge next = it.next();
Node x = pag.getNode(next.getFrom());
Node y = pag.getNode(next.getTo());
if (x == null || y == null) {
continue;
}
Edge edge = pag.getEdge(x, y);
if (edge == null) {
return null;
} else if (edge.getProximalEndpoint(x) == Endpoint.ARROW && edge.getProximalEndpoint(y) == Endpoint.TAIL) {
return null;
} else if (edge.getProximalEndpoint(x) == Endpoint.ARROW && edge.getProximalEndpoint(y) == Endpoint.CIRCLE) {
return null;
} else if (edge.getProximalEndpoint(x) == Endpoint.CIRCLE && edge.getProximalEndpoint(y) == Endpoint.ARROW) {
pag.removeEdge(edge);
pag.addEdge(Edges.directedEdge(x, y));
} else if (edge.getProximalEndpoint(x) == Endpoint.CIRCLE && edge.getProximalEndpoint(y) == Endpoint.CIRCLE) {
pag.removeEdge(edge);
pag.addEdge(Edges.directedEdge(x, y));
}
}
// doFinalOrientation(pag);
return pag;
}
private Set applyKnowledge(Set outputSet) {
Set _out = new HashSet<>();
for (Graph graph : outputSet) {
Graph _graph = screenForKnowledge(graph);
if (_graph != null) {
_out.add(_graph);
}
}
return _out;
}
/**
* Exactly the same as edu.cmu.tetrad.graph.IndependenceFact excepting this class allows for multiple conditioning
* sets to be associated with a single pair of nodes, which is necessary for the proper ordering of iterations in
* the ION search.
*/
private static final class IonIndependenceFacts {
private final Node x;
private final Node y;
private final Set> z;
/**
* Constructs a triple of nodes.
*/
public IonIndependenceFacts(Node x, Node y, Set> z) {
if (x == null || y == null || z == null) {
throw new NullPointerException();
}
this.x = x;
this.y = y;
this.z = z;
}
public Node getX() {
return this.x;
}
public Node getY() {
return this.y;
}
public Set> getZ() {
return this.z;
}
public void addMoreZ(Set moreZ) {
this.z.add(moreZ);
}
public int hashCode() {
int hash = 17;
hash += 19 * this.x.hashCode() * this.y.hashCode();
hash += 23 * this.z.hashCode();
return hash;
}
public boolean equals(Object obj) {
if (!(obj instanceof IonIndependenceFacts fact)) {
return false;
}
return (this.x.equals(fact.x) && this.y.equals(fact.y) &&
this.z.equals(fact.z))
|| (this.x.equals(fact.y) & this.y.equals(fact.x) &&
this.z.equals(fact.z));
}
public String toString() {
return "I(" + this.x + ", " + this.y + " | " + this.z + ")";
}
}
/**
* A PowerSet constructed with a collection with elements of type E can construct an Iterator which enumerates all
* possible subsets (of type Collection) of the collection used to construct the PowerSet.
*
* @param The type of elements in the Collection passed to the constructor.
* @author pingel
*/
private static class PowerSet implements Iterable> {
Collection all;
public PowerSet(Collection all) {
this.all = all;
}
/**
* @return an iterator over elements of type Collection which enumerates the PowerSet of the collection used
* in the constructor
*/
public Iterator> iterator() {
return new PowerSetIterator<>(this);
}
static class PowerSetIterator implements Iterator> {
PowerSet powerSet;
List canonicalOrder = new ArrayList<>();
List mask = new ArrayList<>();
boolean hasNext = true;
PowerSetIterator(PowerSet powerSet) {
this.powerSet = powerSet;
this.canonicalOrder.addAll(powerSet.all);
}
public void remove() {
throw new UnsupportedOperationException();
}
private boolean allOnes() {
for (InE bit : this.mask) {
if (bit == null) {
return false;
}
}
return true;
}
private void increment() {
int i = 0;
while (true) {
if (i < this.mask.size()) {
InE bit = this.mask.get(i);
if (bit == null) {
this.mask.set(i, this.canonicalOrder.get(i));
return;
} else {
this.mask.set(i, null);
i++;
}
} else {
this.mask.add(this.canonicalOrder.get(i));
return;
}
}
}
public boolean hasNext() {
return this.hasNext;
}
public Set next() {
Set result = new HashSet<>(this.mask);
result.remove(null);
this.hasNext = this.mask.size() < this.powerSet.all.size() || !allOnes();
if (this.hasNext) {
increment();
}
return result;
}
}
}
}