org.apache.lucene.search.join.JoinUtil Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.join;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Locale;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.internal.hppc.LongCursor;
import org.apache.lucene.internal.hppc.LongFloatHashMap;
import org.apache.lucene.internal.hppc.LongHashSet;
import org.apache.lucene.internal.hppc.LongIntHashMap;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointInSetQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.join.DocValuesTermsCollector.Function;
import org.apache.lucene.util.BytesRef;
/**
* Utility for query time joining.
*
* @lucene.experimental
*/
public final class JoinUtil {
// No instances allowed
private JoinUtil() {}
/**
* Method for query time joining.
*
* Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have
* the same terms in the to field that match with documents matching the specified fromQuery and
* have the same terms in the from field.
*
*
In the case a single document relates to more than one document the
* multipleValuesPerDocument
option should be set to true. When the
* multipleValuesPerDocument
is set to true
only the the score from the first
* encountered join value originating from the 'from' side is mapped into the 'to' side. Even in
* the case when a second join value related to a specific document yields a higher score.
* Obviously this doesn't apply in the case that {@link ScoreMode#None} is used, since no scores
* are computed at all.
*
*
Memory considerations: During joining all unique join values are kept in memory. On top of
* that when the scoreMode isn't set to {@link ScoreMode#None} a float value per unique join value
* is kept in memory for computing scores. When scoreMode is set to {@link ScoreMode#Avg} also an
* additional integer value is kept in memory per unique join value.
*
* @param fromField The from field to join from
* @param multipleValuesPerDocument Whether the from field has multiple terms per document
* @param toField The to field to join to
* @param fromQuery The query to match documents on the from side
* @param fromSearcher The searcher that executed the specified fromQuery
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @return a {@link Query} instance that can be used to join documents based on the terms in the
* from and to field
* @throws IOException If I/O related errors occur
*/
public static Query createJoinQuery(
String fromField,
boolean multipleValuesPerDocument,
String toField,
Query fromQuery,
IndexSearcher fromSearcher,
ScoreMode scoreMode)
throws IOException {
final GenericTermsCollector termsWithScoreCollector;
if (multipleValuesPerDocument) {
Function mvFunction =
DocValuesTermsCollector.sortedSetDocValues(fromField);
termsWithScoreCollector = GenericTermsCollector.createCollectorMV(mvFunction, scoreMode);
} else {
Function svFunction = DocValuesTermsCollector.sortedDocValues(fromField);
termsWithScoreCollector = GenericTermsCollector.createCollectorSV(svFunction, scoreMode);
}
return createJoinQuery(
multipleValuesPerDocument,
toField,
fromQuery,
fromField,
fromSearcher,
scoreMode,
termsWithScoreCollector);
}
/**
* Method for query time joining for numeric fields. It supports multi- and single- values longs,
* ints, floats and longs. All considerations from {@link JoinUtil#createJoinQuery(String,
* boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, though memory
* consumption might be higher.
*
* @param fromField The from field to join from
* @param multipleValuesPerDocument Whether the from field has multiple terms per document when
* true fromField might be {@link DocValuesType#SORTED_NUMERIC}, otherwise fromField should be
* {@link DocValuesType#NUMERIC}
* @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link
* FloatPoint} or {@link DoublePoint}.
* @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link
* java.lang.Float} or {@link java.lang.Double} it should correspond to toField types
* @param fromQuery The query to match documents on the from side
* @param fromSearcher The searcher that executed the specified fromQuery
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @return a {@link Query} instance that can be used to join documents based on the terms in the
* from and to field
* @throws IOException If I/O related errors occur
*/
public static Query createJoinQuery(
String fromField,
boolean multipleValuesPerDocument,
String toField,
Class extends Number> numericType,
Query fromQuery,
IndexSearcher fromSearcher,
ScoreMode scoreMode)
throws IOException {
LongHashSet joinValues = new LongHashSet();
LongFloatHashMap aggregatedScores = new LongFloatHashMap();
LongIntHashMap occurrences = new LongIntHashMap();
boolean needsScore = scoreMode != ScoreMode.None;
LongFloatProcedure scoreAggregator;
if (scoreMode == ScoreMode.Max) {
scoreAggregator =
(key, score) -> {
int index = aggregatedScores.indexOf(key);
if (index < 0) {
aggregatedScores.indexInsert(index, key, score);
} else {
float currentScore = aggregatedScores.indexGet(index);
aggregatedScores.indexReplace(index, Math.max(currentScore, score));
}
};
} else if (scoreMode == ScoreMode.Min) {
scoreAggregator =
(key, score) -> {
int index = aggregatedScores.indexOf(key);
if (index < 0) {
aggregatedScores.indexInsert(index, key, score);
} else {
float currentScore = aggregatedScores.indexGet(index);
aggregatedScores.indexReplace(index, Math.min(currentScore, score));
}
};
} else if (scoreMode == ScoreMode.Total) {
scoreAggregator = aggregatedScores::addTo;
} else if (scoreMode == ScoreMode.Avg) {
scoreAggregator =
(key, score) -> {
aggregatedScores.addTo(key, score);
occurrences.addTo(key, 1);
};
} else {
scoreAggregator =
(key, score) -> {
throw new UnsupportedOperationException();
};
}
LongFloatFunction joinScorer;
if (scoreMode == ScoreMode.Avg) {
joinScorer =
(joinValue) -> {
float aggregatedScore = aggregatedScores.get(joinValue);
int occurrence = occurrences.get(joinValue);
return aggregatedScore / occurrence;
};
} else {
joinScorer = aggregatedScores::get;
}
Collector collector;
if (multipleValuesPerDocument) {
collector =
new SimpleCollector() {
SortedNumericDocValues sortedNumericDocValues;
Scorable scorer;
@Override
public void collect(int doc) throws IOException {
if (sortedNumericDocValues.advanceExact(doc)) {
for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) {
long value = sortedNumericDocValues.nextValue();
joinValues.add(value);
if (needsScore) {
scoreAggregator.apply(value, scorer.score());
}
}
}
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField);
}
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return needsScore
? org.apache.lucene.search.ScoreMode.COMPLETE
: org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
};
} else {
collector =
new SimpleCollector() {
NumericDocValues numericDocValues;
Scorable scorer;
private int lastDocID = -1;
private boolean docsInOrder(int docID) {
if (docID < lastDocID) {
throw new AssertionError(
"docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID);
}
lastDocID = docID;
return true;
}
@Override
public void collect(int doc) throws IOException {
assert docsInOrder(doc);
long value = 0;
if (numericDocValues.advanceExact(doc)) {
value = numericDocValues.longValue();
}
joinValues.add(value);
if (needsScore) {
scoreAggregator.apply(value, scorer.score());
}
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
numericDocValues = DocValues.getNumeric(context.reader(), fromField);
lastDocID = -1;
}
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return needsScore
? org.apache.lucene.search.ScoreMode.COMPLETE
: org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
};
}
fromSearcher.search(fromQuery, collector);
LongArrayList joinValuesList = new LongArrayList(joinValues.size());
joinValuesList.addAll(joinValues);
Arrays.sort(joinValuesList.buffer, 0, joinValuesList.size());
Iterator iterator = joinValuesList.iterator();
final int bytesPerDim;
final BytesRef encoded = new BytesRef();
final PointInSetIncludingScoreQuery.Stream stream;
if (Integer.class.equals(numericType)) {
bytesPerDim = Integer.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
IntPoint.encodeDimension((int) value.value, encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Long.class.equals(numericType)) {
bytesPerDim = Long.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
LongPoint.encodeDimension(value.value, encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Float.class.equals(numericType)) {
bytesPerDim = Float.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
FloatPoint.encodeDimension(
Float.intBitsToFloat((int) value.value), encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Double.class.equals(numericType)) {
bytesPerDim = Double.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
DoublePoint.encodeDimension(Double.longBitsToDouble(value.value), encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else {
throw new IllegalArgumentException(
"unsupported numeric type, only Integer, Long, Float and Double are supported");
}
encoded.bytes = new byte[bytesPerDim];
encoded.length = bytesPerDim;
if (needsScore) {
return new PointInSetIncludingScoreQuery(
scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) {
@Override
protected String toString(byte[] value) {
return toString.apply(value, numericType);
}
};
} else {
return new PointInSetQuery(toField, 1, bytesPerDim, stream) {
@Override
protected String toString(byte[] value) {
return PointInSetIncludingScoreQuery.toString.apply(value, numericType);
}
};
}
}
private static Query createJoinQuery(
boolean multipleValuesPerDocument,
String toField,
Query fromQuery,
String fromField,
IndexSearcher fromSearcher,
ScoreMode scoreMode,
final GenericTermsCollector collector)
throws IOException {
fromSearcher.search(fromQuery, collector);
switch (scoreMode) {
case None:
return new TermsQuery(
toField,
collector.getCollectedTerms(),
fromField,
fromQuery,
fromSearcher.getTopReaderContext().id());
case Total:
case Max:
case Min:
case Avg:
return new TermsIncludingScoreQuery(
scoreMode,
toField,
multipleValuesPerDocument,
collector.getCollectedTerms(),
collector.getScoresPerTerm(),
fromField,
fromQuery,
fromSearcher.getTopReaderContext().id());
default:
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
}
}
/**
* Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode,
* OrdinalMap, int, int)}, but disables the min and max filtering.
*
* @param joinField The {@link SortedDocValues} field containing the join values
* @param fromQuery The query containing the actual user query. Also the fromQuery can only match
* "from" documents.
* @param toQuery The query identifying all documents on the "to" side.
* @param searcher The index searcher used to execute the from query
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment
* index, no ordinal map needs to be provided.
* @return a {@link Query} instance that can be used to join documents based on the join field
* @throws IOException If I/O related errors occur
*/
public static Query createJoinQuery(
String joinField,
Query fromQuery,
Query toQuery,
IndexSearcher searcher,
ScoreMode scoreMode,
OrdinalMap ordinalMap)
throws IOException {
return createJoinQuery(
joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE);
}
/**
* A query time join using global ordinals over a dedicated join field.
*
* This join has certain restrictions and requirements: 1) A document can only refer to one
* other document. (but can be referred by one or more documents) 2) Documents on each side of the
* join must be distinguishable. Typically this can be done by adding an extra field that
* identifies the "from" and "to" side and then the fromQuery and toQuery must take this into
* account. 3) There must be a single sorted doc values join field used by both the "from" and
* "to" documents. This join field should store the join values as UTF-8 strings. 4) An ordinal
* map must be provided that is created on top of the join field.
*
*
Note: min and max filtering and the avg score mode will require this join to keep track of
* the number of times a document matches per join value. This will increase the per join cost in
* terms of execution time and memory.
*
* @param joinField The {@link SortedDocValues} field containing the join values
* @param fromQuery The query containing the actual user query. Also the fromQuery can only match
* "from" documents.
* @param toQuery The query identifying all documents on the "to" side.
* @param searcher The index searcher used to execute the from query
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment
* index, no ordinal map needs to be provided.
* @param min Optionally the minimum number of "from" documents that are required to match for a
* "to" document to be a match. The min is inclusive. Setting min to 0 and max to
* Interger.MAX_VALUE
disables the min and max "from" documents filtering
* @param max Optionally the maximum number of "from" documents that are allowed to match for a
* "to" document to be a match. The max is inclusive. Setting min to 0 and max to
* Interger.MAX_VALUE
disables the min and max "from" documents filtering
* @return a {@link Query} instance that can be used to join documents based on the join field
* @throws IOException If I/O related errors occur
*/
public static Query createJoinQuery(
String joinField,
Query fromQuery,
Query toQuery,
IndexSearcher searcher,
ScoreMode scoreMode,
OrdinalMap ordinalMap,
int min,
int max)
throws IOException {
int numSegments = searcher.getIndexReader().leaves().size();
final long valueCount;
if (numSegments == 0) {
return new MatchNoDocsQuery("JoinUtil.createJoinQuery with no segments");
} else if (numSegments == 1) {
// No need to use the ordinal map, because there is just one segment.
ordinalMap = null;
LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader();
SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField);
if (joinSortedDocValues != null) {
valueCount = joinSortedDocValues.getValueCount();
} else {
return new MatchNoDocsQuery("JoinUtil.createJoinQuery: no join values");
}
} else {
if (ordinalMap == null) {
throw new IllegalArgumentException(
"OrdinalMap is required, because there is more than 1 segment");
}
valueCount = ordinalMap.getValueCount();
}
final Query rewrittenFromQuery = searcher.rewrite(fromQuery);
final Query rewrittenToQuery = searcher.rewrite(toQuery);
GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector;
switch (scoreMode) {
case Total:
globalOrdinalsWithScoreCollector =
new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max);
break;
case Min:
globalOrdinalsWithScoreCollector =
new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max);
break;
case Max:
globalOrdinalsWithScoreCollector =
new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max);
break;
case Avg:
globalOrdinalsWithScoreCollector =
new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max);
break;
case None:
if (min <= 1 && max == Integer.MAX_VALUE) {
GlobalOrdinalsCollector globalOrdinalsCollector =
new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount);
searcher.search(rewrittenFromQuery, globalOrdinalsCollector);
return new GlobalOrdinalsQuery(
globalOrdinalsCollector.getCollectorOrdinals(),
joinField,
ordinalMap,
rewrittenToQuery,
rewrittenFromQuery,
searcher.getTopReaderContext().id());
} else {
globalOrdinalsWithScoreCollector =
new GlobalOrdinalsWithScoreCollector.NoScore(
joinField, ordinalMap, valueCount, min, max);
break;
}
default:
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
}
searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector);
return new GlobalOrdinalsWithScoreQuery(
globalOrdinalsWithScoreCollector,
scoreMode,
joinField,
ordinalMap,
rewrittenToQuery,
rewrittenFromQuery,
min,
max,
searcher.getTopReaderContext().id());
}
/** Similar to {@link java.util.function.LongFunction} for primitive argument and result. */
private interface LongFloatFunction {
float apply(long value);
}
/** Similar to {@link java.util.function.BiConsumer} for primitive arguments. */
private interface LongFloatProcedure {
void apply(long key, float value);
}
}