All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.utils.IntervalTree Maven / Gradle / Ivy

Go to download

A fork of the Apache Cassandra Project that uses Lucene indexes for providing near real time search such as ElasticSearch or Solr, including full text search capabilities, multi-dimensional queries, and relevance scoring.

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.utils;

import java.io.DataInput;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.*;

import com.google.common.base.Joiner;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterators;
import com.google.common.collect.Ordering;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.db.TypeSizes;
import org.apache.cassandra.io.ISerializer;
import org.apache.cassandra.io.IVersionedSerializer;
import org.apache.cassandra.io.util.DataOutputPlus;

public class IntervalTree> implements Iterable
{
    private static final Logger logger = LoggerFactory.getLogger(IntervalTree.class);

    @SuppressWarnings("unchecked")
    private static final IntervalTree EMPTY_TREE = new IntervalTree(null, null);

    private final IntervalNode head;
    private final int count;
    private final Comparator comparator;

    final Ordering minOrdering;
    final Ordering maxOrdering;

    protected IntervalTree(Collection intervals, Comparator comparator)
    {
        this.comparator = comparator;

        final IntervalTree it = this;
        this.minOrdering = new Ordering()
        {
            public int compare(I interval1, I interval2)
            {
                return it.comparePoints(interval1.min, interval2.min);
            }
        };
        this.maxOrdering = new Ordering()
        {
            public int compare(I interval1, I interval2)
            {
                return it.comparePoints(interval1.max, interval2.max);
            }
        };

        this.head = intervals == null || intervals.isEmpty() ? null : new IntervalNode(intervals);
        this.count = intervals == null ? 0 : intervals.size();
    }

    public static > IntervalTree build(Collection intervals, Comparator comparator)
    {
        if (intervals == null || intervals.isEmpty())
            return emptyTree();

        return new IntervalTree(intervals, comparator);
    }

    public static , D, I extends Interval> IntervalTree build(Collection intervals)
    {
        if (intervals == null || intervals.isEmpty())
            return emptyTree();

        return new IntervalTree(intervals, null);
    }

    public static > Serializer serializer(ISerializer pointSerializer, ISerializer dataSerializer, Constructor constructor)
    {
        return new Serializer<>(pointSerializer, dataSerializer, constructor);
    }

    @SuppressWarnings("unchecked")
    public static > IntervalTree emptyTree()
    {
        return (IntervalTree)EMPTY_TREE;
    }

    public Comparator comparator()
    {
        return comparator;
    }

    public int intervalCount()
    {
        return count;
    }

    public boolean isEmpty()
    {
        return head == null;
    }

    public C max()
    {
        if (head == null)
            throw new IllegalStateException();

        return head.high;
    }

    public C min()
    {
        if (head == null)
            throw new IllegalStateException();

        return head.low;
    }

    public List search(Interval searchInterval)
    {
        if (head == null)
            return Collections.emptyList();

        List results = new ArrayList();
        head.searchInternal(searchInterval, results);
        return results;
    }

    public List search(C point)
    {
        return search(Interval.create(point, point, null));
    }

    public Iterator iterator()
    {
        if (head == null)
            return Iterators.emptyIterator();

        return new TreeIterator(head);
    }

    @Override
    public String toString()
    {
        return "<" + Joiner.on(", ").join(this) + ">";
    }

    @Override
    public boolean equals(Object o)
    {
        if(!(o instanceof IntervalTree))
            return false;
        IntervalTree that = (IntervalTree)o;
        return Iterators.elementsEqual(iterator(), that.iterator());
    }

    @Override
    public final int hashCode()
    {
        int result = comparator.hashCode();
        for (Interval interval : this)
            result = 31 * result + interval.hashCode();
        return result;
    }

    private int comparePoints(C point1, C point2)
    {
        if (comparator != null)
        {
            return comparator.compare(point1, point2);
        }
        else
        {
            assert point1 instanceof Comparable;
            assert point2 instanceof Comparable;
            return ((Comparable)point1).compareTo(point2);
        }
    }

    private boolean encloses(Interval enclosing, Interval enclosed)
    {
        return comparePoints(enclosing.min, enclosed.min) <= 0
            && comparePoints(enclosing.max, enclosed.max) >= 0;
    }

    private boolean contains(Interval interval, C point)
    {
        return comparePoints(interval.min, point) <= 0
            && comparePoints(interval.max, point) >= 0;
    }

    private boolean intersects(Interval interval1, Interval interval2)
    {
        return contains(interval1, interval2.min) || contains(interval1, interval2.max);
    }

    private class IntervalNode
    {
        final C center;
        final C low;
        final C high;

        final List intersectsLeft;
        final List intersectsRight;

        final IntervalNode left;
        final IntervalNode right;

        public IntervalNode(Collection toBisect)
        {
            assert !toBisect.isEmpty();
            logger.trace("Creating IntervalNode from {}", toBisect);

            // Building IntervalTree with one interval will be a reasonably
            // common case for range tombstones, so it's worth optimizing
            if (toBisect.size() == 1)
            {
                I interval = toBisect.iterator().next();
                low = interval.min;
                center = interval.max;
                high = interval.max;
                List l = Collections.singletonList(interval);
                intersectsLeft = l;
                intersectsRight = l;
                left = null;
                right = null;
            }
            else
            {
                // Find min, median and max
                List allEndpoints = new ArrayList(toBisect.size() * 2);
                for (I interval : toBisect)
                {
                    assert (comparator == null ? ((Comparable)interval.min).compareTo(interval.max)
                                               : comparator.compare(interval.min, interval.max)) <= 0 : "Interval min > max";
                    allEndpoints.add(interval.min);
                    allEndpoints.add(interval.max);
                }
                if (comparator != null)
                    Collections.sort(allEndpoints, comparator);
                else
                    Collections.sort((List)allEndpoints);

                low = allEndpoints.get(0);
                center = allEndpoints.get(toBisect.size());
                high = allEndpoints.get(allEndpoints.size() - 1);

                // Separate interval in intersecting center, left of center and right of center
                List intersects = new ArrayList();
                List leftSegment = new ArrayList();
                List rightSegment = new ArrayList();

                for (I candidate : toBisect)
                {
                    if (comparePoints(candidate.max, center) < 0)
                        leftSegment.add(candidate);
                    else if (comparePoints(candidate.min, center) > 0)
                        rightSegment.add(candidate);
                    else
                        intersects.add(candidate);
                }

                intersectsLeft = minOrdering.sortedCopy(intersects);
                intersectsRight = maxOrdering.reverse().sortedCopy(intersects);
                left = leftSegment.isEmpty() ? null : new IntervalNode(leftSegment);
                right = rightSegment.isEmpty() ? null : new IntervalNode(rightSegment);

                assert (intersects.size() + leftSegment.size() + rightSegment.size()) == toBisect.size() :
                        "intersects (" + String.valueOf(intersects.size()) +
                        ") + leftSegment (" + String.valueOf(leftSegment.size()) +
                        ") + rightSegment (" + String.valueOf(rightSegment.size()) +
                        ") != toBisect (" + String.valueOf(toBisect.size()) + ")";
            }
        }

        void searchInternal(Interval searchInterval, List results)
        {
            if (comparePoints(searchInterval.max, low) < 0 || comparePoints(searchInterval.min, high) > 0)
                return;

            if (contains(searchInterval, center))
            {
                // Adds every interval contained in this node to the result set then search left and right for further
                // overlapping intervals
                for (Interval interval : intersectsLeft)
                    results.add(interval.data);

                if (left != null)
                    left.searchInternal(searchInterval, results);
                if (right != null)
                    right.searchInternal(searchInterval, results);
            }
            else if (comparePoints(center, searchInterval.min) < 0)
            {
                // Adds intervals i in intersects right as long as i.max >= searchInterval.min
                // then search right
                for (Interval interval : intersectsRight)
                {
                    if (comparePoints(interval.max, searchInterval.min) >= 0)
                        results.add(interval.data);
                    else
                        break;
                }
                if (right != null)
                    right.searchInternal(searchInterval, results);
            }
            else
            {
                assert comparePoints(center, searchInterval.max) > 0;
                // Adds intervals i in intersects left as long as i.min >= searchInterval.max
                // then search left
                for (Interval interval : intersectsLeft)
                {
                    if (comparePoints(interval.min, searchInterval.max) <= 0)
                        results.add(interval.data);
                    else
                        break;
                }
                if (left != null)
                    left.searchInternal(searchInterval, results);
            }
        }
    }

    private class TreeIterator extends AbstractIterator
    {
        private final Deque stack = new ArrayDeque();
        private Iterator current;

        TreeIterator(IntervalNode node)
        {
            super();
            gotoMinOf(node);
        }

        protected I computeNext()
        {
            if (current != null && current.hasNext())
                return current.next();

            IntervalNode node = stack.pollFirst();
            if (node == null)
                return endOfData();

            current = node.intersectsLeft.iterator();

            // We know this is the smaller not returned yet, but before doing
            // its parent, we must do everyone on it's right.
            gotoMinOf(node.right);

            return computeNext();
        }

        private void gotoMinOf(IntervalNode node)
        {
            while (node != null)
            {
                stack.offerFirst(node);
                node = node.left;
            }

        }
    }

    public static class Serializer> implements IVersionedSerializer>
    {
        private final ISerializer pointSerializer;
        private final ISerializer dataSerializer;
        private final Constructor constructor;

        private Serializer(ISerializer pointSerializer, ISerializer dataSerializer, Constructor constructor)
        {
            this.pointSerializer = pointSerializer;
            this.dataSerializer = dataSerializer;
            this.constructor = constructor;
        }

        public void serialize(IntervalTree it, DataOutputPlus out, int version) throws IOException
        {
            out.writeInt(it.count);
            for (Interval interval : it)
            {
                pointSerializer.serialize(interval.min, out);
                pointSerializer.serialize(interval.max, out);
                dataSerializer.serialize(interval.data, out);
            }
        }

        /**
         * Deserialize an IntervalTree whose keys use the natural ordering.
         * Use deserialize(DataInput, int, Comparator) instead if the interval
         * tree is to use a custom comparator, as the comparator is *not*
         * serialized.
         */
        public IntervalTree deserialize(DataInput in, int version) throws IOException
        {
            return deserialize(in, version, null);
        }

        public IntervalTree deserialize(DataInput in, int version, Comparator comparator) throws IOException
        {
            try
            {
                int count = in.readInt();
                List> intervals = new ArrayList>(count);
                for (int i = 0; i < count; i++)
                {
                    C min = pointSerializer.deserialize(in);
                    C max = pointSerializer.deserialize(in);
                    D data = dataSerializer.deserialize(in);
                    intervals.add(constructor.newInstance(min, max, data));
                }
                return new IntervalTree(intervals, comparator);
            }
            catch (InstantiationException e)
            {
                throw new RuntimeException(e);
            }
            catch (InvocationTargetException e)
            {
                throw new RuntimeException(e);
            }
            catch (IllegalAccessException e)
            {
                throw new RuntimeException(e);
            }
        }

        public long serializedSize(IntervalTree it, TypeSizes typeSizes, int version)
        {
            long size = typeSizes.sizeof(0);
            for (Interval interval : it)
            {
                size += pointSerializer.serializedSize(interval.min, typeSizes);
                size += pointSerializer.serializedSize(interval.max, typeSizes);
                size += dataSerializer.serializedSize(interval.data, typeSizes);
            }
            return size;
        }

        public long serializedSize(IntervalTree it, int version)
        {
            return serializedSize(it, TypeSizes.NATIVE, version);
        }
    }
}