All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery Maven / Gradle / Ivy

There is a newer version: 10.1.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.join;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;

/**
 * kNN byte vector query that joins matching children vector documents with their parent doc id. The
 * top documents returned are the child document ids and the calculated scores. Here is how to use
 * this in conjunction with {@link ToParentBlockJoinQuery}.
 *
 * 
 *   Query knnQuery = new DiversifyingChildrenByteKnnVectorQuery(fieldName, queryVector, ...);
 *   // Rewrite executes kNN search and collects nearest children docIds and their scores
 *   Query rewrittenKnnQuery = searcher.rewrite(knnQuery);
 *   // Join the scored children docs with their parents and score the parents
 *   Query childrenToParents = new ToParentBlockJoinQuery(rewrittenKnnQuery, parentsFilter, ScoreMode.MAX);
 * 
*/ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; private final BitSetProducer parentsFilter; private final Query childFilter; private final int k; private final byte[] query; /** * Create a ToParentBlockJoinByteVectorQuery. * * @param field the query field * @param query the vector query * @param childFilter the child filter * @param k how many parent documents to return given the matching children * @param parentsFilter Filter identifying the parent documents. */ public DiversifyingChildrenByteKnnVectorQuery( String field, byte[] query, Query childFilter, int k, BitSetProducer parentsFilter) { super(field, query, k, childFilter); this.childFilter = childFilter; this.parentsFilter = parentsFilter; this.k = k; this.query = query; } @Override protected TopDocs exactSearch( LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException { ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(field); if (byteVectorValues == null) { ByteVectorValues.checkField(context.reader(), field); return NO_RESULTS; } BitSet parentBitSet = parentsFilter.getBitSet(context); if (parentBitSet == null) { return NO_RESULTS; } VectorScorer scorer = byteVectorValues.scorer(query); if (scorer == null) { return NO_RESULTS; } DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer vectorScorer = new DiversifyingChildrenFloatKnnVectorQuery.DiversifyingChildrenVectorScorer( acceptIterator, parentBitSet, scorer); final int queueSize = Math.min(k, Math.toIntExact(acceptIterator.cost())); HitQueue queue = new HitQueue(queueSize, true); TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; ScoreDoc topDoc = queue.top(); while (vectorScorer.nextParent() != DocIdSetIterator.NO_MORE_DOCS) { // Mark results as partial if timeout is met if (queryTimeout != null && queryTimeout.shouldExit()) { relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; break; } float score = vectorScorer.score(); if (score > topDoc.score) { topDoc.score = score; topDoc.doc = vectorScorer.bestChild(); topDoc = queue.updateTop(); } } // Remove any remaining sentinel values while (queue.size() > 0 && queue.top().score < 0) { queue.pop(); } ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()]; for (int i = topScoreDocs.length - 1; i >= 0; i--) { topScoreDocs[i] = queue.pop(); } TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation); return new TopDocs(totalHits, topScoreDocs); } @Override protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher); } @Override protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { ByteVectorValues.checkField(context.reader(), field); KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context); if (collector == null) { return NO_RESULTS; } context.reader().searchNearestVectors(field, query, collector, acceptDocs); return collector.topDocs(); } @Override public String toString(String field) { return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]"; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; if (!super.equals(o)) return false; DiversifyingChildrenByteKnnVectorQuery that = (DiversifyingChildrenByteKnnVectorQuery) o; return k == that.k && Objects.equals(parentsFilter, that.parentsFilter) && Objects.equals(childFilter, that.childFilter) && Arrays.equals(query, that.query); } @Override public int hashCode() { int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k); result = 31 * result + Arrays.hashCode(query); return result; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy