All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.document.FloatPointNearestNeighbor Maven / Gradle / Ivy

There is a newer version: 10.0.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.lucene.document;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.bkd.BKDReader;

/**
 * KNN search on top of N dimensional indexed float points.
 *
 * @lucene.experimental
 */
public class FloatPointNearestNeighbor {

  static class Cell implements Comparable {
    final int readerIndex;
    final byte[] minPacked;
    final byte[] maxPacked;
    final BKDReader.IndexTree index;
    /** The closest possible distance^2 of all points in this cell */
    final double distanceSquared;
    
    Cell(BKDReader.IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSquared) {
      this.index = index;
      this.readerIndex = readerIndex;
      this.minPacked = minPacked.clone();
      this.maxPacked = maxPacked.clone();
      this.distanceSquared = distanceSquared;
    }

    public int compareTo(Cell other) {
      return Double.compare(distanceSquared, other.distanceSquared);
    }

    @Override
    public String toString() {
      return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID()
          + " isLeaf=" + index.isLeafNode() + " distanceSquared=" + distanceSquared + ")";
    }
  }

  private static class NearestVisitor implements PointValues.IntersectVisitor {
    int curDocBase;
    Bits curLiveDocs;
    final int topN;
    final PriorityQueue hitQueue;
    final float[] origin;
    final private int dims;
    double bottomNearestDistanceSquared = Double.POSITIVE_INFINITY;
    int bottomNearestDistanceDoc = Integer.MAX_VALUE;

    public NearestVisitor(PriorityQueue hitQueue, int topN, float[] origin) {
      this.hitQueue = hitQueue;
      this.topN = topN;
      this.origin = origin;
      this.dims = origin.length;
    }

    @Override
    public void visit(int docID) {
      throw new AssertionError();
    }

    @Override
    public void visit(int docID, byte[] packedValue) {
      // System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs);;
      if (curLiveDocs != null && curLiveDocs.get(docID) == false) {
        return;
      }

      double distanceSquared = 0.0d;
      for (int d = 0, offset = 0 ; d < dims ; ++d, offset += Float.BYTES) {
        double diff = (double) FloatPoint.decodeDimension(packedValue, offset) - (double) origin[d];
        distanceSquared += diff * diff;
        if (distanceSquared > bottomNearestDistanceSquared) {
          return;
        }
      }

      // System.out.println("    visit docID=" + docID + " distanceSquared=" + distanceSquared + " value: " + Arrays.toString(docPoint));

      int fullDocID = curDocBase + docID;

      if (hitQueue.size() == topN) { // queue already full
        if (distanceSquared == bottomNearestDistanceSquared && fullDocID > bottomNearestDistanceDoc) {
          return;
        }
        NearestHit bottom = hitQueue.poll();
        // System.out.println("      bottom distanceSquared=" + bottom.distanceSquared);
        bottom.docID = fullDocID;
        bottom.distanceSquared = distanceSquared;
        hitQueue.offer(bottom);
        updateBottomNearestDistance();
          // System.out.println("      ** keep1, now bottom=" + bottom);
      } else {
        NearestHit hit = new NearestHit();
        hit.docID = fullDocID;
        hit.distanceSquared = distanceSquared;
        hitQueue.offer(hit);
        if (hitQueue.size() == topN) {
          updateBottomNearestDistance();
        }
        // System.out.println("      ** keep2, new addition=" + hit);
      }
    }

    private void updateBottomNearestDistance() {
      NearestHit newBottom = hitQueue.peek();
      bottomNearestDistanceSquared = newBottom.distanceSquared;
      bottomNearestDistanceDoc = newBottom.docID;
    }

    @Override
    public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
      if (hitQueue.size() == topN && pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin) > bottomNearestDistanceSquared) {
        return PointValues.Relation.CELL_OUTSIDE_QUERY;
      }
      return PointValues.Relation.CELL_CROSSES_QUERY;
    }
  }

  /** Holds one hit from {@link FloatPointNearestNeighbor#nearest} */
  static class NearestHit {
    public int docID;
    public double distanceSquared;

    @Override
    public String toString() {
      return "NearestHit(docID=" + docID + " distanceSquared=" + distanceSquared + ")";
    }
  }

  private static NearestHit[] nearest(List readers, List liveDocs, List docBases, final int topN, float[] origin) throws IOException {

    // System.out.println("NEAREST: readers=" + readers + " liveDocs=" + liveDocs + " origin: " + Arrays.toString(origin));

    // Holds closest collected points seen so far:
    // TODO: if we used lucene's PQ we could just updateTop instead of poll/offer:
    final PriorityQueue hitQueue = new PriorityQueue<>(topN, (a, b) -> {
      // sort by opposite distance natural order
      int cmp = Double.compare(a.distanceSquared, b.distanceSquared);
      return cmp != 0 ? -cmp : b.docID - a.docID; // tie-break by higher docID
    });

    // Holds all cells, sorted by closest to the point:
    PriorityQueue cellQueue = new PriorityQueue<>();

    NearestVisitor visitor = new NearestVisitor(hitQueue, topN, origin);
    List states = new ArrayList<>();

    // Add root cell for each reader into the queue:
    int bytesPerDim = -1;

    for (int i = 0 ; i < readers.size() ; ++i) {
      BKDReader reader = readers.get(i);
      if (bytesPerDim == -1) {
        bytesPerDim = reader.getBytesPerDimension();
      } else if (bytesPerDim != reader.getBytesPerDimension()) {
        throw new IllegalStateException("bytesPerDim changed from " + bytesPerDim
            + " to " + reader.getBytesPerDimension() + " across readers");
      }
      byte[] minPackedValue = reader.getMinPackedValue();
      byte[] maxPackedValue = reader.getMaxPackedValue();
      BKDReader.IntersectState state = reader.getIntersectState(visitor);
      states.add(state);

      cellQueue.offer(new Cell(state.index, i, reader.getMinPackedValue(), reader.getMaxPackedValue(),
          pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin)));
    }

    while (cellQueue.size() > 0) {
      Cell cell = cellQueue.poll();
      // System.out.println("  visit " + cell);

      if (cell.distanceSquared > visitor.bottomNearestDistanceSquared) {
        break;
      }

      BKDReader reader = readers.get(cell.readerIndex);
      if (cell.index.isLeafNode()) {
        // System.out.println("    leaf");
        // Leaf block: visit all points and possibly collect them:
        visitor.curDocBase = docBases.get(cell.readerIndex);
        visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
        reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));

        //assert hitQueue.peek().distanceSquared >= cell.distanceSquared;
        // System.out.println("    now " + hitQueue.size() + " hits");
      } else {
        // System.out.println("    non-leaf");
        // Non-leaf block: split into two cells and put them back into the queue:

        BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue());
        int splitDim = cell.index.getSplitDim();

        // we must clone the index so that we we can recurse left and right "concurrently":
        BKDReader.IndexTree newIndex = cell.index.clone();
        byte[] splitPackedValue = cell.maxPacked.clone();
        System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);

        cell.index.pushLeft();
        double distanceLeft = pointToRectangleDistanceSquared(cell.minPacked, splitPackedValue, origin);
        if (distanceLeft <= visitor.bottomNearestDistanceSquared) {
          cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue, distanceLeft));
        }

        splitPackedValue = cell.minPacked.clone();
        System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);

        newIndex.pushRight();
        double distanceRight = pointToRectangleDistanceSquared(splitPackedValue, cell.maxPacked, origin);
        if (distanceRight <= visitor.bottomNearestDistanceSquared) {
          cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked, distanceRight));
        }
      }
    }

    NearestHit[] hits = new NearestHit[hitQueue.size()];
    int downTo = hitQueue.size()-1;
    while (hitQueue.size() != 0) {
      hits[downTo] = hitQueue.poll();
      downTo--;
    }
    //System.out.println(visitor.comp);
    return hits;
  }

  private static double pointToRectangleDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) {
    double sumOfSquaredDiffs = 0.0d;
    for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) {
      double min = FloatPoint.decodeDimension(minPackedValue, offset);
      if (value[i] < min) {
        double diff = min - (double)value[i];
        sumOfSquaredDiffs += diff * diff;
        continue;
      }
      double max = FloatPoint.decodeDimension(maxPackedValue, offset);
      if (value[i] > max) {
        double diff =  max - (double)value[i];
        sumOfSquaredDiffs += diff * diff;
      }
    }
    return sumOfSquaredDiffs;
  }

  public static TopFieldDocs nearest(IndexSearcher searcher, String field, int topN, float... origin) throws IOException {
    if (topN < 1) {
      throw new IllegalArgumentException("topN must be at least 1; got " + topN);
    }
    if (field == null) {
      throw new IllegalArgumentException("field must not be null");
    }
    if (searcher == null) {
      throw new IllegalArgumentException("searcher must not be null");
    }
    List readers = new ArrayList<>();
    List docBases = new ArrayList<>();
    List liveDocs = new ArrayList<>();
    int totalHits = 0;
    for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
      PointValues points = leaf.reader().getPointValues(field);
      if (points != null) {
        if (points instanceof BKDReader == false) {
          throw new IllegalArgumentException("can only run on Lucene60PointsReader points implementation, but got " + points);
        }
        totalHits += points.getDocCount();
        readers.add((BKDReader)points);
        docBases.add(leaf.docBase);
        liveDocs.add(leaf.reader().getLiveDocs());
      }
    }

    NearestHit[] hits = nearest(readers, liveDocs, docBases, topN, origin);

    // Convert to TopFieldDocs:
    ScoreDoc[] scoreDocs = new ScoreDoc[hits.length];
    for(int i=0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy