src.it.unimi.dsi.util.ImmutableBinaryTrie Maven / Gradle / Ivy
package it.unimi.dsi.util;
/*
* DSI utilities
*
* Copyright (C) 2005-2017 Sebastiano Vigna
*
* This library is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as published by the Free
* Software Foundation; either version 3 of the License, or (at your option)
* any later version.
*
* This library is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program; if not, see .
*
*/
import it.unimi.dsi.bits.BitVector;
import it.unimi.dsi.bits.LongArrayBitVector;
import it.unimi.dsi.bits.TransformationStrategy;
import it.unimi.dsi.fastutil.booleans.BooleanIterator;
import it.unimi.dsi.fastutil.objects.AbstractObject2LongFunction;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.fastutil.objects.ObjectList;
import it.unimi.dsi.lang.MutableString;
import java.io.Serializable;
import java.util.Iterator;
import java.util.ListIterator;
/** An immutable implementation of binary tries.
*
* Instance of this class are built starting from a lexicographically ordered
* list of {@link BitVector}s representing binary words. Each word
* is assigned its position (starting from 0) in the list. The words are then organised in a
* binary trie with path compression.
*
*
Once the trie has been
* built, it is possible to ask whether a word w is {@linkplain #get(BooleanIterator) contained in the trie}
* (getting back its position in the list), the {@linkplain #getInterval(BooleanIterator) interval given by the words extending w} and the
* {@linkplain #getApproximatedInterval(BooleanIterator) approximated interval defined by w}.
* @author Sebastiano Vigna
* @since 0.9.2
*/
public class ImmutableBinaryTrie extends AbstractObject2LongFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final static boolean ASSERTS = false;
/** A node in the trie. */
protected static class Node implements Serializable {
private static final long serialVersionUID = 1L;
public Node left, right;
/** An array containing the path compacted in this node ({@code null} if there is no compaction at this node). */
public final long[] path;
/** The length of the path compacted in this node (0 if there is no compaction at this node). */
public final int pathLength;
/** If nonnegative, this node represent the word
-th word. */
public final int word ;
/** Creates a node representing a word.
*
* Note that the long array contained in path
will be stored inside the node.
*
* @param path the path compacted in this node, or {@code null} for the empty path.
* @param word the index of the word represented by this node.
*/
public Node(final BitVector path, final int word) {
if (path == null) {
this.path = null;
this.pathLength = 0;
}
else {
this.path = path.bits();
this.pathLength = (int) path.length();
}
this.word = word;
}
/** Creates a node that does not represent a word.
*
* @param path the path compacted in this node, or {@code null} for the empty path.
*/
public Node(final BitVector path) {
this(path, -1);
}
/** Returns true if this node is a leaf.
*
* @return true if this node is a leaf.
*/
public boolean isLeaf() {
return right == null && left == null;
}
@Override
public String toString() {
return "[" + path + ", " + word + "]";
}
}
/** The root of the trie. */
protected final Node root;
/** The number of words in this trie. */
private int size;
private final TransformationStrategy super T> transformationStrategy;
private static boolean get(final long[] a, final long i) {
return (a[(int)(i >>> LongArrayBitVector.LOG2_BITS_PER_WORD)] & (1L << (i & LongArrayBitVector.WORD_MASK))) != 0;
}
/** Creates a trie from a set of elements.
*
* @param elements a set of elements
* @param transformationStrategy a transformation strategy that must turn elements
into a list of
* distinct, lexicographically increasing (in iteration order) binary words.
*/
public ImmutableBinaryTrie(final Iterable extends T> elements, final TransformationStrategy super T> transformationStrategy) {
this.transformationStrategy = transformationStrategy;
defRetValue = -1;
// Check order
final Iterator extends T> iterator = elements.iterator();
final ObjectList words = new ObjectArrayList<>();
int cmp;
if (iterator.hasNext()) {
final LongArrayBitVector prev = LongArrayBitVector.copy(transformationStrategy.toBitVector(iterator.next()));
words.add(prev.copy());
BitVector curr;
while(iterator.hasNext()) {
curr = transformationStrategy.toBitVector(iterator.next());
cmp = prev.compareTo(curr);
if (cmp == 0) throw new IllegalArgumentException("The trie elements are not unique");
if (cmp > 0) throw new IllegalArgumentException("The trie elements are not sorted");
prev.replace(curr);
words.add(prev.copy());
}
}
root = buildTrie(words, 0);
}
/** Builds a trie recursively.
*
* The trie will contain the suffixes of words in words
starting at pos
.
*
* @param elements a list of elements.
* @param pos a starting position.
* @return a trie containing the suffixes of words in words
starting at pos
.
*/
protected Node buildTrie(final ObjectList elements, final int pos) {
// TODO: on-the-fly check for lexicographical order
if (elements.size() == 0) return null;
BitVector first = elements.get(0), curr;
int prefix = (int) first.length(), change = -1, j;
// We rule out the case of a single word (it would work anyway, but it's faster)
if (elements.size() == 1) return new Node(pos < prefix ? LongArrayBitVector.copy(first.subVector(pos, prefix)) : null, size++);
// Find maximum common prefix. change records the point of change (for splitting the word set).
for(ListIterator i = elements.listIterator(1); i.hasNext();) {
curr = i.next();
if (curr.length() < prefix) prefix = (int) curr.length();
for(j = pos; j < prefix; j++) if (first.getBoolean(j) != curr.getBoolean(j)) break;
if (j < prefix) {
change = i.previousIndex();
prefix = j;
}
}
final Node n;
if (prefix == first.length()) {
// Special case: the first word is the common prefix. We must store it in the node,
// and explicitly search for the actual change point, which is the first
// word with prefix-th bit true.
change = 1;
for(ListIterator i = elements.listIterator(1); i.hasNext();) {
curr = i.next();
if (curr.getBoolean(prefix)) break;
change++;
}
n = new Node(prefix > pos ? LongArrayBitVector.copy(first.subVector(pos, prefix)) : null, size++);
n.left = buildTrie(elements.subList(1, change), prefix + 1);
n.right = buildTrie(elements.subList(change, elements.size()), prefix + 1);
}
else {
n = new Node(prefix > pos ? LongArrayBitVector.copy(first.subVector(pos, prefix)) : null); // There's some common prefix
n.left = buildTrie(elements.subList(0, change), prefix + 1);
n.right = buildTrie(elements.subList(change, elements.size()), prefix + 1);
}
return n;
}
/** Returns the number of binary words in this trie.
*
* @return the number of binary words in this trie.
*/
@Override
public int size() {
return size;
}
@SuppressWarnings("unchecked")
public long getIndex(final Object element) {
final BitVector word = transformationStrategy.toBitVector((T)element);
final int length = (int) word.length();
Node n = root;
int pos = 0; // Current position in word
long[] path;
while(n != null) {
if (pos == length) return n.word;
path = n.path;
if (path != null) {
int minLength = Math.min(length - pos, n.pathLength), i;
for(i = 0; i < minLength; i++) if (word.getBoolean(pos + i) != get(path, i)) break;
// Incompatible with current path.
if (i < minLength) return -1;
pos += i;
// Completely contained in the current path (note that n.word == -1 if this is not a word).
if (pos == length) return n.word;
}
n = word.getBoolean(pos++) ? n.right : n.left;
}
return -1;
}
@Override
public long getLong(final Object element) {
final long result = getIndex(element);
return result == -1 ? defRetValue : result;
}
@Override
public boolean containsKey(final Object element) {
return getIndex(element) != -1;
}
/** Return the index of the word returned by the given iterator, or -1 if the word is not in this trie.
*
* @param iterator a boolean iterator that will be used to find a word in this trie.
* @return the index of the specified word, or -1 if the word returned by the iterator is not in this trie.
* @see #getLong(Object)
*/
public int get(final BooleanIterator iterator) {
Node n = root;
int pathLength;
long[] path;
while(n != null) {
if (! iterator.hasNext()) return n.word;
pathLength = n.pathLength;
if (pathLength != 0) {
int i;
path = n.path;
for(i = 0; i < pathLength && iterator.hasNext(); i++) if (iterator.nextBoolean() != get(path, i)) break;
// Incompatible with current path.
if (i < pathLength) return -1;
// Completely contained in the current path (note that n.word == -1 if this is not a word).
if (! iterator.hasNext()) return n.word;
}
n = iterator.nextBoolean() ? n.right : n.left;
}
return -1;
}
/** Returns an interval given by the smallest and the largest word in the trie starting with the specified word.
*
* @param word a word.
* @return an interval given by the smallest and the largest word in the trie
* that start with word
(thus, the {@linkplain Intervals#EMPTY_INTERVAL empty inteval}
* if no such words exist).
* @see #getInterval(BooleanIterator)
*/
public Interval getInterval(final BitVector word) {
final int length = (int) word.length();
Node n = root;
long[] path;
int pos = 0; // Current position in word
while(n != null) {
// We found the current path: we go searching for left and right delimiters.
if (pos == length) break;
path = n.path;
if (path != null) {
int maxLength = Math.min(length - pos, n.pathLength);
int i;
for(i = 0; i < maxLength; i++) if (word.getBoolean(pos + i) != get(path, i)) break;
// Incompatible with current path--we return the empty interval.
if (i < maxLength) return Intervals.EMPTY_INTERVAL;
pos += i;
// Completely contained in the current path: we go searching for left and right delimiters.
if (pos == length) break;
}
n = word.getBoolean(pos++) ? n.right : n.left;
}
// If n == null, we did not found the path. Otherwise, it's the current node,
// and we must search for left and right delimiters.
if (n == null) return Intervals.EMPTY_INTERVAL;
Node l = n;
// Searching for the left extreme...
while(l.word < 0) l = l.left != null ? l.left : l.right;
// Searching for the right extreme, unless we're on a leaf.
while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
return Interval.valueOf(l.word, n.word);
}
/** Returns an interval given by the smallest and the largest word in the trie starting with
* the word returned by the given iterator.
*
* @param iterator an iterator.
* @return an interval given by the smallest and the largest word in the trie
* that start with the word returned by iterator
(thus, the {@linkplain Intervals#EMPTY_INTERVAL empty inteval}
* if no such words exist).
* @see #getInterval(BitVector)
*/
public Interval getInterval(final BooleanIterator iterator) {
Node n = root;
boolean mismatch = false;
long[] path;
int pathLength;
while(n != null) {
// We found the current path: we go searching for left and right delimiters.
if (! iterator.hasNext()) break;
pathLength = n.pathLength;
if (pathLength != 0) {
int i;
path = n.path;
for(i = 0; i < pathLength && iterator.hasNext(); i++) if ((mismatch = (iterator.nextBoolean() != get(path, i)))) break;
// Incompatible with current path--we return the empty interval.
if (mismatch) return Intervals.EMPTY_INTERVAL;
// Completely contained in the current path: we go searching for left and right delimiters.
if (! iterator.hasNext()) break;
}
n = iterator.nextBoolean() ? n.right : n.left;
}
// If n == null, we did not found the path. Otherwise, it's the current node,
// and we must search for left and right delimiters.
if (n == null) return Intervals.EMPTY_INTERVAL;
Node l = n;
// Searching for the left extreme...
while(l.word < 0) l = l.left != null ? l.left : l.right;
// Searching for the right extreme, unless we're on a leaf.
while (! n.isLeaf()) n = n.right != null ? n.right : n.left;
return Interval.valueOf(l.word, n.word);
}
/** Returns an approximated interval around the specified word.
*
* Given a word w, the corresponding approximated interval is
* defined as follows: if the words in the approximator are thought of as left interval extremes in a
* larger lexicographically ordered set of words, and we number these word intervals using the
* indices of their left extremes, then the first word extending w would be in the
* word interval given by the left extreme of the interval returned by this method, whereas
* the last word extending w would be in the word interval given by the right
* extreme of the interval returned by this method. If no word in the larger set could possibly extend
* w (because w is smaller than the lexicographically smallest word in the approximator)
* the result is just an {@linkplain it.unimi.dsi.util.Intervals#EMPTY_INTERVAL empty interval}.
*
* @param element an element.
* @return an approximated interval around the specified word.
* @see #getApproximatedInterval(BooleanIterator)
*/
public Interval getApproximatedInterval(final T element) {
final BitVector word = transformationStrategy.toBitVector(element);
final int length = (int) word.length();
Node n = root;
long[] path;
boolean exactMatch = false, mismatch = false, nextBit;
int pos = 0; // Current position in word
while(n != null) {
// We found the current path: we go searching for left and right delimiters.
path = n.path;
if (pos == length) {
if (n.word >= 0 && path == null) exactMatch = true;
break;
}
if (path != null) {
int maxLength = Math.min(length - pos, n.pathLength);
int i;
for(i = 0; i < maxLength; i++) if (mismatch = (word.getBoolean(pos + i) != get(path, i))) break;
if (mismatch) {
// System.err.println("Exit 1");
// A mismatch. In this case, it is guaranteed that all
// strings starting with the prefix examined so far lie
// in a single block. The block index depends, however
// on the bit that went wrong.
if (get(path, i)) {
while(n.word < 0) n = n.left != null ? n.left : n.right;
return n.word > 0 ? Interval.valueOf(n.word - 1) : Intervals.EMPTY_INTERVAL;
}
else {
while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
return Interval.valueOf(n.word);
}
}
pos += i;
// Completely contained in the current path
if (pos == length) {
if (ASSERTS) assert n.pathLength == maxLength;
if (i == n.pathLength && n.word >= 0) exactMatch = true;
break;
}
}
if (n.isLeaf()) break;
nextBit = word.getBoolean(pos++);
// We would like to take an impossible turn. This case is similar to
// prefix mismatches, with subtly different off-by-ones.
if (nextBit && n.right == null) {
while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
return Interval.valueOf(n.word);
}
else if (! nextBit && n.left == null) {
while(n.word < 0) n = n.left != null ? n.left : n.right;
return Interval.valueOf(n.word);
}
n = nextBit ? n.right : n.left;
}
Node l = n;
// Searching for the left extreme...
//System.err.println("Going for exit 2: l:" + l + " n:" + n);
while(l.word < 0) l = l.left != null ? l.left : l.right;
// Searching for the right extreme, unless we're on a leaf.
while (! n.isLeaf()) n = n.right != null ? n.right : n.left;
// System.err.println("Following exit 2: l:" + l + " n:" + n);
// If did not find an exact match and l.word is 0 we are lexicographically before every word.
if (pos == length && ! exactMatch) {
if (l.word == 0) return mismatch ? Intervals.EMPTY_INTERVAL : Interval.valueOf(l.word, n.word);
else return Interval.valueOf(l.word - 1, n.word);
}
// System.err.println("Exit 2 (exactMatch: " + exactMatch +")");
return Interval.valueOf(l.word, n.word);
}
/** Returns an approximated prefix interval around the word returned by the specified iterator.
*
* @param iterator an iterator.
* @return an approximated interval around the specified word: if the words in this trie
* are thought of as left interval extremes in a larger lexicographically ordered set of words,
* and we number these word intervals using the indices of their left extremes,
* then the first word extending word
would be in the word interval given by
* the left extreme of the {@link Interval} returned by this method, whereas
* the last word extending word
would be in the word
* interval given by the right extreme of the {@link Interval} returned by this method.
* @see #getApproximatedInterval(Object)
*/
public Interval getApproximatedInterval(final BooleanIterator iterator) {
Node n = root;
long[] path;
boolean exactMatch = false, mismatch = false, nextBit;
for(;;) {
// We found the current path: we go searching for left and right delimiters.
path = n.path;
if (! iterator.hasNext()) {
if (n.word >= 0 && path == null) exactMatch = true;
break;
}
if (path != null) {
int i;
final int pathSize = n.pathLength;
for(i = 0; i < pathSize && iterator.hasNext(); i++) if ((mismatch = (iterator.nextBoolean() != get(path, i)))) break;
if (mismatch) {
// System.err.println("Exit 1");
// A mismatch. In this case, it is guaranteed that all
// strings starting with the prefix examined so far lie
// in a single block. The block index depends, however
// on the bit that went wrong.
if (get(path, i)) {
while(n.word < 0) n = n.left != null ? n.left : n.right;
return n.word > 0 ? Interval.valueOf(n.word - 1) : Intervals.EMPTY_INTERVAL;
}
else {
while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
return Interval.valueOf(n.word);
}
}
// Completely contained in the current path
if (! iterator.hasNext()) {
if (i == pathSize && n.word >= 0) exactMatch = true;
break;
}
}
if (n.isLeaf()) break;
nextBit = iterator.nextBoolean();
// We would like to take an impossible turn. This case is similar to
// prefix mismatches, with subtly different off-by-ones.
if (nextBit && n.right == null) {
while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
return Interval.valueOf(n.word);
}
else if (! nextBit && n.left == null) {
while(n.word < 0) n = n.left != null ? n.left : n.right;
return Interval.valueOf(n.word);
}
n = nextBit ? n.right : n.left;
}
Node l = n;
// Searching for the left extreme...
//System.err.println("Going for exit 2: l:" + l + " n:" + n);
while(l.word < 0) l = l.left != null ? l.left : l.right;
// Searching for the right extreme, unless we're on a leaf.
while (! n.isLeaf()) n = n.right != null ? n.right : n.left;
// If did not find an exact match and l.word is 0 we are lexicographically before every word.
if (! iterator.hasNext() && ! exactMatch) {
if (l.word == 0) return mismatch ? Intervals.EMPTY_INTERVAL : Interval.valueOf(0);
else return Interval.valueOf(l.word - 1, n.word);
}
// System.err.println("Exit 2 (hasNext: " +iterator.hasNext() + " exactMatch: " + exactMatch +")");
return Interval.valueOf(l.word, n.word);
}
private void recToString(final Node n, final MutableString printPrefix, final MutableString result, final MutableString path, final int level) {
if (n == null) return;
//System.err.println("Called with prefix " + printPrefix);
result.append(printPrefix).append('(').append(level).append(')');
if (n.path != null) {
path.append(LongArrayBitVector.wrap(n.path, n.pathLength));
result.append(" path:").append(LongArrayBitVector.wrap(n.path, n.pathLength));
}
if (n.word >= 0) result.append(" word: ").append(n.word).append(" (").append(path).append(')');
result.append('\n');
path.append('0');
recToString(n.left, printPrefix.append('\t').append("0 => "), result, path, level + 1);
path.charAt(path.length() - 1, '1');
recToString(n.right, printPrefix.replace(printPrefix.length() - 5, printPrefix.length(), "1 => "), result, path, level + 1);
path.delete(path.length() - 1, path.length());
printPrefix.delete(printPrefix.length() - 6, printPrefix.length());
//System.err.println("Path now: " + path + " Going to delete from " + (path.length() - n.pathLength));
path.delete(path.length() - n.pathLength, path.length());
}
@Override
public String toString() {
MutableString s = new MutableString();
recToString(root, new MutableString(), s, new MutableString(), 0);
return s.toString();
}
}