All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.search.utils.BesPermutation Maven / Gradle / Ivy

There is a newer version: 7.6.5
Show newest version
package edu.cmu.tetrad.search.utils;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.Boss;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import org.jetbrains.annotations.NotNull;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

import static edu.cmu.tetrad.graph.Edges.directedEdge;
import static org.apache.commons.math3.util.FastMath.min;


/**
 * 

Implements a version of the BES (Best Equivalent Search) algorithm * that takes a permutation as input and yields a permtuation as output, where the related DAG or CPDAG models are * implied by the ordering or variables in these permutations. BES is the second step of the GES algorithm (e.g., FGES). * The first step in GES starts with an empty graph and adds edges (with corresponding reorientations of edges), * yielding a Markov model. The second step, this one, BES, starts with this Markov model and then tries to remove edges * from it (with corresponding reorientation) to improve the BES scores.

*

The advantage of doing this is that BES can then be used as * a step in certain permutation-based algorithms like BOSS to allow correct models to be inferred under the assumption * of faithfulness.

* * @author bryanandrews * @author josephramsey * @see Fges * @see Bes * @see Boss */ public class BesPermutation { private final List variables; private final Score score; private Knowledge knowledge = new Knowledge(); private boolean verbose = true; private int depth = -1; /** * Constructor. * * @param score The score that BES (from FGES) will use. */ public BesPermutation(@NotNull Score score) { this.score = score; this.variables = score.getVariables(); } /** * Returns the variables. * * @return This list. */ @NotNull public List getVariables() { return this.variables; } /** * Sets whether verbose output should be printed. * * @param verbose True, if so. */ public void setVerbose(boolean verbose) { this.verbose = verbose; } private void buildIndexing(List nodes, Map hashIndices) { int i = -1; for (Node n : nodes) { hashIndices.put(n, ++i); } } // public void setDepth(int depth) { // if (depth < -1) throw new IllegalArgumentException("Depth should be >= -1."); // this.depth = depth; // } /** * Runs BES. * @param graph The graph. * @param order The order. * @param suborder The suborder. */ public void bes(Graph graph, List order, List suborder) { Map hashIndices = new HashMap<>(); SortedSet sortedArrowsBack = new ConcurrentSkipListSet<>(); Map arrowsMapBackward = new ConcurrentHashMap<>(); int[] arrowIndex = new int[1]; buildIndexing(order, hashIndices); reevaluateBackward(new HashSet<>(order), graph, hashIndices, arrowIndex, sortedArrowsBack, arrowsMapBackward); while (!sortedArrowsBack.isEmpty()) { Arrow arrow = sortedArrowsBack.first(); sortedArrowsBack.remove(arrow); Node x = arrow.getA(); Node y = arrow.getB(); if (!graph.isAdjacentTo(x, y)) { continue; } Edge edge = graph.getEdge(x, y); if (edge.pointsTowards(x)) { continue; } if (!getNaYX(x, y, graph).equals(arrow.getNaYX())) { continue; } if (!new HashSet<>(graph.getParents(y)).equals(new HashSet<>(arrow.getParents()))) { continue; } if (!validDelete(x, y, arrow.getHOrT(), arrow.getNaYX(), graph, suborder)) { continue; } Set complement = new HashSet<>(arrow.getNaYX()); complement.removeAll(arrow.getHOrT()); double _bump = deleteEval(x, y, complement, arrow.parents, hashIndices); delete(x, y, arrow.getHOrT(), _bump, arrow.getNaYX(), graph); Set process = revertToCPDAG(graph); process.add(x); process.add(y); process.addAll(graph.getAdjacentNodes(x)); process.addAll(graph.getAdjacentNodes(y)); reevaluateBackward(new HashSet<>(process), graph, hashIndices, arrowIndex, sortedArrowsBack, arrowsMapBackward); } } private void delete(Node x, Node y, Set H, double bump, Set naYX, Graph graph) { Edge oldxy = graph.getEdge(x, y); Set diff = new HashSet<>(naYX); diff.removeAll(H); graph.removeEdge(oldxy); int numEdges = graph.getNumEdges(); if (numEdges % 1000 == 0 && numEdges > 0) { System.out.println("Num edges (backwards) = " + numEdges); } if (verbose) { int cond = diff.size() + graph.getParents(y).size(); String message = (graph.getNumEdges()) + ". DELETE " + x + " --> " + y + " H = " + H + " NaYX = " + naYX + " degree = " + GraphUtils.getDegree(graph) + " indegree = " + GraphUtils.getIndegree(graph) + " diff = " + diff + " (" + bump + ") " + " cond = " + cond; TetradLogger.getInstance().forceLogMessage(message); } for (Node h : H) { if (graph.isParentOf(h, y) || graph.isParentOf(h, x)) { continue; } Edge oldyh = graph.getEdge(y, h); graph.removeEdge(oldyh); graph.addEdge(directedEdge(y, h)); if (verbose) { TetradLogger.getInstance().forceLogMessage("--- Directing " + oldyh + " to " + graph.getEdge(y, h)); } Edge oldxh = graph.getEdge(x, h); if (Edges.isUndirectedEdge(oldxh)) { graph.removeEdge(oldxh); graph.addEdge(directedEdge(x, h)); if (verbose) { TetradLogger.getInstance().forceLogMessage("--- Directing " + oldxh + " to " + graph.getEdge(x, h)); } } } } private double deleteEval(Node x, Node y, Set complement, Set parents, Map hashIndices) { Set set = new HashSet<>(complement); set.addAll(parents); set.remove(x); return -scoreGraphChange(x, y, set, hashIndices); } private double scoreGraphChange(Node x, Node y, Set parents, Map hashIndices) { int xIndex = hashIndices.get(x); int yIndex = hashIndices.get(y); if (x == y) { throw new IllegalArgumentException(); } if (parents.contains(y)) { throw new IllegalArgumentException(); } int[] parentIndices = new int[parents.size()]; int count = 0; for (Node parent : parents) { parentIndices[count++] = hashIndices.get(parent); } return score.localScoreDiff(xIndex, yIndex, parentIndices); } public Knowledge getKnowledge() { return knowledge; } /** * Sets the knowledge that BES will use. * * @param knowledge This knowledge. */ public void setKnowledge(Knowledge knowledge) { this.knowledge = new Knowledge((Knowledge) knowledge); } private Set revertToCPDAG(Graph graph) { MeekRules rules = new MeekRules(); rules.setKnowledge(getKnowledge()); rules.setMeekPreventCycles(true); boolean meekVerbose = false; rules.setVerbose(meekVerbose); return rules.orientImplied(graph); } private boolean validDelete(Node x, Node y, Set H, Set naYX, Graph graph, List suborder) { if (existsKnowledge()) { for (Node h : H) { if (knowledge.isForbidden(x.getName(), h.getName())) return false; if (knowledge.isForbidden(y.getName(), h.getName())) return false; } } Set diff = new HashSet<>(naYX); diff.removeAll(H); if (!isClique(diff, graph)) return false; if (existsKnowledge()) { graph = new EdgeListGraph(graph); Edge oldxy = graph.getEdge(x, y); graph.removeEdge(oldxy); for (Node h : H) { if (graph.isParentOf(h, y) || graph.isParentOf(h, x)) continue; Edge oldyh = graph.getEdge(y, h); graph.removeEdge(oldyh); graph.addEdge(directedEdge(y, h)); Edge oldxh = graph.getEdge(x, h); if (!Edges.isUndirectedEdge(oldxh)) continue; graph.removeEdge(oldxh); graph.addEdge(directedEdge(x, h)); } revertToCPDAG(graph); List initialOrder = new ArrayList<>(suborder); Collections.reverse(initialOrder); while (!initialOrder.isEmpty()) { Iterator itr = initialOrder.iterator(); Node b; do { if (itr.hasNext()) b = itr.next(); else return false; } while (invalidSink(b, graph)); graph.removeNode(b); itr.remove(); } } return true; } private boolean invalidSink(Node x, Graph graph) { LinkedList neighbors = new LinkedList<>(); for (Edge edge : graph.getEdges(x)) { if (edge.getDistalEndpoint(x) == Endpoint.ARROW) return true; if (edge.getProximalEndpoint(x) == Endpoint.TAIL) neighbors.add(edge.getDistalNode(x)); } while (!neighbors.isEmpty()) { Node y = neighbors.pop(); for (Node z : neighbors) if (!graph.isAdjacentTo(y, z)) return true; } return false; } private boolean existsKnowledge() { return !knowledge.isEmpty(); } private boolean isClique(Set nodes, Graph graph) { List _nodes = new ArrayList<>(nodes); for (int i = 0; i < _nodes.size(); i++) { for (int j = i + 1; j < _nodes.size(); j++) { if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { return false; } } } return true; } private Set getNaYX(Node x, Node y, Graph graph) { List adj = graph.getAdjacentNodes(y); Set nayx = new HashSet<>(); for (Node z : adj) { if (z == x) { continue; } Edge yz = graph.getEdge(y, z); if (!Edges.isUndirectedEdge(yz)) { continue; } if (!graph.isAdjacentTo(z, x)) { continue; } nayx.add(z); } return nayx; } private void reevaluateBackward(Set toProcess, Graph graph, Map hashIndices, int[] arrowIndex, SortedSet sortedArrowsBack, Map arrowsMapBackward) { class BackwardTask extends RecursiveTask { final Map arrowsMapBackward; private final Node r; private final List adj; private final Map hashIndices; private final int chunk; private final int from; private final int to; private final SortedSet sortedArrowsBack; private BackwardTask(Node r, List adj, int chunk, int from, int to, Map hashIndices, SortedSet sortedArrowsBack, Map arrowsMapBackward) { this.adj = adj; this.hashIndices = hashIndices; this.chunk = chunk; this.from = from; this.to = to; this.r = r; this.sortedArrowsBack = sortedArrowsBack; this.arrowsMapBackward = arrowsMapBackward; } @Override protected Boolean compute() { if (to - from <= chunk) { for (int _w = from; _w < to; _w++) { final Node w = adj.get(_w); Edge e = graph.getEdge(w, r); if (e != null) { if (e.pointsTowards(r)) { calculateArrowsBackward(w, r, graph, arrowsMapBackward, hashIndices, arrowIndex, sortedArrowsBack); } else if (e.pointsTowards(w)) { calculateArrowsBackward(r, w, graph, arrowsMapBackward, hashIndices, arrowIndex, sortedArrowsBack); } else { calculateArrowsBackward(w, r, graph, arrowsMapBackward, hashIndices, arrowIndex, sortedArrowsBack); calculateArrowsBackward(r, w, graph, arrowsMapBackward, hashIndices, arrowIndex, sortedArrowsBack); } } } } else { int mid = (to - from) / 2; List tasks = new ArrayList<>(); tasks.add(new BackwardTask(r, adj, chunk, from, from + mid, hashIndices, sortedArrowsBack, arrowsMapBackward)); tasks.add(new BackwardTask(r, adj, chunk, from + mid, to, hashIndices, sortedArrowsBack, arrowsMapBackward)); invokeAll(tasks); } return true; } } for (Node r : toProcess) { List adjacentNodes = new ArrayList<>(toProcess); ForkJoinPool.commonPool().invoke(new BackwardTask(r, adjacentNodes, getChunkSize(adjacentNodes.size()), 0, adjacentNodes.size(), hashIndices, sortedArrowsBack, arrowsMapBackward)); } } private int getChunkSize(int n) { int chunk = n / Runtime.getRuntime().availableProcessors(); if (chunk < 100) chunk = 100; return chunk; } private void calculateArrowsBackward(Node a, Node b, Graph graph, Map arrowsMapBackward, Map hashIndices, int[] arrowIndex, SortedSet sortedArrowsBack) { if (existsKnowledge()) { if (!getKnowledge().noEdgeRequired(a.getName(), b.getName())) { return; } } Set naYX = getNaYX(a, b, graph); Set parents = new HashSet<>(graph.getParents(b)); List _naYX = new ArrayList<>(naYX); ArrowConfigBackward config = new ArrowConfigBackward(naYX, parents); ArrowConfigBackward storedConfig = arrowsMapBackward.get(directedEdge(a, b)); if (storedConfig != null && storedConfig.equals(config)) return; arrowsMapBackward.put(directedEdge(a, b), new ArrowConfigBackward(naYX, parents)); int _depth = min(depth, _naYX.size()); final SublistGenerator gen = new SublistGenerator(_naYX.size(), _depth);//_naYX.size()); int[] choice; Set maxComplement = null; double maxBump = Double.NEGATIVE_INFINITY; while ((choice = gen.next()) != null) { Set complement = GraphUtils.asSet(choice, _naYX); double _bump = deleteEval(a, b, complement, parents, hashIndices); if (_bump > maxBump) { maxBump = _bump; maxComplement = complement; } } if (maxBump > 0) { Set _H = new HashSet<>(naYX); _H.removeAll(maxComplement); addArrowBackward(a, b, _H, naYX, parents, maxBump, arrowIndex, sortedArrowsBack); } } private void addArrowBackward(Node a, Node b, Set hOrT, Set naYX, Set parents, double bump, int[] arrowIndex, SortedSet sortedArrowsBack) { Arrow arrow = new Arrow(bump, a, b, hOrT, null, naYX, parents, arrowIndex[0]++); sortedArrowsBack.add(arrow); } private static class ArrowConfigBackward { private Set nayx; private Set parents; public ArrowConfigBackward(Set nayx, Set parents) { this.setNayx(nayx); this.setParents(parents); } public void setNayx(Set nayx) { this.nayx = nayx; } public Set getParents() { return parents; } public void setParents(Set parents) { this.parents = parents; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ArrowConfigBackward that = (ArrowConfigBackward) o; return nayx.equals(that.nayx) && parents.equals(that.parents); } @Override public int hashCode() { return Objects.hash(nayx, parents); } } private static class Arrow implements Comparable { private final double bump; private final Node a; private final Node b; private final Set hOrT; private final Set naYX; private final Set parents; private final int index; private Set TNeighbors; Arrow(double bump, Node a, Node b, Set hOrT, Set capTorH, Set naYX, Set parents, int index) { this.bump = bump; this.a = a; this.b = b; this.setTNeighbors(capTorH); this.hOrT = hOrT; this.naYX = naYX; this.index = index; this.parents = parents; } public double getBump() { return bump; } public Node getA() { return a; } public Node getB() { return b; } Set getHOrT() { return hOrT; } Set getNaYX() { return naYX; } // Sorting by bump, high to low. The problem is the SortedSet contains won't add a new element if it compares // to zero with an existing element, so for the cases where the comparison is to zero (i.e. have the same // bump), we need to determine as quickly as possible a determinate ordering (fixed) ordering for two variables. // The fastest way to do this is using a hash code, though it's still possible for two Arrows to have the // same hash code but not be equal. If we're paranoid, in this case we calculate a determinate comparison // not equal to zero by keeping a list. This last part is commened out by default. public int compareTo(@NotNull Arrow arrow) { final int compare = Double.compare(arrow.getBump(), getBump()); if (compare == 0) { return Integer.compare(getIndex(), arrow.getIndex()); } return compare; } public String toString() { return "Arrow<" + a + "->" + b + " bump = " + bump + " t/h = " + hOrT + " TNeighbors = " + getTNeighbors() + " parents = " + parents + " naYX = " + naYX + ">"; } public int getIndex() { return index; } public Set getTNeighbors() { return TNeighbors; } public void setTNeighbors(Set TNeighbors) { this.TNeighbors = TNeighbors; } public Set getParents() { return parents; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy