All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.util.bkd.BKDRadixSelector Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.util.bkd;

import java.io.IOException;
import java.util.Arrays;

import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FutureArrays;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.MSBRadixSorter;
import org.apache.lucene.util.RadixSelector;
import org.apache.lucene.util.Selector;
import org.apache.lucene.util.Sorter;

/**
 *
 * Offline Radix selector for BKD tree.
 *
 *  @lucene.internal
 * */
public final class BKDRadixSelector {
  // size of the histogram
  private static final int HISTOGRAM_SIZE = 256;
  // size of the online buffer: 8 KB
  private static final int MAX_SIZE_OFFLINE_BUFFER = 1024 * 8;
  // histogram array
  private final long[] histogram;
  // number of bytes to be sorted: config.bytesPerDim + Integer.BYTES
  private final int bytesSorted;
  // flag to when we are moving to sort on heap
  private final int maxPointsSortInHeap;
  // reusable buffer
  private final byte[] offlineBuffer;
  // holder for partition points
  private final int[] partitionBucket;
  // scratch array to hold temporary data
  private final byte[] scratch;
  // Directory to create new Offline writer
  private final Directory tempDir;
  // prefix for temp files
  private final String tempFileNamePrefix;
  // BKD tree configuration
  private final BKDConfig config;

  /**
   * Sole constructor.
   */
  public BKDRadixSelector(BKDConfig config, int maxPointsSortInHeap, Directory tempDir, String tempFileNamePrefix) {
    this.config = config;
    this.maxPointsSortInHeap = maxPointsSortInHeap;
    this.tempDir = tempDir;
    this.tempFileNamePrefix = tempFileNamePrefix;
    // Selection and sorting is done in a given dimension. In case the value of the dimension are equal
    // between two points we tie break first using the data-only dimensions and if those are still equal
    // we tie-break on the docID. Here we account for all bytes used in the process.
    this.bytesSorted = config.bytesPerDim + (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
    final int numberOfPointsOffline = MAX_SIZE_OFFLINE_BUFFER / config.bytesPerDoc;
    this.offlineBuffer = new byte[numberOfPointsOffline * config.bytesPerDoc];
    this.partitionBucket = new int[bytesSorted];
    this.histogram = new long[HISTOGRAM_SIZE];
    this.scratch = new byte[bytesSorted];
  }

  /**
   *  It uses the provided {@code points} from the given {@code from} to the given {@code to}
   *  to populate the {@code partitionSlices} array holder (length > 1) with two path slices
   *  so the path slice at position 0 contains {@code partition - from} points
   *  where the value of the {@code dim} is lower or equal to the {@code to -from}
   *  points on the slice at position 1.
   *
   *  The {@code dimCommonPrefix} provides a hint for the length of the common prefix length for
   *  the {@code dim} where are partitioning the points.
   *
   *  It return the value of the {@code dim} at the partition point.
   *
   *  If the provided {@code points} is wrapping an {@link OfflinePointWriter}, the
   *  writer is destroyed in the process to save disk space.
   */
  public byte[] select(PathSlice points, PathSlice[] partitionSlices, long from, long to, long partitionPoint, int dim, int dimCommonPrefix) throws IOException {
    checkArgs(from, to, partitionPoint);

    assert partitionSlices.length > 1 : "[partition alices] must be > 1, got " + partitionSlices.length;

    // If we are on heap then we just select on heap
    if (points.writer instanceof HeapPointWriter) {
      byte[] partition = heapRadixSelect((HeapPointWriter) points.writer, dim, Math.toIntExact(from), Math.toIntExact(to),  Math.toIntExact(partitionPoint), dimCommonPrefix);
      partitionSlices[0] = new PathSlice(points.writer, from, partitionPoint - from);
      partitionSlices[1] = new PathSlice(points.writer, partitionPoint, to - partitionPoint);
      return partition;
    }

    OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points.writer;

    try (PointWriter left = getPointWriter(partitionPoint - from, "left" + dim);
         PointWriter right = getPointWriter(to - partitionPoint, "right" + dim)) {
      partitionSlices[0] = new PathSlice(left, 0, partitionPoint - from);
      partitionSlices[1] = new PathSlice(right, 0, to - partitionPoint);
      return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, dimCommonPrefix, dim);
    }
  }

  void checkArgs(long from, long to, long partitionPoint) {
    if (partitionPoint < from) {
      throw new IllegalArgumentException("partitionPoint must be >= from");
    }
    if (partitionPoint >= to) {
      throw new IllegalArgumentException("partitionPoint must be < to");
    }
  }

  private int findCommonPrefixAndHistogram(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{
    // find common prefix
    int commonPrefixPosition = bytesSorted;
    final int offset = dim * config.bytesPerDim;
    try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
      assert commonPrefixPosition > dimCommonPrefix;
      reader.next();
      PointValue pointValue = reader.pointValue();
      BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
      // copy dimension
      System.arraycopy(packedValueDocID.bytes, packedValueDocID.offset + offset, scratch, 0, config.bytesPerDim);
      // copy data dimensions and docID
      System.arraycopy(packedValueDocID.bytes, packedValueDocID.offset + config.packedIndexBytesLength, scratch, config.bytesPerDim, (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES);

      for (long i = from + 1; i < to; i++) {
        reader.next();
        pointValue = reader.pointValue();
        if (commonPrefixPosition == dimCommonPrefix) {
          histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
          // we do not need to check for common prefix anymore,
          // just finish the histogram and break
          for (long j = i + 1; j < to; j++) {
            reader.next();
            pointValue = reader.pointValue();
            histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
          }
          break;
        } else {
          // Check common prefix and adjust histogram
          final int startIndex = (dimCommonPrefix > config.bytesPerDim) ? config.bytesPerDim : dimCommonPrefix;
          final int endIndex = (commonPrefixPosition > config.bytesPerDim) ? config.bytesPerDim : commonPrefixPosition;
          packedValueDocID = pointValue.packedValueDocIDBytes();
          int j = FutureArrays.mismatch(scratch, startIndex, endIndex, packedValueDocID.bytes, packedValueDocID.offset + offset + startIndex, packedValueDocID.offset + offset + endIndex);
          if (j == -1) {
            if (commonPrefixPosition > config.bytesPerDim) {
              // Tie-break on data dimensions + docID
              final int startTieBreak = config.packedIndexBytesLength;
              final int endTieBreak = startTieBreak + commonPrefixPosition - config.bytesPerDim;
              int k = FutureArrays.mismatch(scratch, config.bytesPerDim, commonPrefixPosition,
                      packedValueDocID.bytes, packedValueDocID.offset + startTieBreak, packedValueDocID.offset + endTieBreak);
              if (k != -1) {
                commonPrefixPosition = config.bytesPerDim + k;
                Arrays.fill(histogram, 0);
                histogram[scratch[commonPrefixPosition] & 0xff] = i - from;
              }
            }
          } else {
            commonPrefixPosition = dimCommonPrefix + j;
            Arrays.fill(histogram, 0);
            histogram[scratch[commonPrefixPosition] & 0xff] = i - from;
          }
          if (commonPrefixPosition != bytesSorted) {
            histogram[getBucket(offset, commonPrefixPosition, pointValue)]++;
          }
        }
      }
    }

    // Build partition buckets up to commonPrefix
    for (int i = 0; i < commonPrefixPosition; i++) {
      partitionBucket[i] = scratch[i] & 0xff;
    }
    return commonPrefixPosition;
  }

  private int getBucket(int offset, int commonPrefixPosition, PointValue pointValue) {
    int bucket;
    if (commonPrefixPosition < config.bytesPerDim) {
      BytesRef packedValue = pointValue.packedValue();
      bucket = packedValue.bytes[packedValue.offset + offset + commonPrefixPosition] & 0xff;
    } else {
      BytesRef packedValueDocID = pointValue.packedValueDocIDBytes();
      bucket = packedValueDocID.bytes[packedValueDocID.offset + config.packedIndexBytesLength + commonPrefixPosition - config.bytesPerDim] & 0xff;
    }
    return bucket;
  }

  private byte[] buildHistogramAndPartition(OfflinePointWriter points, PointWriter left, PointWriter right,
                                            long from, long to, long partitionPoint, int iteration,  int baseCommonPrefix, int dim) throws IOException {
    // Find common prefix from baseCommonPrefix and build histogram
    int commonPrefix = findCommonPrefixAndHistogram(points, from, to, dim, baseCommonPrefix);

    // If all equals we just partition the points
    if (commonPrefix == bytesSorted) {
      offlinePartition(points, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint);
      return partitionPointFromCommonPrefix();
    }

    long leftCount = 0;
    long rightCount = 0;

    // Count left points and record the partition point
    for(int i = 0; i < HISTOGRAM_SIZE; i++) {
      long size = histogram[i];
      if (leftCount + size > partitionPoint - from) {
        partitionBucket[commonPrefix] = i;
        break;
      }
      leftCount += size;
    }
    // Count right points
    for(int i = partitionBucket[commonPrefix] + 1; i < HISTOGRAM_SIZE; i++) {
      rightCount += histogram[i];
    }

    long delta = histogram[partitionBucket[commonPrefix]];
    assert leftCount + rightCount + delta == to - from : (leftCount + rightCount + delta) + " / " + (to - from);

    // Special case when points are equal except last byte, we can just tie-break
    if (commonPrefix == bytesSorted - 1) {
      long tieBreakCount =(partitionPoint - from - leftCount);
      offlinePartition(points, left,  right, null, from, to, dim, commonPrefix, tieBreakCount);
      return partitionPointFromCommonPrefix();
    }

    // Create the delta points writer
    PointWriter deltaPoints;
    try (PointWriter tempDeltaPoints = getDeltaPointWriter(left, right, delta, iteration)) {
      // Divide the points. This actually destroys the current writer
      offlinePartition(points, left, right, tempDeltaPoints, from, to, dim, commonPrefix, 0);
      deltaPoints = tempDeltaPoints;
    }

    long newPartitionPoint = partitionPoint - from - leftCount;

    if (deltaPoints instanceof HeapPointWriter) {
      return heapPartition((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix);
    } else {
      return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim);
    }
  }

  private void offlinePartition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints,
                                long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException {
    assert bytePosition == bytesSorted -1 || deltaPoints != null;
    int offset =  dim * config.bytesPerDim;
    long tiebreakCounter = 0;
    try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
      while (reader.next()) {
        PointValue pointValue = reader.pointValue();
        int bucket = getBucket(offset, bytePosition, pointValue);
        if (bucket < this.partitionBucket[bytePosition]) {
          // to the left side
          left.append(pointValue);
        } else if (bucket > this.partitionBucket[bytePosition]) {
          // to the right side
          right.append(pointValue);
        } else {
          if (bytePosition == bytesSorted - 1) {
            if (tiebreakCounter < numDocsTiebreak) {
              left.append(pointValue);
              tiebreakCounter++;
            } else {
              right.append(pointValue);
            }
          } else {
            deltaPoints.append(pointValue);
          }
        }
      }
    }
    // Delete original file
    points.destroy();
  }

  private byte[] partitionPointFromCommonPrefix() {
    byte[] partition = new byte[config.bytesPerDim];
    for (int i = 0; i < config.bytesPerDim; i++) {
      partition[i] = (byte)partitionBucket[i];
    }
    return partition;
  }

  private byte[] heapPartition(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException {
    byte[] partition = heapRadixSelect(points, dim, from, to, partitionPoint, commonPrefix);
    for (int i = from; i < to; i++) {
      PointValue value = points.getPackedValueSlice(i);
      if (i < partitionPoint) {
        left.append(value);
      } else {
        right.append(value);
      }
    }
    return partition;
  }

  private byte[] heapRadixSelect(HeapPointWriter points, int dim, int from, int to, int partitionPoint, int commonPrefixLength) {
    final int dimOffset = dim * config.bytesPerDim + commonPrefixLength;
    final int dimCmpBytes = config.bytesPerDim - commonPrefixLength;
    final int dataOffset = config.packedIndexBytesLength - dimCmpBytes;
    new RadixSelector(bytesSorted - commonPrefixLength) {

      @Override
      protected void swap(int i, int j) {
        points.swap(i, j);
      }

      @Override
      protected int byteAt(int i, int k) {
        assert k >= 0 : "negative prefix " + k;
        if (k  < dimCmpBytes) {
          // dim bytes
          return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
        } else {
          // data bytes
          return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
        }
      }

      @Override
      protected Selector getFallbackSelector(int d) {
        final int skypedBytes = d + commonPrefixLength;
        final int dimStart = dim * config.bytesPerDim + skypedBytes;
        final int dimEnd =  dim * config.bytesPerDim + config.bytesPerDim;
        // data length is composed by the data dimensions plus the docID
        final int dataLength = (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
        return new IntroSelector() {

          @Override
          protected void swap(int i, int j) {
            points.swap(i, j);
          }

          @Override
          protected void setPivot(int i) {
            if (skypedBytes < config.bytesPerDim) {
              System.arraycopy(points.block, i * config.bytesPerDoc + dim * config.bytesPerDim, scratch, 0, config.bytesPerDim);
            }
            System.arraycopy(points.block, i * config.bytesPerDoc + config.packedIndexBytesLength, scratch, config.bytesPerDim, dataLength);
          }

          @Override
          protected int compare(int i, int j) {
            if (skypedBytes < config.bytesPerDim) {
              int iOffset = i * config.bytesPerDoc;
              int jOffset = j * config.bytesPerDoc;
              int cmp = FutureArrays.compareUnsigned(points.block, iOffset + dimStart, iOffset + dimEnd, points.block, jOffset + dimStart, jOffset + dimEnd);
              if (cmp != 0) {
                return cmp;
              }
            }
            int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
            int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
            return FutureArrays.compareUnsigned(points.block, iOffset, iOffset + dataLength, points.block, jOffset, jOffset + dataLength);
          }

          @Override
          protected int comparePivot(int j) {
            if (skypedBytes < config.bytesPerDim) {
              int jOffset = j * config.bytesPerDoc;
              int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, config.bytesPerDim, points.block, jOffset + dimStart, jOffset + dimEnd);
              if (cmp != 0) {
                return cmp;
              }
            }
            int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
            return FutureArrays.compareUnsigned(scratch, config.bytesPerDim, config.bytesPerDim + dataLength, points.block, jOffset, jOffset + dataLength);
          }
        };
      }
    }.select(from, to, partitionPoint);

    byte[] partition = new byte[config.bytesPerDim];
    PointValue pointValue = points.getPackedValueSlice(partitionPoint);
    BytesRef packedValue = pointValue.packedValue();
    System.arraycopy(packedValue.bytes, packedValue.offset + dim * config.bytesPerDim, partition, 0, config.bytesPerDim);
    return partition;
  }

  /** Sort the heap writer by the specified dim. It is used to sort the leaves of the tree */
  public void heapRadixSort(final HeapPointWriter points, int from, int to, int dim, int commonPrefixLength) {
    final int dimOffset = dim * config.bytesPerDim + commonPrefixLength;
    final int dimCmpBytes = config.bytesPerDim - commonPrefixLength;
    final int dataOffset = config.packedIndexBytesLength - dimCmpBytes;
    new MSBRadixSorter(bytesSorted - commonPrefixLength) {

      @Override
      protected int byteAt(int i, int k) {
        assert k >= 0 : "negative prefix " + k;
        if (k  < dimCmpBytes) {
          // dim bytes
          return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
        } else {
          // data bytes
          return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
        }
      }

      @Override
      protected void swap(int i, int j) {
        points.swap(i, j);
      }

      @Override
      protected Sorter getFallbackSorter(int k) {
        final int skypedBytes = k + commonPrefixLength;
        final int dimStart = dim * config.bytesPerDim + skypedBytes;
        final int dimEnd =  dim * config.bytesPerDim + config.bytesPerDim;
        // data length is composed by the data dimensions plus the docID
        final int dataLength = (config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
        return new IntroSorter() {

          @Override
          protected void swap(int i, int j) {
            points.swap(i, j);
          }

          @Override
          protected void setPivot(int i) {
            if (skypedBytes < config.bytesPerDim) {
              System.arraycopy(points.block, i * config.bytesPerDoc + dim * config.bytesPerDim, scratch, 0, config.bytesPerDim);
            }
            System.arraycopy(points.block, i * config.bytesPerDoc + config.packedIndexBytesLength, scratch, config.bytesPerDim, dataLength);
          }

          @Override
          protected int compare(int i, int j) {
            if (skypedBytes < config.bytesPerDim) {
              int iOffset = i * config.bytesPerDoc;
              int jOffset = j * config.bytesPerDoc;
              int cmp = FutureArrays.compareUnsigned(points.block, iOffset + dimStart, iOffset + dimEnd, points.block, jOffset + dimStart, jOffset + dimEnd);
              if (cmp != 0) {
                return cmp;
              }
            }
            int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
            int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
            return FutureArrays.compareUnsigned(points.block, iOffset, iOffset + dataLength, points.block, jOffset, jOffset + dataLength);
          }

          @Override
          protected int comparePivot(int j) {
            if (skypedBytes < config.bytesPerDim) {
              int jOffset = j * config.bytesPerDoc;
              int cmp = FutureArrays.compareUnsigned(scratch, skypedBytes, config.bytesPerDim, points.block, jOffset + dimStart, jOffset + dimEnd);
              if (cmp != 0) {
                return cmp;
              }
            }
            int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
            return FutureArrays.compareUnsigned(scratch, config.bytesPerDim, config.bytesPerDim + dataLength, points.block, jOffset, jOffset + dataLength);
          }
        };
      }
    }.sort(from, to);
  }

  private PointWriter getDeltaPointWriter(PointWriter left, PointWriter right, long delta, int iteration) throws IOException {
    if (delta <= getMaxPointsSortInHeap(left, right)) {
      return  new HeapPointWriter(config, Math.toIntExact(delta));
    } else {
      return new OfflinePointWriter(config, tempDir, tempFileNamePrefix, "delta" + iteration, delta);
    }
  }

  private int getMaxPointsSortInHeap(PointWriter left, PointWriter right) {
    int pointsUsed = 0;
    if (left instanceof HeapPointWriter) {
      pointsUsed += ((HeapPointWriter) left).size;
    }
    if (right instanceof HeapPointWriter) {
      pointsUsed += ((HeapPointWriter) right).size;
    }
    assert maxPointsSortInHeap >= pointsUsed;
    return maxPointsSortInHeap - pointsUsed;
  }

  PointWriter getPointWriter(long count, String desc) throws IOException {
    // As we recurse, we hold two on-heap point writers at any point. Therefore the
    // max size for these objects is half of the total points we can have on-heap.
    if (count <= maxPointsSortInHeap / 2) {
      int size = Math.toIntExact(count);
      return new HeapPointWriter(config, size);
    } else {
      return new OfflinePointWriter(config, tempDir, tempFileNamePrefix, desc, count);
    }
  }

  /** Sliced reference to points in an PointWriter. */
  public static final class PathSlice {
    public final PointWriter writer;
    public final long start;
    public final long count;

    public PathSlice(PointWriter writer, long start, long count) {
      this.writer = writer;
      this.start = start;
      this.count = count;
    }

    @Override
    public String toString() {
      return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy