![JAR search and dependency download from the Maven repository](/logo.png)
org.apache.lucene.search.highlight.WeightedSpanTermExtractor Maven / Gradle / Ivy
Show all versions of aem-sdk-api Show documentation
package org.apache.lucene.search.highlight;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.lucene.analysis.CachingTokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.FilterAtomicReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.memory.MemoryIndex;
import org.apache.lucene.queries.CommonTermsQuery;
import org.apache.lucene.search.*;
import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanNotQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
/**
* Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether
* {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
*/
public class WeightedSpanTermExtractor {
private String fieldName;
private TokenStream tokenStream;
private String defaultField;
private boolean expandMultiTermQuery;
private boolean cachedTokenStream;
private boolean wrapToCaching = true;
private int maxDocCharsToAnalyze;
private AtomicReader internalReader = null;
public WeightedSpanTermExtractor() {
}
public WeightedSpanTermExtractor(String defaultField) {
if (defaultField != null) {
this.defaultField = defaultField;
}
}
/**
* Fills a Map
with <@link WeightedSpanTerm>s using the terms from the supplied Query
.
*
* @param query
* Query to extract Terms from
* @param terms
* Map to place created WeightedSpanTerms in
* @throws IOException If there is a low-level I/O error
*/
protected void extract(Query query, Map terms) throws IOException {
if (query instanceof BooleanQuery) {
BooleanClause[] queryClauses = ((BooleanQuery) query).getClauses();
for (int i = 0; i < queryClauses.length; i++) {
if (!queryClauses[i].isProhibited()) {
extract(queryClauses[i].getQuery(), terms);
}
}
} else if (query instanceof PhraseQuery) {
PhraseQuery phraseQuery = ((PhraseQuery) query);
Term[] phraseQueryTerms = phraseQuery.getTerms();
SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
for (int i = 0; i < phraseQueryTerms.length; i++) {
clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
}
int slop = phraseQuery.getSlop();
int[] positions = phraseQuery.getPositions();
// add largest position increment to slop
if (positions.length > 0) {
int lastPos = positions[0];
int largestInc = 0;
int sz = positions.length;
for (int i = 1; i < sz; i++) {
int pos = positions[i];
int inc = pos - lastPos;
if (inc > largestInc) {
largestInc = inc;
}
lastPos = pos;
}
if(largestInc > 1) {
slop += largestInc;
}
}
boolean inorder = false;
if (slop == 0) {
inorder = true;
}
SpanNearQuery sp = new SpanNearQuery(clauses, slop, inorder);
sp.setBoost(query.getBoost());
extractWeightedSpanTerms(terms, sp);
} else if (query instanceof TermQuery) {
extractWeightedTerms(terms, query);
} else if (query instanceof SpanQuery) {
extractWeightedSpanTerms(terms, (SpanQuery) query);
} else if (query instanceof FilteredQuery) {
extract(((FilteredQuery) query).getQuery(), terms);
} else if (query instanceof ConstantScoreQuery) {
final Query q = ((ConstantScoreQuery) query).getQuery();
if (q != null) {
extract(q, terms);
}
} else if (query instanceof CommonTermsQuery) {
// specialized since rewriting would change the result query
// this query is TermContext sensitive.
extractWeightedTerms(terms, query);
} else if (query instanceof DisjunctionMaxQuery) {
for (Iterator iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
extract(iterator.next(), terms);
}
} else if (query instanceof MultiPhraseQuery) {
final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
final List termArrays = mpq.getTermArrays();
final int[] positions = mpq.getPositions();
if (positions.length > 0) {
int maxPosition = positions[positions.length - 1];
for (int i = 0; i < positions.length - 1; ++i) {
if (positions[i] > maxPosition) {
maxPosition = positions[i];
}
}
@SuppressWarnings("unchecked")
final List[] disjunctLists = new List[maxPosition + 1];
int distinctPositions = 0;
for (int i = 0; i < termArrays.size(); ++i) {
final Term[] termArray = termArrays.get(i);
List disjuncts = disjunctLists[positions[i]];
if (disjuncts == null) {
disjuncts = (disjunctLists[positions[i]] = new ArrayList(termArray.length));
++distinctPositions;
}
for (int j = 0; j < termArray.length; ++j) {
disjuncts.add(new SpanTermQuery(termArray[j]));
}
}
int positionGaps = 0;
int position = 0;
final SpanQuery[] clauses = new SpanQuery[distinctPositions];
for (int i = 0; i < disjunctLists.length; ++i) {
List disjuncts = disjunctLists[i];
if (disjuncts != null) {
clauses[position++] = new SpanOrQuery(disjuncts
.toArray(new SpanQuery[disjuncts.size()]));
} else {
++positionGaps;
}
}
final int slop = mpq.getSlop();
final boolean inorder = (slop == 0);
SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
sp.setBoost(query.getBoost());
extractWeightedSpanTerms(terms, sp);
}
} else {
Query origQuery = query;
if (query instanceof MultiTermQuery) {
if (!expandMultiTermQuery) {
return;
}
MultiTermQuery copy = (MultiTermQuery) query.clone();
copy.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_QUERY_REWRITE);
origQuery = copy;
}
final IndexReader reader = getLeafContext().reader();
Query rewritten = origQuery.rewrite(reader);
if (rewritten != origQuery) {
// only rewrite once and then flatten again - the rewritten query could have a speacial treatment
// if this method is overwritten in a subclass or above in the next recursion
extract(rewritten, terms);
}
}
extractUnknownQuery(query, terms);
}
protected void extractUnknownQuery(Query query,
Map terms) throws IOException {
// for sub-classing to extract custom queries
}
/**
* Fills a Map
with <@link WeightedSpanTerm>s using the terms from the supplied SpanQuery
.
*
* @param terms
* Map to place created WeightedSpanTerms in
* @param spanQuery
* SpanQuery to extract Terms from
* @throws IOException If there is a low-level I/O error
*/
protected void extractWeightedSpanTerms(Map terms, SpanQuery spanQuery) throws IOException {
Set fieldNames;
if (fieldName == null) {
fieldNames = new HashSet();
collectSpanQueryFields(spanQuery, fieldNames);
} else {
fieldNames = new HashSet(1);
fieldNames.add(fieldName);
}
// To support the use of the default field name
if (defaultField != null) {
fieldNames.add(defaultField);
}
Map queries = new HashMap();
Set nonWeightedTerms = new HashSet();
final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
if (mustRewriteQuery) {
for (final String field : fieldNames) {
final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getLeafContext().reader());
queries.put(field, rewrittenQuery);
rewrittenQuery.extractTerms(nonWeightedTerms);
}
} else {
spanQuery.extractTerms(nonWeightedTerms);
}
List spanPositions = new ArrayList();
for (final String field : fieldNames) {
final SpanQuery q;
if (mustRewriteQuery) {
q = queries.get(field);
} else {
q = spanQuery;
}
AtomicReaderContext context = getLeafContext();
Map termContexts = new HashMap();
TreeSet extractedTerms = new TreeSet();
q.extractTerms(extractedTerms);
for (Term term : extractedTerms) {
termContexts.put(term, TermContext.build(context, term));
}
Bits acceptDocs = context.reader().getLiveDocs();
final Spans spans = q.getSpans(context, acceptDocs, termContexts);
// collect span positions
while (spans.next()) {
spanPositions.add(new PositionSpan(spans.start(), spans.end() - 1));
}
}
if (spanPositions.size() == 0) {
// no spans found
return;
}
for (final Term queryTerm : nonWeightedTerms) {
if (fieldNameComparator(queryTerm.field())) {
WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());
if (weightedSpanTerm == null) {
weightedSpanTerm = new WeightedSpanTerm(spanQuery.getBoost(), queryTerm.text());
weightedSpanTerm.addPositionSpans(spanPositions);
weightedSpanTerm.positionSensitive = true;
terms.put(queryTerm.text(), weightedSpanTerm);
} else {
if (spanPositions.size() > 0) {
weightedSpanTerm.addPositionSpans(spanPositions);
}
}
}
}
}
/**
* Fills a Map
with <@link WeightedSpanTerm>s using the terms from the supplied Query
.
*
* @param terms
* Map to place created WeightedSpanTerms in
* @param query
* Query to extract Terms from
* @throws IOException If there is a low-level I/O error
*/
protected void extractWeightedTerms(Map terms, Query query) throws IOException {
Set nonWeightedTerms = new HashSet();
query.extractTerms(nonWeightedTerms);
for (final Term queryTerm : nonWeightedTerms) {
if (fieldNameComparator(queryTerm.field())) {
WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(query.getBoost(), queryTerm.text());
terms.put(queryTerm.text(), weightedSpanTerm);
}
}
}
/**
* Necessary to implement matches for queries against defaultField
*/
protected boolean fieldNameComparator(String fieldNameToCheck) {
boolean rv = fieldName == null || fieldName.equals(fieldNameToCheck)
|| (defaultField != null && defaultField.equals(fieldNameToCheck));
return rv;
}
protected AtomicReaderContext getLeafContext() throws IOException {
if (internalReader == null) {
if(wrapToCaching && !(tokenStream instanceof CachingTokenFilter)) {
assert !cachedTokenStream;
tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
cachedTokenStream = true;
}
final MemoryIndex indexer = new MemoryIndex(true);
indexer.addField(DelegatingAtomicReader.FIELD_NAME, tokenStream);
tokenStream.reset();
final IndexSearcher searcher = indexer.createSearcher();
// MEM index has only atomic ctx
internalReader = new DelegatingAtomicReader(((AtomicReaderContext)searcher.getTopReaderContext()).reader());
}
return internalReader.getContext();
}
/*
* This reader will just delegate every call to a single field in the wrapped
* AtomicReader. This way we only need to build this field once rather than
* N-Times
*/
static final class DelegatingAtomicReader extends FilterAtomicReader {
private static final String FIELD_NAME = "shadowed_field";
DelegatingAtomicReader(AtomicReader in) {
super(in);
}
@Override
public FieldInfos getFieldInfos() {
throw new UnsupportedOperationException();
}
@Override
public Fields fields() throws IOException {
return new FilterFields(super.fields()) {
@Override
public Terms terms(String field) throws IOException {
return super.terms(DelegatingAtomicReader.FIELD_NAME);
}
@Override
public Iterator iterator() {
return Collections.singletonList(DelegatingAtomicReader.FIELD_NAME).iterator();
}
@Override
public int size() {
return 1;
}
};
}
@Override
public NumericDocValues getNumericDocValues(String field) throws IOException {
return super.getNumericDocValues(FIELD_NAME);
}
@Override
public BinaryDocValues getBinaryDocValues(String field) throws IOException {
return super.getBinaryDocValues(FIELD_NAME);
}
@Override
public SortedDocValues getSortedDocValues(String field) throws IOException {
return super.getSortedDocValues(FIELD_NAME);
}
@Override
public NumericDocValues getNormValues(String field) throws IOException {
return super.getNormValues(FIELD_NAME);
}
@Override
public Bits getDocsWithField(String field) throws IOException {
return super.getDocsWithField(FIELD_NAME);
}
}
/**
* Creates a Map of WeightedSpanTerms
from the given Query
and TokenStream
.
*
*
*
* @param query
* that caused hit
* @param tokenStream
* of text to be highlighted
* @return Map containing WeightedSpanTerms
* @throws IOException If there is a low-level I/O error
*/
public Map getWeightedSpanTerms(Query query, TokenStream tokenStream)
throws IOException {
return getWeightedSpanTerms(query, tokenStream, null);
}
/**
* Creates a Map of WeightedSpanTerms
from the given Query
and TokenStream
.
*
*
*
* @param query
* that caused hit
* @param tokenStream
* of text to be highlighted
* @param fieldName
* restricts Term's used based on field name
* @return Map containing WeightedSpanTerms
* @throws IOException If there is a low-level I/O error
*/
public Map getWeightedSpanTerms(Query query, TokenStream tokenStream,
String fieldName) throws IOException {
if (fieldName != null) {
this.fieldName = fieldName;
} else {
this.fieldName = null;
}
Map terms = new PositionCheckingMap();
this.tokenStream = tokenStream;
try {
extract(query, terms);
} finally {
IOUtils.close(internalReader);
}
return terms;
}
/**
* Creates a Map of WeightedSpanTerms
from the given Query
and TokenStream
. Uses a supplied
* IndexReader
to properly weight terms (for gradient highlighting).
*
*
*
* @param query
* that caused hit
* @param tokenStream
* of text to be highlighted
* @param fieldName
* restricts Term's used based on field name
* @param reader
* to use for scoring
* @return Map of WeightedSpanTerms with quasi tf/idf scores
* @throws IOException If there is a low-level I/O error
*/
public Map getWeightedSpanTermsWithScores(Query query, TokenStream tokenStream, String fieldName,
IndexReader reader) throws IOException {
if (fieldName != null) {
this.fieldName = fieldName;
} else {
this.fieldName = null;
}
this.tokenStream = tokenStream;
Map terms = new PositionCheckingMap();
extract(query, terms);
int totalNumDocs = reader.maxDoc();
Set weightedTerms = terms.keySet();
Iterator it = weightedTerms.iterator();
try {
while (it.hasNext()) {
WeightedSpanTerm weightedSpanTerm = terms.get(it.next());
int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term));
// IDF algorithm taken from DefaultSimilarity class
float idf = (float) (Math.log(totalNumDocs / (double) (docFreq + 1)) + 1.0);
weightedSpanTerm.weight *= idf;
}
} finally {
IOUtils.close(internalReader);
}
return terms;
}
protected void collectSpanQueryFields(SpanQuery spanQuery, Set fieldNames) {
if (spanQuery instanceof FieldMaskingSpanQuery) {
collectSpanQueryFields(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery(), fieldNames);
} else if (spanQuery instanceof SpanFirstQuery) {
collectSpanQueryFields(((SpanFirstQuery)spanQuery).getMatch(), fieldNames);
} else if (spanQuery instanceof SpanNearQuery) {
for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
collectSpanQueryFields(clause, fieldNames);
}
} else if (spanQuery instanceof SpanNotQuery) {
collectSpanQueryFields(((SpanNotQuery)spanQuery).getInclude(), fieldNames);
} else if (spanQuery instanceof SpanOrQuery) {
for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
collectSpanQueryFields(clause, fieldNames);
}
} else {
fieldNames.add(spanQuery.getField());
}
}
protected boolean mustRewriteQuery(SpanQuery spanQuery) {
if (!expandMultiTermQuery) {
return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery.
} else if (spanQuery instanceof FieldMaskingSpanQuery) {
return mustRewriteQuery(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery());
} else if (spanQuery instanceof SpanFirstQuery) {
return mustRewriteQuery(((SpanFirstQuery)spanQuery).getMatch());
} else if (spanQuery instanceof SpanNearQuery) {
for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) {
if (mustRewriteQuery(clause)) {
return true;
}
}
return false;
} else if (spanQuery instanceof SpanNotQuery) {
SpanNotQuery spanNotQuery = (SpanNotQuery)spanQuery;
return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude());
} else if (spanQuery instanceof SpanOrQuery) {
for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) {
if (mustRewriteQuery(clause)) {
return true;
}
}
return false;
} else if (spanQuery instanceof SpanTermQuery) {
return false;
} else {
return true;
}
}
/**
* This class makes sure that if both position sensitive and insensitive
* versions of the same term are added, the position insensitive one wins.
*/
@SuppressWarnings("serial")
protected static class PositionCheckingMap extends HashMap {
@Override
public void putAll(Map extends K,? extends WeightedSpanTerm> m) {
for (Map.Entry extends K,? extends WeightedSpanTerm> entry : m.entrySet())
this.put(entry.getKey(), entry.getValue());
}
@Override
public WeightedSpanTerm put(K key, WeightedSpanTerm value) {
WeightedSpanTerm prev = super.put(key, value);
if (prev == null) return prev;
WeightedSpanTerm prevTerm = prev;
WeightedSpanTerm newTerm = value;
if (!prevTerm.positionSensitive) {
newTerm.positionSensitive = false;
}
return prev;
}
}
public boolean getExpandMultiTermQuery() {
return expandMultiTermQuery;
}
public void setExpandMultiTermQuery(boolean expandMultiTermQuery) {
this.expandMultiTermQuery = expandMultiTermQuery;
}
public boolean isCachedTokenStream() {
return cachedTokenStream;
}
public TokenStream getTokenStream() {
return tokenStream;
}
/**
* By default, {@link TokenStream}s that are not of the type
* {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to
* ensure an efficient reset - if you are already using a different caching
* {@link TokenStream} impl and you don't want it to be wrapped, set this to
* false.
*/
public void setWrapIfNotCachingTokenFilter(boolean wrap) {
this.wrapToCaching = wrap;
}
protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) {
this.maxDocCharsToAnalyze = maxDocCharsToAnalyze;
}
}