All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.broadinstitute.hellbender.engine.spark.SparkSharder Maven / Gradle / Ivy

There is a newer version: 4.6.0.0
Show newest version
package org.broadinstitute.hellbender.engine.spark;

import com.google.common.base.Function;
import com.google.common.collect.*;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.PartitionCoalescer;
import org.apache.spark.rdd.RDD;
import org.broadinstitute.hellbender.engine.Shard;
import org.broadinstitute.hellbender.engine.ShardBoundary;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import scala.Option;
import scala.Tuple2;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

import javax.annotation.Nullable;
import java.io.Serializable;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.broadinstitute.hellbender.utils.IntervalUtils.overlaps;

/**
 * Utility methods for sharding {@link Locatable} objects (such as reads) for given intervals, without using a shuffle.
 */
public class SparkSharder {
    /**
     * Create an RDD of {@link Shard} from an RDD of coordinate sorted {@link Locatable} without using a shuffle.
     * Each shard contains the {@link Locatable} objects that overlap it (including overlapping only padding).
     * @param ctx the Spark Context
     * @param locatables the RDD of {@link Locatable}, must be coordinate sorted
     * @param locatableClass the class of the {@link Locatable} objects in the RDD
     * @param sequenceDictionary the sequence dictionary to use to find contig lengths
     * @param intervals the {@link ShardBoundary} objects to create shards for, must be coordinate sorted
     * @param maxLocatableLength the maximum length of a {@link Locatable}, if any is larger than this size then an exception will be thrown
     * @param  the {@link Locatable} type
     * @param  the {@link ShardBoundary} type
     * @return an RDD of {@link Shard} of overlapping {@link Locatable} objects (including overlapping only padding)
     */
    public static  JavaRDD> shard(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                SAMSequenceDictionary sequenceDictionary, List intervals,
                                                                int maxLocatableLength) {
        return shard(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength, false);
    }

    /**
     * Create an RDD of {@link Shard} from an RDD of coordinate sorted {@link Locatable} without using a shuffle,
     * and where the intervals for shards are specified as an RDD, rather than a list.
     * Each shard contains the {@link Locatable} objects that overlap it (including overlapping only padding).
     * @param ctx the Spark Context
     * @param locatables the RDD of {@link Locatable}, must be coordinate sorted
     * @param locatableClass the class of the {@link Locatable} objects in the RDD
     * @param sequenceDictionary the sequence dictionary to use to find contig lengths
     * @param intervals the {@link ShardBoundary} objects to create shards for, must be coordinate sorted
     * @param maxLocatableLength the maximum length of a {@link Locatable}, if any is larger than this size then an exception will be thrown
     * @param  the {@link Locatable} type
     * @param  the {@link ShardBoundary} type
     * @return an RDD of {@link Shard} of overlapping {@link Locatable} objects (including overlapping only padding)
     */
    public static  JavaRDD> shard(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                SAMSequenceDictionary sequenceDictionary, JavaRDD intervals,
                                                                int maxLocatableLength) {
        return shard(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength, false);
    }

    /**
     * Create an RDD of {@link Shard} from an RDD of coordinate sorted {@link Locatable}, optionally using a shuffle.
     * A shuffle is typically only needed for correctness testing, since it usually has a significant performance impact.
     * @param ctx the Spark Context
     * @param locatables the RDD of {@link Locatable}, must be coordinate sorted
     * @param locatableClass the class of the {@link Locatable} objects in the RDD
     * @param sequenceDictionary the sequence dictionary to use to find contig lengths
     * @param intervals the {@link ShardBoundary} objects to create shards for, must be coordinate sorted
     * @param maxLocatableLength the maximum length of a {@link Locatable}, if any is larger than this size then an exception will be thrown
     * @param useShuffle whether to use a shuffle or not
     * @param  the {@link Locatable} type
     * @param  the {@link ShardBoundary} type
     * @return an RDD of {@link Shard} of overlapping {@link Locatable} objects (including overlapping only padding)
     */
    public static  JavaRDD> shard(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                SAMSequenceDictionary sequenceDictionary, List intervals,
                                                                int maxLocatableLength, boolean useShuffle) {

        List paddedIntervals = intervals.stream().map(ShardBoundary::paddedShardBoundary).collect(Collectors.toList());
        if (useShuffle) {
            OverlapDetector overlapDetector = OverlapDetector.create(paddedIntervals);
            Broadcast> overlapDetectorBroadcast = ctx.broadcast(overlapDetector);
            JavaPairRDD intervalsToLocatables = locatables.flatMapToPair(locatable -> {
                Set overlaps = overlapDetectorBroadcast.getValue().getOverlaps(locatable);
                return overlaps.stream().map(key -> new Tuple2<>(key, locatable)).collect(Collectors.toList()).iterator();
            });
            JavaPairRDD> grouped = intervalsToLocatables.groupByKey();
            return grouped.map((org.apache.spark.api.java.function.Function>, Shard>) value -> value._1().createShard(value._2()));
        }
        return joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, paddedIntervals, maxLocatableLength,
                new MapFunction>, Shard>() {
            private static final long serialVersionUID = 1L;
            @Override
            public Shard call(Tuple2> value) {
                return value._1().createShard(value._2());
            }
        });
    }

    private static  JavaRDD> shard(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                SAMSequenceDictionary sequenceDictionary, JavaRDD intervals,
                                                                int maxLocatableLength, boolean useShuffle) {

        JavaRDD paddedIntervals = intervals.map(ShardBoundary::paddedShardBoundary);
        if (useShuffle) {
            throw new UnsupportedOperationException("Shuffle not supported when sharding an RDD of intervals.");
        }
        return joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, paddedIntervals, maxLocatableLength,
                new MapFunction>, Shard>() {
                    private static final long serialVersionUID = 1L;
                    @Override
                    public Shard call(Tuple2> value) {
                        return value._1().createShard(value._2());
                    }
                });
    }

    /**
     * Join an RDD of locatables with a set of intervals, and apply a function to process the locatables that overlap each interval.
     * @param ctx the Spark Context
     * @param locatables the locatables RDD, must be coordinate sorted
     * @param locatableClass the class of the locatables, must be a subclass of {@link Locatable}
     * @param sequenceDictionary the sequence dictionary to use to find contig lengths
     * @param intervals the collection of intervals to apply the function to
     * @param maxLocatableLength the maximum length of a {@link Locatable}, if any is larger than this size then an exception will be thrown
     * @param f the function to process intervals and overlapping locatables with
     * @param  the {@link Locatable} type
     * @param  the interval type
     * @param  the return type of f
     * @return
     */
    private static  JavaRDD joinOverlapping(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                                            SAMSequenceDictionary sequenceDictionary, List intervals,
                                                                                            int maxLocatableLength, MapFunction>, T> f) {
        return joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength,
                (FlatMapFunction2, Iterator, T>) (locatablesIterator, shardsIterator) -> Iterators.transform(locatablesPerShard(locatablesIterator, shardsIterator, sequenceDictionary, maxLocatableLength), new Function>, T>() {
                    @Nullable
                    @Override
                    public T apply(@Nullable Tuple2> input) {
                        try {
                            return f.call(input);
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                }));
    }

    private static  JavaRDD joinOverlapping(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                                            SAMSequenceDictionary sequenceDictionary, JavaRDD intervals,
                                                                                            int maxLocatableLength, MapFunction>, T> f) {
        return joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, intervals, maxLocatableLength,
                (FlatMapFunction2, Iterator, T>) (locatablesIterator, shardsIterator) -> Iterators.transform(locatablesPerShard(locatablesIterator, shardsIterator, sequenceDictionary, maxLocatableLength), new Function>, T>() {
                    @Nullable
                    @Override
                    public T apply(@Nullable Tuple2> input) {
                        try {
                            return f.call(input);
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                }));
    }

    /**
     * Join an RDD of locatables with a set of intervals, and apply a function to process the locatables that overlap each interval.
     * This differs from {@link #joinOverlapping(JavaSparkContext, JavaRDD, Class, SAMSequenceDictionary, List, int, MapFunction)}
     * in that the function to apply is given two iterators: one over intervals, and one over locatables (for the partition),
     * and it is up to the function implemention to find overlaps between intervals and locatables.
     * @param ctx the Spark Context
     * @param locatables the locatables RDD, must be coordinate sorted
     * @param locatableClass the class of the locatables, must be a subclass of {@link Locatable}
     * @param sequenceDictionary the sequence dictionary to use to find contig lengths
     * @param intervals the collection of intervals to apply the function to
     * @param maxLocatableLength the maximum length of a {@link Locatable}, if any is larger than this size then an exception will be thrown
     * @param f the function to process intervals and overlapping locatables with
     * @param  the {@link Locatable} type
     * @param  the interval type
     * @param  the return type of f
     * @return
     */
    private static  JavaRDD joinOverlapping(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                                            SAMSequenceDictionary sequenceDictionary, List intervals,
                                                                                            int maxLocatableLength, FlatMapFunction2, Iterator, T> f) {

        return joinOverlapping(ctx, locatables, locatableClass, sequenceDictionary, ctx.parallelize(intervals), maxLocatableLength, f);
    }

    private static  JavaRDD joinOverlapping(JavaSparkContext ctx, JavaRDD locatables, Class locatableClass,
                                                                                            SAMSequenceDictionary sequenceDictionary, JavaRDD intervals,
                                                                                            int maxLocatableLength, FlatMapFunction2, Iterator, T> f) {

        List> partitionReadExtents = computePartitionReadExtents(locatables, sequenceDictionary, maxLocatableLength);
        List firstLocatablesList = partitionReadExtents.stream().map(PartitionLocatable::getLocatable).collect(Collectors.toList());
        Broadcast> firstLocatablesBroadcast = ctx.broadcast(firstLocatablesList);

        // For each interval find which partition it starts and ends in.
        // An interval is processed in the partition it starts in. However, we need to make sure that
        // subsequent partitions are coalesced if needed, so for each partition p find the latest subsequent
        // partition that is needed to read all of the intervals that start in p.
        OverlapDetector> overlapDetector = OverlapDetector.create(partitionReadExtents);
        Broadcast>> overlapDetectorBroadcast = ctx.broadcast(overlapDetector);
        JavaRDD> indexedIntervals = intervals.map(interval -> {
            int[] partitionIndexes = overlapDetectorBroadcast.getValue().getOverlaps(interval).stream()
                    .mapToInt(PartitionLocatable::getPartitionIndex).toArray();
            if (partitionIndexes.length == 0) {
                final List firstLocatables = firstLocatablesBroadcast.getValue();
                // interval does not overlap any partition - add it to the one after the interval start
                int i = Collections.binarySearch(firstLocatables, new SimpleInterval(interval), (o1, o2) -> IntervalUtils.compareLocatables(o1, o2, sequenceDictionary));
                if (i >= 0) {
                    throw new IllegalStateException(); // TODO: no overlaps, yet start of interval matches a partition read extent start
                }
                int insertionPoint = -i - 1;
                if (insertionPoint == firstLocatables.size()) {
                    insertionPoint = firstLocatables.size() - 1;
                }
                return new PartitionLocatable<>(insertionPoint, interval);
            }
            Arrays.sort(partitionIndexes);
            int startIndex = partitionIndexes[0];
            int endIndex = partitionIndexes[partitionIndexes.length - 1];
            return new PartitionLocatable<>(startIndex, endIndex, interval);
        });

        // Create an RDD of intervals with the same number of partitions as the locatables, and where each interval
        // is in its start partition. Within each partition, intervals are sorted by IntervalUtils#compareLocatables.
        JavaRDD> indexedIntervalsRepartitioned = indexedIntervals
                .mapToPair(interval ->
                        new Tuple2<>(interval, (Void) null))
                .repartitionAndSortWithinPartitions(new PartitionLocatablePartitioner(locatables.getNumPartitions()), new PartitionLocatableComparator(sequenceDictionary))
                .keys();

        indexedIntervalsRepartitioned.cache(); // cache since we need to do two calculations on the intervals

        // Find the end partition index for each partition.
        Map maxEndPartitionIndexesMap = indexedIntervalsRepartitioned.mapToPair((PairFunction, Integer, Integer>) partitionLocatable ->
                new Tuple2<>(partitionLocatable.getPartitionIndex(), partitionLocatable.getEndPartitionIndex()))
                .reduceByKey((Function2) Math::max)
                .collectAsMap();
        List maxEndPartitionIndexes = IntStream.range(0, locatables.getNumPartitions()).boxed().collect(Collectors.toList());
        maxEndPartitionIndexesMap.forEach((startIndex, endIndex) -> {
            if (endIndex > maxEndPartitionIndexes.get(startIndex)) {
                maxEndPartitionIndexes.set(startIndex, endIndex);
            }
        });

        JavaRDD coalescedRdd = coalesce(locatables, locatableClass, new RangePartitionCoalescer(maxEndPartitionIndexes));

        // zipPartitions on coalesced locatable partitions and intervals, and apply the function f
        return coalescedRdd.zipPartitions(indexedIntervalsRepartitioned.map(PartitionLocatable::getLocatable), f);
    }

    /**
     * Turn a pair of iterators over intervals and locatables, into a single iterator over pairs made up of an interval and
     * the locatables that overlap it. Intervals with no overlapping locatables are included.
     */
    static  Iterator>> locatablesPerShard(Iterator locatables, Iterator shards, SAMSequenceDictionary sequenceDictionary, int maxLocatableLength) {
        if (!shards.hasNext()) {
            return Collections.emptyIterator();
        }
        PeekingIterator peekingShards = Iterators.peekingIterator(shards);
        Iterator>> iterator = new AbstractIterator>>() {
            Queue> pendingShards = new ArrayDeque<>();

            @Override
            protected Tuple2> computeNext() {
                Tuple2> nextShard = null;

                while (locatables.hasNext() && nextShard == null) {
                    L locatable = locatables.next();
                    if (locatable.getContig() != null) {
                        int size = locatable.getEnd() - locatable.getStart() + 1;
                        if (size > maxLocatableLength) {
                            throw new UserException(String.format("Max size of locatable exceeded. Max size is %s, but locatable size is %s. Try increasing shard size and/or padding. Locatable: %s", maxLocatableLength, size, locatable));
                        }
                    }

                    // Add any new shards that start before the end of the read to the queue
                    while (peekingShards.hasNext() && !IntervalUtils.isAfter(peekingShards.peek(), locatable, sequenceDictionary)) {
                        pendingShards.add(new PendingShard<>(peekingShards.next()));
                    }

                    // Add the read to any shards that it overlaps
                    for (PendingShard pendingShard : pendingShards) {
                        if (overlaps(pendingShard, locatable)) {
                            pendingShard.addLocatable(locatable);
                        }
                    }

                    // A pending shard only becomes ready once our reads iterator has advanced beyond the end of its extended span
                    // (this ensures that we've loaded all reads that belong in the new shard)
                    if (!pendingShards.isEmpty() && IntervalUtils.isAfter(locatable, pendingShards.peek(), sequenceDictionary)) {
                        nextShard = pendingShards.poll().get();
                    }
                }

                if (!locatables.hasNext()) {
                    // Pull on intervals until it is exhausted
                    while (peekingShards.hasNext()) {
                        pendingShards.add(new PendingShard<>(peekingShards.next()));
                    }

                    // Grab the next pending shard if there is one, unless we already have a shard ready to go
                    if (!pendingShards.isEmpty() && nextShard == null) {
                        nextShard = pendingShards.poll().get();
                    }
                }

                if (nextShard == null) {
                    return endOfData();
                }

                return nextShard;
            }
        };
        return iterator;
    }

    private static class PendingShard implements Locatable {
        private I interval;
        private List locatables = new ArrayList<>();

        public PendingShard(I interval) {
            this.interval = interval;
        }

        public void addLocatable(L locatable) {
            locatables.add(locatable);
        }

        @Override
        public String getContig() {
            return interval.getContig();
        }

        @Override
        public int getStart() {
            return interval.getStart();
        }

        @Override
        public int getEnd() {
            return interval.getEnd();
        }

        public Tuple2> get() {
            return new Tuple2<>(interval, locatables);
        }
    }

    /**
     * @return true if the locatable is to the right of the given interval
     */
    private static  boolean toRightOf(I interval, L locatable, SAMSequenceDictionary sequenceDictionary) {
        int intervalContigIndex = sequenceDictionary.getSequenceIndex(interval.getContig());
        int locatableContigIndex = sequenceDictionary.getSequenceIndex(locatable.getContig());
        return (intervalContigIndex == locatableContigIndex && interval.getEnd() < locatable.getStart()) // locatable on same contig, to the right
                || intervalContigIndex < locatableContigIndex; // locatable on subsequent contig
    }

    /**
     * For each partition, find the interval that spans it, ordered by start position.
     */
    static  List> computePartitionReadExtents(JavaRDD locatables, SAMSequenceDictionary sequenceDictionary, int maxLocatableLength) {
        // Find the first locatable in each partition. This is very efficient since only the first record in each partition is read.
        // If a partition is empty then set the locatable to null
        List> allSplitPoints = locatables.mapPartitions(
                (FlatMapFunction, PartitionLocatable>) it -> ImmutableList.of(new PartitionLocatable<>(-1, it.hasNext() ? it.next() : null)).iterator()
        ).collect();
        List> splitPoints = new ArrayList<>(); // fill in index and remove nulls (empty partitions)
        for (int i = 0; i < allSplitPoints.size(); i++) {
            L locatable = allSplitPoints.get(i).getLocatable();
            if (locatable != null) {
                splitPoints.add(new PartitionLocatable(i, locatable));
            }
        }
        List> extents = new ArrayList<>();
        for (int i = 0; i < splitPoints.size(); i++) {
            PartitionLocatable splitPoint = splitPoints.get(i);
            int partitionIndex = splitPoint.getPartitionIndex();
            Locatable current = splitPoint.getLocatable();
            int intervalContigIndex = sequenceDictionary.getSequenceIndex(current.getContig());
            Utils.validate(intervalContigIndex != -1, "Contig not found in sequence dictionary: " + current.getContig());
            final Locatable next;
            final int nextContigIndex;
            if (i < splitPoints.size() - 1) {
                next = splitPoints.get(i + 1);
                nextContigIndex = sequenceDictionary.getSequenceIndex(next.getContig());
                Utils.validate(nextContigIndex != -1, "Contig not found in sequence dictionary: " + next.getContig());
            } else {
                next = null;
                nextContigIndex = sequenceDictionary.getSequences().size();
            }
            if (intervalContigIndex == nextContigIndex) { // same contig
                addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), next.getStart() + maxLocatableLength);
            } else {
                // complete current contig
                SAMSequenceRecord seq = sequenceDictionary.getSequence(current.getContig());
                Utils.validate(seq != null, "Contig not found in sequence dictionary: " + current.getContig());
                int contigEnd = seq.getSequenceLength();
                addPartitionReadExtent(extents, partitionIndex, current.getContig(), current.getStart(), contigEnd);
                // add any whole contigs up to next (exclusive)
                for (int contigIndex = intervalContigIndex + 1; contigIndex < nextContigIndex; contigIndex++) {
                    SAMSequenceRecord sequence = sequenceDictionary.getSequence(contigIndex);
                    Utils.validate(sequence != null, "Contig index not found in sequence dictionary: " + contigIndex);
                    addPartitionReadExtent(extents, partitionIndex, sequence.getSequenceName(), 1, sequence.getSequenceLength());
                }
                // add start of next contig
                if (next != null) {
                    addPartitionReadExtent(extents, partitionIndex, next.getContig(), 1, next.getStart() + maxLocatableLength);
                }
            }
        }
        return extents;
    }

    private static void addPartitionReadExtent(List> extents, int partitionIndex, String contig, int start, int end) {
        SimpleInterval extent = new SimpleInterval(contig, start, end);
        extents.add(new PartitionLocatable<>(partitionIndex, extent));
    }

    private static  JavaRDD coalesce(JavaRDD rdd, Class cls, PartitionCoalescer partitionCoalescer) {
        RDD coalescedRdd = rdd.rdd().coalesce(rdd.getNumPartitions(), false, Option.apply(partitionCoalescer), null);
        ClassTag tag = ClassTag$.MODULE$.apply(cls);
        return new JavaRDD<>(coalescedRdd, tag);
    }

    /**
     * Assigns {@link PartitionLocatable} objects to their start partition.
     */
    private static class PartitionLocatablePartitioner extends Partitioner {
        private static final long serialVersionUID = 1L;

        private int numPartitions;

        public PartitionLocatablePartitioner(int numPartitions) {
            this.numPartitions = numPartitions;
        }

        @Override
        public int numPartitions() {
            return numPartitions;
        }

        @Override
        public int getPartition(Object key) {
            return ((PartitionLocatable) key).getPartitionIndex();
        }
    }

    /**
     * Compares {@link PartitionLocatable} objects using a {@link htsjdk.samtools.SAMSequenceDictionary} sequence ordering.
     * @param  the interval type
     */
    private static class PartitionLocatableComparator implements Comparator>, Serializable {
        private static final long serialVersionUID = 1L;
        private final SAMSequenceDictionary sequenceDictionary;

        private PartitionLocatableComparator(SAMSequenceDictionary sequenceDictionary) {
            this.sequenceDictionary = sequenceDictionary;
        }

        @Override
        public int compare(PartitionLocatable pl1, PartitionLocatable pl2) {
            return IntervalUtils.compareLocatables(pl1.getLocatable(), pl2.getLocatable(), this.sequenceDictionary);
        }
    }

    /**
     * Encapsulates the start and end partitions for an interval.
     * @param  the interval type
     */
    static class PartitionLocatable implements Locatable {
        private static final long serialVersionUID = 1L;

        private final int partitionIndex;
        private final int endPartitionIndex;
        private final L interval;

        public PartitionLocatable(int partitionIndex, L interval) {
            this(partitionIndex, partitionIndex, interval);
        }

        public PartitionLocatable(int partitionIndex, int endPartitionIndex, L interval) {
            this.partitionIndex = partitionIndex;
            this.endPartitionIndex = endPartitionIndex;
            this.interval = interval;
        }

        public int getPartitionIndex() {
            return partitionIndex;
        }

        public int getEndPartitionIndex() {
            return endPartitionIndex;
        }

        public L getLocatable() {
            return interval;
        }

        @Override
        public String getContig() {
            return interval.getContig();
        }

        @Override
        public int getStart() {
            return interval.getStart();
        }

        @Override
        public int getEnd() {
            return interval.getEnd();
        }

        @Override
        public String toString() {
            return "PartitionLocatable{" +
                    "partitionIndex=" + partitionIndex +
                    ", interval='" + interval + '\'' +
                    '}';
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            PartitionLocatable that = (PartitionLocatable) o;

            if (partitionIndex != that.partitionIndex) return false;
            return interval.equals(that.interval);

        }

        @Override
        public int hashCode() {
            int result = partitionIndex;
            result = 31 * result + interval.hashCode();
            return result;
        }
    }
}