All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.search.join.JoinUtil Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.join;

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Locale;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.internal.hppc.LongCursor;
import org.apache.lucene.internal.hppc.LongFloatHashMap;
import org.apache.lucene.internal.hppc.LongHashSet;
import org.apache.lucene.internal.hppc.LongIntHashMap;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointInSetQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.join.DocValuesTermsCollector.Function;
import org.apache.lucene.util.BytesRef;

/**
 * Utility for query time joining.
 *
 * @lucene.experimental
 */
public final class JoinUtil {

  // No instances allowed
  private JoinUtil() {}

  /**
   * Method for query time joining.
   *
   * 

Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have * the same terms in the to field that match with documents matching the specified fromQuery and * have the same terms in the from field. * *

In the case a single document relates to more than one document the * multipleValuesPerDocument option should be set to true. When the * multipleValuesPerDocument is set to true only the the score from the first * encountered join value originating from the 'from' side is mapped into the 'to' side. Even in * the case when a second join value related to a specific document yields a higher score. * Obviously this doesn't apply in the case that {@link ScoreMode#None} is used, since no scores * are computed at all. * *

Memory considerations: During joining all unique join values are kept in memory. On top of * that when the scoreMode isn't set to {@link ScoreMode#None} a float value per unique join value * is kept in memory for computing scores. When scoreMode is set to {@link ScoreMode#Avg} also an * additional integer value is kept in memory per unique join value. * * @param fromField The from field to join from * @param multipleValuesPerDocument Whether the from field has multiple terms per document * @param toField The to field to join to * @param fromQuery The query to match documents on the from side * @param fromSearcher The searcher that executed the specified fromQuery * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @return a {@link Query} instance that can be used to join documents based on the terms in the * from and to field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery( String fromField, boolean multipleValuesPerDocument, String toField, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode) throws IOException { final GenericTermsCollector termsWithScoreCollector; if (multipleValuesPerDocument) { Function mvFunction = DocValuesTermsCollector.sortedSetDocValues(fromField); termsWithScoreCollector = GenericTermsCollector.createCollectorMV(mvFunction, scoreMode); } else { Function svFunction = DocValuesTermsCollector.sortedDocValues(fromField); termsWithScoreCollector = GenericTermsCollector.createCollectorSV(svFunction, scoreMode); } return createJoinQuery( multipleValuesPerDocument, toField, fromQuery, fromField, fromSearcher, scoreMode, termsWithScoreCollector); } /** * Method for query time joining for numeric fields. It supports multi- and single- values longs, * ints, floats and longs. All considerations from {@link JoinUtil#createJoinQuery(String, * boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, though memory * consumption might be higher. * * @param fromField The from field to join from * @param multipleValuesPerDocument Whether the from field has multiple terms per document when * true fromField might be {@link DocValuesType#SORTED_NUMERIC}, otherwise fromField should be * {@link DocValuesType#NUMERIC} * @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link * FloatPoint} or {@link DoublePoint}. * @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link * java.lang.Float} or {@link java.lang.Double} it should correspond to toField types * @param fromQuery The query to match documents on the from side * @param fromSearcher The searcher that executed the specified fromQuery * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @return a {@link Query} instance that can be used to join documents based on the terms in the * from and to field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery( String fromField, boolean multipleValuesPerDocument, String toField, Class numericType, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode) throws IOException { LongHashSet joinValues = new LongHashSet(); LongFloatHashMap aggregatedScores = new LongFloatHashMap(); LongIntHashMap occurrences = new LongIntHashMap(); boolean needsScore = scoreMode != ScoreMode.None; LongFloatProcedure scoreAggregator; if (scoreMode == ScoreMode.Max) { scoreAggregator = (key, score) -> { int index = aggregatedScores.indexOf(key); if (index < 0) { aggregatedScores.indexInsert(index, key, score); } else { float currentScore = aggregatedScores.indexGet(index); aggregatedScores.indexReplace(index, Math.max(currentScore, score)); } }; } else if (scoreMode == ScoreMode.Min) { scoreAggregator = (key, score) -> { int index = aggregatedScores.indexOf(key); if (index < 0) { aggregatedScores.indexInsert(index, key, score); } else { float currentScore = aggregatedScores.indexGet(index); aggregatedScores.indexReplace(index, Math.min(currentScore, score)); } }; } else if (scoreMode == ScoreMode.Total) { scoreAggregator = aggregatedScores::addTo; } else if (scoreMode == ScoreMode.Avg) { scoreAggregator = (key, score) -> { aggregatedScores.addTo(key, score); occurrences.addTo(key, 1); }; } else { scoreAggregator = (key, score) -> { throw new UnsupportedOperationException(); }; } LongFloatFunction joinScorer; if (scoreMode == ScoreMode.Avg) { joinScorer = (joinValue) -> { float aggregatedScore = aggregatedScores.get(joinValue); int occurrence = occurrences.get(joinValue); return aggregatedScore / occurrence; }; } else { joinScorer = aggregatedScores::get; } Collector collector; if (multipleValuesPerDocument) { collector = new SimpleCollector() { SortedNumericDocValues sortedNumericDocValues; Scorable scorer; @Override public void collect(int doc) throws IOException { if (sortedNumericDocValues.advanceExact(doc)) { for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) { long value = sortedNumericDocValues.nextValue(); joinValues.add(value); if (needsScore) { scoreAggregator.apply(value, scorer.score()); } } } } @Override protected void doSetNextReader(LeafReaderContext context) throws IOException { sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField); } @Override public void setScorer(Scorable scorer) throws IOException { this.scorer = scorer; } @Override public org.apache.lucene.search.ScoreMode scoreMode() { return needsScore ? org.apache.lucene.search.ScoreMode.COMPLETE : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; } }; } else { collector = new SimpleCollector() { NumericDocValues numericDocValues; Scorable scorer; private int lastDocID = -1; private boolean docsInOrder(int docID) { if (docID < lastDocID) { throw new AssertionError( "docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID); } lastDocID = docID; return true; } @Override public void collect(int doc) throws IOException { assert docsInOrder(doc); long value = 0; if (numericDocValues.advanceExact(doc)) { value = numericDocValues.longValue(); } joinValues.add(value); if (needsScore) { scoreAggregator.apply(value, scorer.score()); } } @Override protected void doSetNextReader(LeafReaderContext context) throws IOException { numericDocValues = DocValues.getNumeric(context.reader(), fromField); lastDocID = -1; } @Override public void setScorer(Scorable scorer) throws IOException { this.scorer = scorer; } @Override public org.apache.lucene.search.ScoreMode scoreMode() { return needsScore ? org.apache.lucene.search.ScoreMode.COMPLETE : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; } }; } fromSearcher.search(fromQuery, collector); LongArrayList joinValuesList = new LongArrayList(joinValues.size()); joinValuesList.addAll(joinValues); Arrays.sort(joinValuesList.buffer, 0, joinValuesList.size()); Iterator iterator = joinValuesList.iterator(); final int bytesPerDim; final BytesRef encoded = new BytesRef(); final PointInSetIncludingScoreQuery.Stream stream; if (Integer.class.equals(numericType)) { bytesPerDim = Integer.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { LongCursor value = iterator.next(); IntPoint.encodeDimension((int) value.value, encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value.value); } return encoded; } else { return null; } } }; } else if (Long.class.equals(numericType)) { bytesPerDim = Long.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { LongCursor value = iterator.next(); LongPoint.encodeDimension(value.value, encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value.value); } return encoded; } else { return null; } } }; } else if (Float.class.equals(numericType)) { bytesPerDim = Float.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { LongCursor value = iterator.next(); FloatPoint.encodeDimension( Float.intBitsToFloat((int) value.value), encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value.value); } return encoded; } else { return null; } } }; } else if (Double.class.equals(numericType)) { bytesPerDim = Double.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { LongCursor value = iterator.next(); DoublePoint.encodeDimension(Double.longBitsToDouble(value.value), encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value.value); } return encoded; } else { return null; } } }; } else { throw new IllegalArgumentException( "unsupported numeric type, only Integer, Long, Float and Double are supported"); } encoded.bytes = new byte[bytesPerDim]; encoded.length = bytesPerDim; if (needsScore) { return new PointInSetIncludingScoreQuery( scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) { @Override protected String toString(byte[] value) { return toString.apply(value, numericType); } }; } else { return new PointInSetQuery(toField, 1, bytesPerDim, stream) { @Override protected String toString(byte[] value) { return PointInSetIncludingScoreQuery.toString.apply(value, numericType); } }; } } private static Query createJoinQuery( boolean multipleValuesPerDocument, String toField, Query fromQuery, String fromField, IndexSearcher fromSearcher, ScoreMode scoreMode, final GenericTermsCollector collector) throws IOException { fromSearcher.search(fromQuery, collector); switch (scoreMode) { case None: return new TermsQuery( toField, collector.getCollectedTerms(), fromField, fromQuery, fromSearcher.getTopReaderContext().id()); case Total: case Max: case Min: case Avg: return new TermsIncludingScoreQuery( scoreMode, toField, multipleValuesPerDocument, collector.getCollectedTerms(), collector.getScoresPerTerm(), fromField, fromQuery, fromSearcher.getTopReaderContext().id()); default: throw new IllegalArgumentException( String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); } } /** * Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode, * OrdinalMap, int, int)}, but disables the min and max filtering. * * @param joinField The {@link SortedDocValues} field containing the join values * @param fromQuery The query containing the actual user query. Also the fromQuery can only match * "from" documents. * @param toQuery The query identifying all documents on the "to" side. * @param searcher The index searcher used to execute the from query * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment * index, no ordinal map needs to be provided. * @return a {@link Query} instance that can be used to join documents based on the join field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery( String joinField, Query fromQuery, Query toQuery, IndexSearcher searcher, ScoreMode scoreMode, OrdinalMap ordinalMap) throws IOException { return createJoinQuery( joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE); } /** * A query time join using global ordinals over a dedicated join field. * *

This join has certain restrictions and requirements: 1) A document can only refer to one * other document. (but can be referred by one or more documents) 2) Documents on each side of the * join must be distinguishable. Typically this can be done by adding an extra field that * identifies the "from" and "to" side and then the fromQuery and toQuery must take this into * account. 3) There must be a single sorted doc values join field used by both the "from" and * "to" documents. This join field should store the join values as UTF-8 strings. 4) An ordinal * map must be provided that is created on top of the join field. * *

Note: min and max filtering and the avg score mode will require this join to keep track of * the number of times a document matches per join value. This will increase the per join cost in * terms of execution time and memory. * * @param joinField The {@link SortedDocValues} field containing the join values * @param fromQuery The query containing the actual user query. Also the fromQuery can only match * "from" documents. * @param toQuery The query identifying all documents on the "to" side. * @param searcher The index searcher used to execute the from query * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment * index, no ordinal map needs to be provided. * @param min Optionally the minimum number of "from" documents that are required to match for a * "to" document to be a match. The min is inclusive. Setting min to 0 and max to * Interger.MAX_VALUE disables the min and max "from" documents filtering * @param max Optionally the maximum number of "from" documents that are allowed to match for a * "to" document to be a match. The max is inclusive. Setting min to 0 and max to * Interger.MAX_VALUE disables the min and max "from" documents filtering * @return a {@link Query} instance that can be used to join documents based on the join field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery( String joinField, Query fromQuery, Query toQuery, IndexSearcher searcher, ScoreMode scoreMode, OrdinalMap ordinalMap, int min, int max) throws IOException { int numSegments = searcher.getIndexReader().leaves().size(); final long valueCount; if (numSegments == 0) { return new MatchNoDocsQuery("JoinUtil.createJoinQuery with no segments"); } else if (numSegments == 1) { // No need to use the ordinal map, because there is just one segment. ordinalMap = null; LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader(); SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField); if (joinSortedDocValues != null) { valueCount = joinSortedDocValues.getValueCount(); } else { return new MatchNoDocsQuery("JoinUtil.createJoinQuery: no join values"); } } else { if (ordinalMap == null) { throw new IllegalArgumentException( "OrdinalMap is required, because there is more than 1 segment"); } valueCount = ordinalMap.getValueCount(); } final Query rewrittenFromQuery = searcher.rewrite(fromQuery); final Query rewrittenToQuery = searcher.rewrite(toQuery); GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector; switch (scoreMode) { case Total: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max); break; case Min: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max); break; case Max: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max); break; case Avg: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max); break; case None: if (min <= 1 && max == Integer.MAX_VALUE) { GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); searcher.search(rewrittenFromQuery, globalOrdinalsCollector); return new GlobalOrdinalsQuery( globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, rewrittenToQuery, rewrittenFromQuery, searcher.getTopReaderContext().id()); } else { globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.NoScore( joinField, ordinalMap, valueCount, min, max); break; } default: throw new IllegalArgumentException( String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); } searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector); return new GlobalOrdinalsWithScoreQuery( globalOrdinalsWithScoreCollector, scoreMode, joinField, ordinalMap, rewrittenToQuery, rewrittenFromQuery, min, max, searcher.getTopReaderContext().id()); } /** Similar to {@link java.util.function.LongFunction} for primitive argument and result. */ private interface LongFloatFunction { float apply(long value); } /** Similar to {@link java.util.function.BiConsumer} for primitive arguments. */ private interface LongFloatProcedure { void apply(long key, float value); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy