org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.join;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
/**
* kNN float vector query that joins matching children vector documents with their parent doc id.
* The top documents returned are the child document ids and the calculated scores. Here is how to
* use this in conjunction with {@link ToParentBlockJoinQuery}.
*
*
* Query knnQuery = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, queryVector, ...);
* // Rewrite executes kNN search and collects nearest children docIds and their scores
* Query rewrittenKnnQuery = searcher.rewrite(knnQuery);
* // Join the scored children docs with their parents and score the parents
* Query childrenToParents = new ToParentBlockJoinQuery(rewrittenKnnQuery, parentsFilter, ScoreMode.MAX);
*
*/
public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery {
private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
private final BitSetProducer parentsFilter;
private final Query childFilter;
private final int k;
private final float[] query;
/**
* Create a ToParentBlockJoinFloatVectorQuery.
*
* @param field the query field
* @param query the vector query
* @param childFilter the child filter
* @param k how many parent documents to return given the matching children
* @param parentsFilter Filter identifying the parent documents.
*/
public DiversifyingChildrenFloatKnnVectorQuery(
String field, float[] query, Query childFilter, int k, BitSetProducer parentsFilter) {
super(field, query, k, childFilter);
this.childFilter = childFilter;
this.parentsFilter = parentsFilter;
this.k = k;
this.query = query;
}
@Override
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(field);
if (floatVectorValues == null) {
FloatVectorValues.checkField(context.reader(), field);
return NO_RESULTS;
}
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return NO_RESULTS;
}
VectorScorer floatVectorScorer = floatVectorValues.scorer(query);
if (floatVectorScorer == null) {
return NO_RESULTS;
}
DiversifyingChildrenVectorScorer vectorScorer =
new DiversifyingChildrenVectorScorer(acceptIterator, parentBitSet, floatVectorScorer);
final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost()));
HitQueue queue = new HitQueue(queueSize, true);
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
ScoreDoc topDoc = queue.top();
while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
if (queryTimeout != null && queryTimeout.shouldExit()) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
break;
}
float score = vectorScorer.score();
if (score > topDoc.score) {
topDoc.score = score;
topDoc.doc = vectorScorer.bestChild();
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
return new TopDocs(totalHits, topScoreDocs);
}
@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
}
@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
FloatVectorValues.checkField(context.reader(), field);
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
}
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
return collector.topDocs();
}
@Override
public String toString(String field) {
return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]";
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
DiversifyingChildrenFloatKnnVectorQuery that = (DiversifyingChildrenFloatKnnVectorQuery) o;
return k == that.k
&& Objects.equals(parentsFilter, that.parentsFilter)
&& Objects.equals(childFilter, that.childFilter)
&& Arrays.equals(query, that.query);
}
@Override
public int hashCode() {
int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k);
result = 31 * result + Arrays.hashCode(query);
return result;
}
static class DiversifyingChildrenVectorScorer {
private final VectorScorer vectorScorer;
private final DocIdSetIterator vectorIterator;
private final DocIdSetIterator acceptedChildrenIterator;
private final BitSet parentBitSet;
private int currentParent = -1;
private int bestChild = -1;
private float currentScore = Float.NEGATIVE_INFINITY;
protected DiversifyingChildrenVectorScorer(
DocIdSetIterator acceptedChildrenIterator, BitSet parentBitSet, VectorScorer vectorScorer) {
this.acceptedChildrenIterator = acceptedChildrenIterator;
this.vectorScorer = vectorScorer;
this.vectorIterator = vectorScorer.iterator();
this.parentBitSet = parentBitSet;
}
public int bestChild() {
return bestChild;
}
public int nextParent() throws IOException {
int nextChild = acceptedChildrenIterator.docID();
if (nextChild == -1) {
nextChild = acceptedChildrenIterator.nextDoc();
}
if (nextChild == DocIdSetIterator.NO_MORE_DOCS) {
currentParent = DocIdSetIterator.NO_MORE_DOCS;
return currentParent;
}
currentScore = Float.NEGATIVE_INFINITY;
currentParent = parentBitSet.nextSetBit(nextChild);
do {
vectorIterator.advance(nextChild);
float score = vectorScorer.score();
if (score > currentScore) {
bestChild = nextChild;
currentScore = score;
}
} while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS
&& nextChild < currentParent);
return currentParent;
}
public float score() throws IOException {
return currentScore;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy