edu.stanford.nlp.util.IntervalTree Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
package edu.stanford.nlp.util;
import java.util.*;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.ToDoubleFunction;
/**
* An interval tree maintains a tree so that all intervals to the left start
* before current interval and all intervals to the right start after.
*
* @author Angel Chang
*/
public class IntervalTree, T extends HasInterval> extends AbstractCollection
{
private static final double defaultAlpha = 0.65; // How balanced we want this tree (between 0.5 and 1.0)
private static final boolean debug = false;
private TreeNode root = new TreeNode<>();
// Tree node
public static class TreeNode, T extends HasInterval> {
T value;
E maxEnd; // Maximum end in this subtree
int size;
TreeNode left;
TreeNode right;
TreeNode parent; // Parent for convenience
public boolean isEmpty() { return value == null; }
public void clear() {
value = null;
maxEnd = null;
size = 0;
left = null;
right = null;
// parent = null;
}
}
@Override
public boolean isEmpty() { return root.isEmpty(); }
@Override
public void clear() {
root.clear();
}
public String toString() {
return "Size: " + root.size;
}
@Override
public boolean add(T target) {
return add(root, target, defaultAlpha);
}
public boolean add(TreeNode node, T target) {
return add(node, target, defaultAlpha);
}
// Add node to tree - attempting to maintain alpha balance
public boolean add(TreeNode node, T target, double alpha) {
if (target == null) return false;
TreeNode n = node;
int depth = 0;
int thresholdDepth = (node.size > 10)? ((int) (-Math.log(node.size)/Math.log(alpha)+1)):10;
while (n != null) {
if (n.value == null) {
n.value = target;
n.maxEnd = target.getInterval().getEnd();
n.size = 1;
if (depth > thresholdDepth) {
// Do rebalancing
TreeNode p = n.parent;
while (p != null) {
if (p.size > 10 && !isAlphaBalanced(p,alpha)) {
TreeNode newParent = balance(p);
if (p == root) root = newParent;
if (debug) this.check();
break;
}
p = p.parent;
}
}
return true;
} else {
depth++;
n.maxEnd = Interval.max(n.maxEnd, target.getInterval().getEnd());
n.size++;
if (target.getInterval().compareTo(n.value.getInterval()) <= 0) {
// Should go on left
if (n.left == null) {
n.left = new TreeNode<>();
n.left.parent = n;
}
n = n.left;
} else {
// Should go on right
if (n.right == null) {
n.right = new TreeNode<>();
n.right.parent = n;
}
n = n.right;
}
}
}
return false;
}
@Override
public int size()
{
return root.size;
}
@Override
public Iterator iterator() {
return new TreeNodeIterator<>(root);
}
private static class TreeNodeIterator, T extends HasInterval> extends AbstractIterator {
TreeNode node;
Iterator curIter;
int stage = -1;
T next;
public TreeNodeIterator(TreeNode node) {
this.node = node;
if (node.isEmpty()) {
stage = 3;
}
}
@Override
public boolean hasNext() {
if (next == null) {
next = getNext();
}
return next != null;
}
@Override
public T next() {
if (hasNext()) {
T x = next;
next = getNext();
return x;
} else throw new NoSuchElementException();
}
private T getNext() {
// TODO: Do more efficient traversal down the tree
if (stage > 2) return null;
while (curIter == null || !curIter.hasNext()) {
stage++;
switch (stage) {
case 0:
curIter = (node.left != null)? new TreeNodeIterator<>(node.left):null;
break;
case 1:
curIter = null;
return node.value;
case 2:
curIter = (node.right != null)? new TreeNodeIterator<>(node.right):null;
break;
default:
return null;
}
}
if (curIter != null && curIter.hasNext()) {
return curIter.next();
} else return null;
}
}
@Override
public boolean removeAll(Collection> c) {
boolean modified = false;
for (Object t:c) {
if (remove(t)) { modified = true; }
}
return modified;
}
@Override
public boolean retainAll(Collection> c) {
throw new UnsupportedOperationException("retainAll not implemented");
}
@Override
public boolean contains(Object o) {
try {
return contains((T) o);
} catch (ClassCastException ex) {
return false;
}
}
@Override
public boolean remove(Object o) {
try {
return remove((T) o);
} catch (ClassCastException ex) {
return false;
}
}
public boolean remove(T target) {
return remove(root, target);
}
public boolean remove(TreeNode node, T target)
{
if (target == null) return false;
if (node.value == null) return false;
if (target.equals(node.value)) {
int leftSize = (node.left != null)? node.left.size:0;
int rightSize = (node.right != null)? node.right.size:0;
if (leftSize == 0) {
if (rightSize == 0) {
node.clear();
} else {
node.value = node.right.value;
node.size = node.right.size;
node.maxEnd = node.right.maxEnd;
node.left = node.right.left;
node.right = node.right.right;
if (node.left != null) node.left.parent = node;
if (node.right != null) node.right.parent = node;
}
} else if (rightSize == 0) {
node.value = node.left.value;
node.size = node.left.size;
node.maxEnd = node.left.maxEnd;
node.left = node.left.left;
node.right = node.left.right;
if (node.left != null) node.left.parent = node;
if (node.right != null) node.right.parent = node;
} else {
// Rotate left up
node.value = node.left.value;
node.size--;
node.maxEnd = Interval.max(node.left.maxEnd, node.right.maxEnd);
TreeNode origRight = node.right;
node.right = node.left.right;
node.left = node.left.left;
if (node.left != null) node.left.parent = node;
if (node.right != null) node.right.parent = node;
// Attach origRight somewhere...
TreeNode rightmost = getRightmostNode(node);
rightmost.right = origRight;
if (rightmost.right != null) {
rightmost.right.parent = rightmost;
// adjust maxEnd and sizes on the right
adjustUpwards(rightmost.right,node);
}
}
return true;
} else {
if (target.getInterval().compareTo(node.value.getInterval()) <= 0) {
// Should go on left
if (node.left == null) {
return false;
}
boolean res = remove(node.left, target);
if (res) {
node.maxEnd = Interval.max(node.maxEnd, node.left.maxEnd);
node.size--;
}
return res;
} else {
// Should go on right
if (node.right == null) {
return false;
}
boolean res = remove(node.right, target);
if (res) {
node.maxEnd = Interval.max(node.maxEnd, node.right.maxEnd);
node.size--;
}
return res;
}
}
}
private void adjustUpwards(TreeNode node) {
adjustUpwards(node, null);
}
// Adjust upwards starting at this node until stopAt
private void adjustUpwards(TreeNode node, TreeNode stopAt) {
TreeNode n = node;
while (n != null && n != stopAt) {
int leftSize = (n.left != null)? n.left.size:0;
int rightSize = (n.right != null)? n.right.size:0;
n.maxEnd = n.value.getInterval().getEnd();
if (n.left != null) {
n.maxEnd = Interval.max(n.maxEnd, n.left.maxEnd);
}
if (n.right != null) {
n.maxEnd = Interval.max(n.maxEnd, n.right.maxEnd);
}
n.size = leftSize + 1 + rightSize;
if (n == n.parent) {
throw new IllegalStateException("node is same as parent!!!");
}
n = n.parent;
}
}
private void adjust(TreeNode node) {
adjustUpwards(node, node.parent);
}
public void check() {
check(root);
}
public void check(TreeNode treeNode) {
Stack> todo = new Stack<>();
todo.add(treeNode);
while (!todo.isEmpty()) {
TreeNode node = todo.pop();
if (node == node.parent) {
throw new IllegalStateException("node is same as parent!!!");
}
if (node.isEmpty()) {
if (node.left != null) throw new IllegalStateException("Empty node shouldn't have left branch");
if (node.right != null) throw new IllegalStateException("Empty node shouldn't have right branch");
continue;
}
int leftSize = (node.left != null)? node.left.size:0;
int rightSize = (node.right != null)? node.right.size:0;
E leftMax = (node.left != null)? node.left.maxEnd:null;
E rightMax = (node.right != null)? node.right.maxEnd:null;
E maxEnd = node.value.getInterval().getEnd();
if (leftMax != null && leftMax.compareTo(maxEnd) > 0) {
maxEnd = leftMax;
}
if (rightMax != null && rightMax.compareTo(maxEnd) > 0) {
maxEnd = rightMax;
}
if (!maxEnd.equals(node.maxEnd)) {
throw new IllegalStateException("max end is not as expected!!!");
}
if (node.size != leftSize + rightSize + 1) {
throw new IllegalStateException("node size is not one plus the sum of left and right!!!");
}
if (node.left != null) {
if (node.left.parent != node) {
throw new IllegalStateException("node left parent is not same as node!!!");
}
}
if (node.right != null) {
if (node.right.parent != node) {
throw new IllegalStateException("node right parent is not same as node!!!");
}
}
if (node.parent != null) {
// Go up parent and make sure we are on correct side
TreeNode n = node;
while (n != null && n.parent != null) {
// Check we are either right or left
if (n == n.parent.left) {
// Check that node is less than the parent
if (node.value != null) {
if (node.value.getInterval().compareTo(n.parent.value.getInterval()) > 0) {
throw new IllegalStateException("node is not on the correct side!!!");
}
}
} else if (n == n.parent.right) {
// Check that node is greater than the parent
if (node.value.getInterval().compareTo(n.parent.value.getInterval()) <= 0) {
throw new IllegalStateException("node is not on the correct side!!!");
}
} else {
throw new IllegalStateException("node is not parent's left or right child!!!");
}
n = n.parent;
}
}
if (node.left != null) todo.add(node.left);
if (node.right != null) todo.add(node.right);
}
}
public boolean isAlphaBalanced(TreeNode node, double alpha) {
int leftSize = (node.left != null)? node.left.size:0;
int rightSize = (node.right != null)? node.right.size:0;
int threshold = (int) (alpha*node.size) + 1;
return (leftSize <= threshold) && (rightSize <= threshold);
}
public void balance() {
root = balance(root);
}
// Balances this tree
public TreeNode balance(TreeNode node) {
if (debug) check(node);
Stack> todo = new Stack<>();
todo.add(node);
TreeNode newRoot = null;
while (!todo.isEmpty()) {
TreeNode n = todo.pop();
// Balance tree between this node
// Select median nodes and try to balance the tree
int medianAt = n.size/2;
TreeNode median = getNode(n, medianAt);
// Okay, this is going to be our root
if (median != null && median != n) {
// Yes, there is indeed something to be done
rotateUp(median, n);
}
if (newRoot == null) {
newRoot = median;
}
if (median.left != null) todo.push(median.left);
if (median.right != null) todo.push(median.right);
}
if (newRoot == null) return node;
else return newRoot;
}
// Moves this node up the tree until it replaces the target node
public void rotateUp(TreeNode node, TreeNode target) {
TreeNode n = node;
boolean done = false;
while (n != null && n.parent != null && !done) {
// Check if we are the left or right child
done = (n.parent == target);
if (n == n.parent.left) {
n = rightRotate(n.parent);
} else if (n == n.parent.right) {
n = leftRotate(n.parent);
} else {
throw new IllegalStateException("Not on parent's left or right branches.");
}
if (debug) check(n);
}
}
// Moves this node to the right and the left child up and returns the new root
public TreeNode rightRotate(TreeNode oldRoot) {
if (oldRoot == null || oldRoot.isEmpty() || oldRoot.left == null) return oldRoot;
TreeNode oldLeftRight = oldRoot.left.right;
TreeNode newRoot = oldRoot.left;
newRoot.right = oldRoot;
oldRoot.left = oldLeftRight;
// Adjust parents and such
newRoot.parent = oldRoot.parent;
newRoot.maxEnd = oldRoot.maxEnd;
newRoot.size = oldRoot.size;
if (newRoot.parent != null) {
if (newRoot.parent.left == oldRoot) {
newRoot.parent.left = newRoot;
} else if (newRoot.parent.right == oldRoot) {
newRoot.parent.right = newRoot;
} else {
throw new IllegalStateException("Old root not a child of it's parent");
}
}
oldRoot.parent = newRoot;
if (oldLeftRight != null) oldLeftRight.parent = oldRoot;
adjust(oldRoot);
return newRoot;
}
// Moves this node to the left and the right child up and returns the new root
public TreeNode leftRotate(TreeNode oldRoot) {
if (oldRoot == null || oldRoot.isEmpty() || oldRoot.right == null) return oldRoot;
TreeNode oldRightLeft = oldRoot.right.left;
TreeNode newRoot = oldRoot.right;
newRoot.left = oldRoot;
oldRoot.right = oldRightLeft;
// Adjust parents and such
newRoot.parent = oldRoot.parent;
newRoot.maxEnd = oldRoot.maxEnd;
newRoot.size = oldRoot.size;
if (newRoot.parent != null) {
if (newRoot.parent.left == oldRoot) {
newRoot.parent.left = newRoot;
} else if (newRoot.parent.right == oldRoot) {
newRoot.parent.right = newRoot;
} else {
throw new IllegalStateException("Old root not a child of it's parent");
}
}
oldRoot.parent = newRoot;
if (oldRightLeft != null) oldRightLeft.parent = oldRoot;
adjust(oldRoot);
return newRoot;
}
public int height() { return height(root); }
public int height(TreeNode node) {
if (node.value == null) return 0;
int lh = (node.left != null)? height(node.left):0;
int rh = (node.right != null)? height(node.right):0;
return Math.max(lh,rh) + 1;
}
public TreeNode getLeftmostNode(TreeNode node)
{
TreeNode n = node;
while (n.left != null) {
n = n.left;
}
return n;
}
public TreeNode getRightmostNode(TreeNode node)
{
TreeNode n = node;
while (n.right != null) {
n = n.right;
}
return n;
}
// Returns ith node
public TreeNode getNode(TreeNode node, int nodeIndex) {
int i = nodeIndex;
TreeNode n = node;
while (n != null) {
if (i < 0 || i >= n.size) return null;
int leftSize = (n.left != null)? n.left.size:0;
if (i == leftSize) {
return n;
} else if (i > leftSize) {
// Look for in right side of tree
n = n.right;
i = i - leftSize - 1;
} else {
n = n.left;
}
}
return null;
}
public boolean addNonOverlapping(T target)
{
if (overlaps(target)) return false;
add(target);
return true;
}
public boolean addNonNested(T target)
{
if (containsInterval(target, false)) return false;
add(target);
return true;
}
public boolean overlaps(T target) {
return overlaps(root, target.getInterval());
}
public List getOverlapping(T target) {
return getOverlapping(root, target.getInterval());
}
public static , T extends HasInterval> List getOverlapping(TreeNode n, E p)
{
List overlapping = new ArrayList<>();
getOverlapping(n, p, overlapping);
return overlapping;
}
public static , T extends HasInterval> List getOverlapping(TreeNode n, Interval target)
{
List overlapping = new ArrayList<>();
getOverlapping(n, target, overlapping);
return overlapping;
}
// Search for all intervals which contain p, starting with the
// node "n" and adding matching intervals to the list "result"
public static , T extends HasInterval> void getOverlapping(TreeNode n, E p, List result) {
getOverlapping(n, Interval.toInterval(p,p), result);
}
public static , T extends HasInterval> void getOverlapping(TreeNode node, Interval target, List result) {
Queue> todo = new LinkedList<>();
todo.add(node);
while (!todo.isEmpty()) {
TreeNode n = todo.poll();
// Don't search nodes that don't exist
if (n == null || n.isEmpty())
continue;
// If target is to the right of the rightmost point of any interval
// in this node and all children, there won't be any matches.
if (target.first.compareTo(n.maxEnd) > 0)
continue;
// Search left children
if (n.left != null) {
todo.add(n.left);
}
// Check this node
if (n.value.getInterval().overlaps(target)) {
result.add(n.value);
}
// If target is to the left of the start of this interval,
// then it can't be in any child to the right.
if (target.second.compareTo(n.value.getInterval().first()) < 0) {
continue;
}
// Otherwise, search right children
if (n.right != null) {
todo.add(n.right);
}
}
}
public static , T extends HasInterval> boolean overlaps(TreeNode n, E p) {
return overlaps(n, Interval.toInterval(p,p));
}
public static , T extends HasInterval> boolean overlaps(TreeNode node, Interval target) {
Stack> todo = new Stack<>();
todo.push(node);
while (!todo.isEmpty()) {
TreeNode n = todo.pop();
// Don't search nodes that don't exist
if (n == null || n.isEmpty()) continue;
// If target is to the right of the rightmost point of any interval
// in this node and all children, there won't be any matches.
if (target.first.compareTo(n.maxEnd) > 0)
continue;
// Check this node
if (n.value.getInterval().overlaps(target)) {
return true;
}
// Search left children
if (n.left != null) {
todo.add(n.left);
}
// If target is to the left of the start of this interval,
// then it can't be in any child to the right.
if (target.second.compareTo(n.value.getInterval().first()) < 0) {
continue;
}
if (n.right != null) {
todo.add(n.right);
}
}
return false;
}
public boolean contains(T target) {
return containsValue(this, target);
}
public boolean containsInterval(T target, boolean exact) {
return containsInterval(this, target.getInterval(), exact);
}
public static , T extends HasInterval> boolean containsInterval(IntervalTree n, E p, boolean exact) {
return containsInterval(n, Interval.toInterval(p, p), exact);
}
public static , T extends HasInterval> boolean containsInterval(IntervalTree node, Interval target, boolean exact) {
Predicate containsTargetFunction = new ContainsIntervalFunction(target, exact);
return contains(node, target.getInterval(), containsTargetFunction);
}
public static , T extends HasInterval> boolean containsValue(IntervalTree node, T target) {
Predicate containsTargetFunction = new ContainsValueFunction(target);
return contains(node, target.getInterval(), containsTargetFunction);
}
private static class ContainsValueFunction, T extends HasInterval>
implements Predicate {
private T target;
public ContainsValueFunction(T target) {
this.target = target;
}
@Override
public boolean test(T in) {
return in.equals(target);
}
}
private static class ContainsIntervalFunction, T extends HasInterval>
implements Predicate {
private Interval target;
private boolean exact;
public ContainsIntervalFunction(Interval target, boolean exact) {
this.target = target;
this.exact = exact;
}
@Override
public boolean test(T in) {
if (exact) {
return in.getInterval().equals(target);
} else {
return in.getInterval().contains(target);
}
}
}
private static , T extends HasInterval>
boolean contains(IntervalTree tree, Interval target, Predicate containsTargetFunction) {
return contains(tree.root, target, containsTargetFunction);
}
private static , T extends HasInterval>
boolean contains(TreeNode node, Interval target, Predicate containsTargetFunction) {
Stack> todo = new Stack<>();
todo.push(node);
// Don't search nodes that don't exist
while (!todo.isEmpty()) {
TreeNode n = todo.pop();
// Don't search nodes that don't exist
if (n == null || n.isEmpty()) continue;
// If target is to the right of the rightmost point of any interval
// in this node and all children, there won't be any matches.
if (target.first.compareTo(n.maxEnd) > 0) {
continue;
}
// Check this node
if (containsTargetFunction.test(n.value))
return true;
if (n.left != null) {
todo.push(n.left);
}
// If target is to the left of the start of this interval, then no need to search right
if (target.second.compareTo(n.value.getInterval().first()) <= 0) {
continue;
}
// Need to check right children
if (n.right != null) {
todo.push(n.right);
}
}
return false;
}
public static > List getNonOverlapping(
List extends T> items, Function super T,Interval> toIntervalFunc)
{
List nonOverlapping = new ArrayList<>();
IntervalTree> intervals = new IntervalTree<>();
for (T item:items) {
Interval i = toIntervalFunc.apply(item);
boolean addOk = intervals.addNonOverlapping(i);
if (addOk) {
nonOverlapping.add(item);
}
}
return nonOverlapping;
}
public static > List getNonOverlapping(
List extends T> items, Function super T,Interval> toIntervalFunc, Comparator super T> compareFunc)
{
List sorted = new ArrayList<>(items);
Collections.sort(sorted, compareFunc);
return getNonOverlapping(sorted, toIntervalFunc);
}
public static , E extends Comparable> List getNonOverlapping(
List extends T> items, Comparator super T> compareFunc)
{
Function> toIntervalFunc = in -> in.getInterval();
return getNonOverlapping(items, toIntervalFunc, compareFunc);
}
public static , E extends Comparable> List getNonOverlapping(
List extends T> items)
{
Function> toIntervalFunc = in -> in.getInterval();
return getNonOverlapping(items, toIntervalFunc);
}
private static class PartialScoredList {
T object;
E lastMatchKey;
int size;
double score;
}
public static > List getNonOverlappingMaxScore(
List extends T> items, Function super T,Interval> toIntervalFunc, ToDoubleFunction super T> scoreFunc)
{
if (items.size() > 1) {
Map> bestNonOverlapping = new TreeMap<>();
for (T item:items) {
Interval itemInterval = toIntervalFunc.apply(item);
E mBegin = itemInterval.getBegin();
E mEnd = itemInterval.getEnd();
PartialScoredList bestk = bestNonOverlapping.get(mEnd);
double itemScore = scoreFunc.applyAsDouble(item);
if (bestk == null) {
bestk = new PartialScoredList<>();
bestk.size = 1;
bestk.score = itemScore;
bestk.object = item;
bestNonOverlapping.put(mEnd, bestk);
}
// Assumes map is ordered
for (E j:bestNonOverlapping.keySet()) {
if (j.compareTo(mBegin) > 0) break;
// Consider adding this match into the bestNonOverlapping strand at j
PartialScoredList bestj = bestNonOverlapping.get(j);
double withMatchScore = bestj.score + itemScore;
boolean better = false;
if (withMatchScore > bestk.score) {
better = true;
} else if (withMatchScore == bestk.score) {
if (bestj.size + 1 < bestk.size) {
better = true;
}
}
if (better) {
bestk.size = bestj.size + 1;
bestk.score = withMatchScore;
bestk.object = item;
bestk.lastMatchKey = j;
}
}
}
PartialScoredList best = null;
for (PartialScoredList v: bestNonOverlapping.values()) {
if (best == null || v.score > best.score) {
best = v;
}
}
List nonOverlapping = new ArrayList<>(best.size);
PartialScoredList prev = best;
while (prev != null) {
if (prev.object != null) {
nonOverlapping.add(prev.object);
}
if (prev.lastMatchKey != null) {
prev = bestNonOverlapping.get(prev.lastMatchKey);
} else {
prev = null;
}
}
Collections.reverse(nonOverlapping);
return nonOverlapping;
} else {
List nonOverlapping = new ArrayList<>(items);
return nonOverlapping;
}
}
public static , E extends Comparable> List getNonOverlappingMaxScore(
List extends T> items, ToDoubleFunction super T> scoreFunc)
{
Function> toIntervalFunc = in -> in.getInterval();
return getNonOverlappingMaxScore(items, toIntervalFunc, scoreFunc);
}
public static > List getNonNested(
List extends T> items, Function super T,Interval> toIntervalFunc, Comparator super T> compareFunc)
{
List sorted = new ArrayList<>(items);
Collections.sort(sorted, compareFunc);
List res = new ArrayList<>();
IntervalTree> intervals = new IntervalTree<>();
for (T item:sorted) {
Interval i = toIntervalFunc.apply(item);
boolean addOk = intervals.addNonNested(i);
if (addOk) {
res.add(item);
} else {
// log.info("Discarding " + item);
}
}
return res;
}
}