All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.search.join.JoinUtil Maven / Gradle / Ivy

There is a newer version: 10.0.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.join;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import java.util.TreeSet;
import java.util.function.BiConsumer;
import java.util.function.LongFunction;

import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointInSetQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.join.DocValuesTermsCollector.Function;
import org.apache.lucene.util.BytesRef;

/**
 * Utility for query time joining.
 *
 * @lucene.experimental
 */
public final class JoinUtil {

  // No instances allowed
  private JoinUtil() {
  }

  /**
   * Method for query time joining.
   * 

* Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have the same terms in the * to field that match with documents matching the specified fromQuery and have the same terms in the from field. *

* In the case a single document relates to more than one document the multipleValuesPerDocument option * should be set to true. When the multipleValuesPerDocument is set to true only the * the score from the first encountered join value originating from the 'from' side is mapped into the 'to' side. * Even in the case when a second join value related to a specific document yields a higher score. Obviously this * doesn't apply in the case that {@link ScoreMode#None} is used, since no scores are computed at all. *

* Memory considerations: During joining all unique join values are kept in memory. On top of that when the scoreMode * isn't set to {@link ScoreMode#None} a float value per unique join value is kept in memory for computing scores. * When scoreMode is set to {@link ScoreMode#Avg} also an additional integer value is kept in memory per unique * join value. * * @param fromField The from field to join from * @param multipleValuesPerDocument Whether the from field has multiple terms per document * @param toField The to field to join to * @param fromQuery The query to match documents on the from side * @param fromSearcher The searcher that executed the specified fromQuery * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @return a {@link Query} instance that can be used to join documents based on the * terms in the from and to field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery(String fromField, boolean multipleValuesPerDocument, String toField, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode) throws IOException { final GenericTermsCollector termsWithScoreCollector; if (multipleValuesPerDocument) { Function mvFunction = DocValuesTermsCollector.sortedSetDocValues(fromField); termsWithScoreCollector = GenericTermsCollector.createCollectorMV(mvFunction, scoreMode); } else { Function svFunction = DocValuesTermsCollector.binaryDocValues(fromField); termsWithScoreCollector = GenericTermsCollector.createCollectorSV(svFunction, scoreMode); } return createJoinQuery(multipleValuesPerDocument, toField, fromQuery, fromField, fromSearcher, scoreMode, termsWithScoreCollector); } /** * Method for query time joining for numeric fields. It supports multi- and single- values longs, ints, floats and longs. * All considerations from {@link JoinUtil#createJoinQuery(String, boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, * though memory consumption might be higher. *

* * @param fromField The from field to join from * @param multipleValuesPerDocument Whether the from field has multiple terms per document * when true fromField might be {@link DocValuesType#SORTED_NUMERIC}, * otherwise fromField should be {@link DocValuesType#NUMERIC} * @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link FloatPoint} * or {@link DoublePoint}. * @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link java.lang.Float} * or {@link java.lang.Double} it should correspond to toField types * @param fromQuery The query to match documents on the from side * @param fromSearcher The searcher that executed the specified fromQuery * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @return a {@link Query} instance that can be used to join documents based on the * terms in the from and to field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery(String fromField, boolean multipleValuesPerDocument, String toField, Class numericType, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode) throws IOException { TreeSet joinValues = new TreeSet<>(); Map aggregatedScores = new HashMap<>(); Map occurrences = new HashMap<>(); boolean needsScore = scoreMode != ScoreMode.None; BiConsumer scoreAggregator; if (scoreMode == ScoreMode.Max) { scoreAggregator = (key, score) -> { Float currentValue = aggregatedScores.putIfAbsent(key, score); if (currentValue != null) { aggregatedScores.put(key, Math.max(currentValue, score)); } }; } else if (scoreMode == ScoreMode.Min) { scoreAggregator = (key, score) -> { Float currentValue = aggregatedScores.putIfAbsent(key, score); if (currentValue != null) { aggregatedScores.put(key, Math.min(currentValue, score)); } }; } else if (scoreMode == ScoreMode.Total) { scoreAggregator = (key, score) -> { Float currentValue = aggregatedScores.putIfAbsent(key, score); if (currentValue != null) { aggregatedScores.put(key, currentValue + score); } }; } else if (scoreMode == ScoreMode.Avg) { scoreAggregator = (key, score) -> { Float currentSore = aggregatedScores.putIfAbsent(key, score); if (currentSore != null) { aggregatedScores.put(key, currentSore + score); } Integer currentOccurrence = occurrences.putIfAbsent(key, 1); if (currentOccurrence != null) { occurrences.put(key, ++currentOccurrence); } }; } else { scoreAggregator = (key, score) -> { throw new UnsupportedOperationException(); }; } LongFunction joinScorer; if (scoreMode == ScoreMode.Avg) { joinScorer = (joinValue) -> { Float aggregatedScore = aggregatedScores.get(joinValue); Integer occurrence = occurrences.get(joinValue); return aggregatedScore / occurrence; }; } else { joinScorer = aggregatedScores::get; } Collector collector; if (multipleValuesPerDocument) { collector = new SimpleCollector() { SortedNumericDocValues sortedNumericDocValues; Scorable scorer; @Override public void collect(int doc) throws IOException { if (sortedNumericDocValues.advanceExact(doc)) { for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) { long value = sortedNumericDocValues.nextValue(); joinValues.add(value); if (needsScore) { scoreAggregator.accept(value, scorer.score()); } } } } @Override protected void doSetNextReader(LeafReaderContext context) throws IOException { sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField); } @Override public void setScorer(Scorable scorer) throws IOException { this.scorer = scorer; } @Override public org.apache.lucene.search.ScoreMode scoreMode() { return needsScore ? org.apache.lucene.search.ScoreMode.COMPLETE : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; } }; } else { collector = new SimpleCollector() { NumericDocValues numericDocValues; Scorable scorer; private int lastDocID = -1; private boolean docsInOrder(int docID) { if (docID < lastDocID) { throw new AssertionError("docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID); } lastDocID = docID; return true; } @Override public void collect(int doc) throws IOException { assert docsInOrder(doc); long value = 0; if (numericDocValues.advanceExact(doc)) { value = numericDocValues.longValue(); } joinValues.add(value); if (needsScore) { scoreAggregator.accept(value, scorer.score()); } } @Override protected void doSetNextReader(LeafReaderContext context) throws IOException { numericDocValues = DocValues.getNumeric(context.reader(), fromField); lastDocID = -1; } @Override public void setScorer(Scorable scorer) throws IOException { this.scorer = scorer; } @Override public org.apache.lucene.search.ScoreMode scoreMode() { return needsScore ? org.apache.lucene.search.ScoreMode.COMPLETE : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES; } }; } fromSearcher.search(fromQuery, collector); Iterator iterator = joinValues.iterator(); final int bytesPerDim; final BytesRef encoded = new BytesRef(); final PointInSetIncludingScoreQuery.Stream stream; if (Integer.class.equals(numericType)) { bytesPerDim = Integer.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { long value = iterator.next(); IntPoint.encodeDimension((int) value, encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value); } return encoded; } else { return null; } } }; } else if (Long.class.equals(numericType)) { bytesPerDim = Long.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { long value = iterator.next(); LongPoint.encodeDimension(value, encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value); } return encoded; } else { return null; } } }; } else if (Float.class.equals(numericType)) { bytesPerDim = Float.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { long value = iterator.next(); FloatPoint.encodeDimension(Float.intBitsToFloat((int) value), encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value); } return encoded; } else { return null; } } }; } else if (Double.class.equals(numericType)) { bytesPerDim = Double.BYTES; stream = new PointInSetIncludingScoreQuery.Stream() { @Override public BytesRef next() { if (iterator.hasNext()) { long value = iterator.next(); DoublePoint.encodeDimension(Double.longBitsToDouble(value), encoded.bytes, 0); if (needsScore) { score = joinScorer.apply(value); } return encoded; } else { return null; } } }; } else { throw new IllegalArgumentException("unsupported numeric type, only Integer, Long, Float and Double are supported"); } encoded.bytes = new byte[bytesPerDim]; encoded.length = bytesPerDim; if (needsScore) { return new PointInSetIncludingScoreQuery(scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) { @Override protected String toString(byte[] value) { return toString.apply(value, numericType); } }; } else { return new PointInSetQuery(toField, 1, bytesPerDim, stream) { @Override protected String toString(byte[] value) { return PointInSetIncludingScoreQuery.toString.apply(value, numericType); } }; } } private static Query createJoinQuery(boolean multipleValuesPerDocument, String toField, Query fromQuery, String fromField, IndexSearcher fromSearcher, ScoreMode scoreMode, final GenericTermsCollector collector) throws IOException { fromSearcher.search(fromQuery, collector); switch (scoreMode) { case None: return new TermsQuery(toField, collector.getCollectedTerms(), fromField, fromQuery, fromSearcher.getTopReaderContext().id()); case Total: case Max: case Min: case Avg: return new TermsIncludingScoreQuery( scoreMode, toField, multipleValuesPerDocument, collector.getCollectedTerms(), collector.getScoresPerTerm(), fromField, fromQuery, fromSearcher.getTopReaderContext().id() ); default: throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); } } /** * Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode, OrdinalMap, int, int)}, * but disables the min and max filtering. * * @param joinField The {@link SortedDocValues} field containing the join values * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. * @param toQuery The query identifying all documents on the "to" side. * @param searcher The index searcher used to execute the from query * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map * needs to be provided. * @return a {@link Query} instance that can be used to join documents based on the join field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery(String joinField, Query fromQuery, Query toQuery, IndexSearcher searcher, ScoreMode scoreMode, OrdinalMap ordinalMap) throws IOException { return createJoinQuery(joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE); } /** * A query time join using global ordinals over a dedicated join field. * * This join has certain restrictions and requirements: * 1) A document can only refer to one other document. (but can be referred by one or more documents) * 2) Documents on each side of the join must be distinguishable. Typically this can be done by adding an extra field * that identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into account. * 3) There must be a single sorted doc values join field used by both the "from" and "to" documents. This join field * should store the join values as UTF-8 strings. * 4) An ordinal map must be provided that is created on top of the join field. * * Note: min and max filtering and the avg score mode will require this join to keep track of the number of times * a document matches per join value. This will increase the per join cost in terms of execution time and memory. * * @param joinField The {@link SortedDocValues} field containing the join values * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. * @param toQuery The query identifying all documents on the "to" side. * @param searcher The index searcher used to execute the from query * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map * needs to be provided. * @param min Optionally the minimum number of "from" documents that are required to match for a "to" document * to be a match. The min is inclusive. Setting min to 0 and max to Interger.MAX_VALUE * disables the min and max "from" documents filtering * @param max Optionally the maximum number of "from" documents that are allowed to match for a "to" document * to be a match. The max is inclusive. Setting min to 0 and max to Interger.MAX_VALUE * disables the min and max "from" documents filtering * @return a {@link Query} instance that can be used to join documents based on the join field * @throws IOException If I/O related errors occur */ public static Query createJoinQuery(String joinField, Query fromQuery, Query toQuery, IndexSearcher searcher, ScoreMode scoreMode, OrdinalMap ordinalMap, int min, int max) throws IOException { int numSegments = searcher.getIndexReader().leaves().size(); final long valueCount; if (numSegments == 0) { return new MatchNoDocsQuery("JoinUtil.createJoinQuery with no segments"); } else if (numSegments == 1) { // No need to use the ordinal map, because there is just one segment. ordinalMap = null; LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader(); SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField); if (joinSortedDocValues != null) { valueCount = joinSortedDocValues.getValueCount(); } else { return new MatchNoDocsQuery("JoinUtil.createJoinQuery: no join values"); } } else { if (ordinalMap == null) { throw new IllegalArgumentException("OrdinalMap is required, because there is more than 1 segment"); } valueCount = ordinalMap.getValueCount(); } final Query rewrittenFromQuery = searcher.rewrite(fromQuery); final Query rewrittenToQuery = searcher.rewrite(toQuery); GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector; switch (scoreMode) { case Total: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max); break; case Min: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max); break; case Max: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max); break; case Avg: globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max); break; case None: if (min <= 1 && max == Integer.MAX_VALUE) { GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); searcher.search(rewrittenFromQuery, globalOrdinalsCollector); return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, rewrittenToQuery, rewrittenFromQuery, searcher.getTopReaderContext().id()); } else { globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.NoScore(joinField, ordinalMap, valueCount, min, max); break; } default: throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); } searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector); return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, scoreMode, joinField, ordinalMap, rewrittenToQuery, rewrittenFromQuery, min, max, searcher.getTopReaderContext().id()); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy