All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.search.highlight.WeightedSpanTermExtractor Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.highlight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.analysis.CachingTokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.memory.MemoryIndex;
import org.apache.lucene.queries.CommonTermsQuery;
import org.apache.lucene.queries.CustomScoreQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiPhraseQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.PhraseQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.apache.lucene.search.join.ToParentBlockJoinQuery;
import org.apache.lucene.search.spans.FieldMaskingSpanQuery;
import org.apache.lucene.search.spans.SpanFirstQuery;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanNotQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.SpanWeight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;

/**
 * Class used to extract {@link WeightedSpanTerm}s from a {@link Query} based on whether 
 * {@link Term}s from the {@link Query} are contained in a supplied {@link TokenStream}.
 */
public class WeightedSpanTermExtractor {

  private String fieldName;
  private TokenStream tokenStream;//set subsequent to getWeightedSpanTerms* methods
  private String defaultField;
  private boolean expandMultiTermQuery;
  private boolean cachedTokenStream;
  private boolean wrapToCaching = true;
  private int maxDocCharsToAnalyze;
  private boolean usePayloads = false;
  private LeafReader internalReader = null;

  public WeightedSpanTermExtractor() {
  }

  public WeightedSpanTermExtractor(String defaultField) {
    if (defaultField != null) {
      this.defaultField = defaultField;
    }
  }

  /**
   * Fills a Map with {@link WeightedSpanTerm}s using the terms from the supplied Query.
   * 
   * @param query
   *          Query to extract Terms from
   * @param terms
   *          Map to place created WeightedSpanTerms in
   * @throws IOException If there is a low-level I/O error
   */
  protected void extract(Query query, float boost, Map terms) throws IOException {
    if (query instanceof BoostQuery) {
      BoostQuery boostQuery = (BoostQuery) query;
      extract(boostQuery.getQuery(), boost * boostQuery.getBoost(), terms);
    } else if (query instanceof BooleanQuery) {
      for (BooleanClause clause : (BooleanQuery) query) {
        if (!clause.isProhibited()) {
          extract(clause.getQuery(), boost, terms);
        }
      }
    } else if (query instanceof PhraseQuery) {
      PhraseQuery phraseQuery = ((PhraseQuery) query);
      Term[] phraseQueryTerms = phraseQuery.getTerms();
      if (phraseQueryTerms.length == 1) {
        extractWeightedSpanTerms(terms, new SpanTermQuery(phraseQueryTerms[0]), boost);
      }
      else {
        SpanQuery[] clauses = new SpanQuery[phraseQueryTerms.length];
        for (int i = 0; i < phraseQueryTerms.length; i++) {
          clauses[i] = new SpanTermQuery(phraseQueryTerms[i]);
        }

        // sum position increments beyond 1
        int positionGaps = 0;
        int[] positions = phraseQuery.getPositions();
        if (positions.length >= 2) {
          // positions are in increasing order.   max(0,...) is just a safeguard.
          positionGaps = Math.max(0, positions[positions.length - 1] - positions[0] - positions.length + 1);
        }

        //if original slop is 0 then require inOrder
        boolean inorder = (phraseQuery.getSlop() == 0);

        SpanNearQuery sp = new SpanNearQuery(clauses, phraseQuery.getSlop() + positionGaps, inorder);
        extractWeightedSpanTerms(terms, sp, boost);
      }
    } else if (query instanceof TermQuery) {
      extractWeightedTerms(terms, query, boost);
    } else if (query instanceof SpanQuery) {
      extractWeightedSpanTerms(terms, (SpanQuery) query, boost);
    } else if (query instanceof ConstantScoreQuery) {
      final Query q = ((ConstantScoreQuery) query).getQuery();
      if (q != null) {
        extract(q, boost, terms);
      }
    } else if (query instanceof CommonTermsQuery) {
      // specialized since rewriting would change the result query 
      // this query is TermContext sensitive.
      extractWeightedTerms(terms, query, boost);
    } else if (query instanceof DisjunctionMaxQuery) {
      for (Iterator iterator = ((DisjunctionMaxQuery) query).iterator(); iterator.hasNext();) {
        extract(iterator.next(), boost, terms);
      }
    } else if (query instanceof ToParentBlockJoinQuery) {
      extract(((ToParentBlockJoinQuery) query).getChildQuery(), boost, terms);
    } else if (query instanceof ToChildBlockJoinQuery) {
      extract(((ToChildBlockJoinQuery) query).getParentQuery(), boost, terms);
    } else if (query instanceof MultiPhraseQuery) {
      final MultiPhraseQuery mpq = (MultiPhraseQuery) query;
      final Term[][] termArrays = mpq.getTermArrays();
      final int[] positions = mpq.getPositions();
      if (positions.length > 0) {

        int maxPosition = positions[positions.length - 1];
        for (int i = 0; i < positions.length - 1; ++i) {
          if (positions[i] > maxPosition) {
            maxPosition = positions[i];
          }
        }

        @SuppressWarnings({"unchecked","rawtypes"})
        final List[] disjunctLists = new List[maxPosition + 1];
        int distinctPositions = 0;

        for (int i = 0; i < termArrays.length; ++i) {
          final Term[] termArray = termArrays[i];
          List disjuncts = disjunctLists[positions[i]];
          if (disjuncts == null) {
            disjuncts = (disjunctLists[positions[i]] = new ArrayList<>(termArray.length));
            ++distinctPositions;
          }
          for (int j = 0; j < termArray.length; ++j) {
            disjuncts.add(new SpanTermQuery(termArray[j]));
          }
        }

        int positionGaps = 0;
        int position = 0;
        final SpanQuery[] clauses = new SpanQuery[distinctPositions];
        for (int i = 0; i < disjunctLists.length; ++i) {
          List disjuncts = disjunctLists[i];
          if (disjuncts != null) {
            clauses[position++] = new SpanOrQuery(disjuncts
                .toArray(new SpanQuery[disjuncts.size()]));
          } else {
            ++positionGaps;
          }
        }

        final int slop = mpq.getSlop();
        final boolean inorder = (slop == 0);

        SpanNearQuery sp = new SpanNearQuery(clauses, slop + positionGaps, inorder);
        extractWeightedSpanTerms(terms, sp, boost);
      }
    } else if (query instanceof MatchAllDocsQuery) {
      //nothing
    } else if (query instanceof CustomScoreQuery){
      extract(((CustomScoreQuery) query).getSubQuery(), boost, terms);
    } else if (isQueryUnsupported(query.getClass())) {
      // nothing
    } else {
      Query origQuery = query;
      final IndexReader reader = getLeafContext().reader();
      Query rewritten;
      if (query instanceof MultiTermQuery) {
        if (!expandMultiTermQuery) {
          return;
        }
        rewritten = MultiTermQuery.SCORING_BOOLEAN_REWRITE.rewrite(reader, (MultiTermQuery) query);
      } else {
        rewritten = origQuery.rewrite(reader);
      }
      if (rewritten != origQuery) {
        // only rewrite once and then flatten again - the rewritten query could have a special treatment
        // if this method is overwritten in a subclass or above in the next recursion
        extract(rewritten, boost, terms);
      } else {
        extractUnknownQuery(query, terms);
      }
    }
  }

  protected boolean isQueryUnsupported(Class clazz) {
    // spatial queries do not support highlighting:
    if (clazz.getName().startsWith("org.apache.lucene.spatial.")) {
      return true;
    }
    // spatial3d queries are also not supported:
    if (clazz.getName().startsWith("org.apache.lucene.spatial3d.")) {
      return true;
    }
    return false;
  }

  protected void extractUnknownQuery(Query query,
      Map terms) throws IOException {
    
    // for sub-classing to extract custom queries
  }

  /**
   * Fills a Map with {@link WeightedSpanTerm}s using the terms from the supplied SpanQuery.
   * 
   * @param terms
   *          Map to place created WeightedSpanTerms in
   * @param spanQuery
   *          SpanQuery to extract Terms from
   * @throws IOException If there is a low-level I/O error
   */
  protected void extractWeightedSpanTerms(Map terms, SpanQuery spanQuery, float boost) throws IOException {
    Set fieldNames;

    if (fieldName == null) {
      fieldNames = new HashSet<>();
      collectSpanQueryFields(spanQuery, fieldNames);
    } else {
      fieldNames = new HashSet<>(1);
      fieldNames.add(fieldName);
    }
    // To support the use of the default field name
    if (defaultField != null) {
      fieldNames.add(defaultField);
    }
    
    Map queries = new HashMap<>();
 
    Set nonWeightedTerms = new HashSet<>();
    final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
    final IndexSearcher searcher = new IndexSearcher(getLeafContext());
    searcher.setQueryCache(null);
    if (mustRewriteQuery) {
      for (final String field : fieldNames) {
        final SpanQuery rewrittenQuery = (SpanQuery) spanQuery.rewrite(getLeafContext().reader());
        queries.put(field, rewrittenQuery);
        rewrittenQuery.createWeight(searcher, false).extractTerms(nonWeightedTerms);
      }
    } else {
      spanQuery.createWeight(searcher, false).extractTerms(nonWeightedTerms);
    }

    List spanPositions = new ArrayList<>();

    for (final String field : fieldNames) {
      final SpanQuery q;
      if (mustRewriteQuery) {
        q = queries.get(field);
      } else {
        q = spanQuery;
      }
      LeafReaderContext context = getLeafContext();
      SpanWeight w = (SpanWeight) searcher.createNormalizedWeight(q, false);
      Bits acceptDocs = context.reader().getLiveDocs();
      final Spans spans = w.getSpans(context, SpanWeight.Postings.POSITIONS);
      if (spans == null) {
        return;
      }

      // collect span positions
      while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
        if (acceptDocs != null && acceptDocs.get(spans.docID()) == false) {
          continue;
        }
        while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) {
          spanPositions.add(new PositionSpan(spans.startPosition(), spans.endPosition() - 1));
        }
      }
    }

    if (spanPositions.size() == 0) {
      // no spans found
      return;
    }

    for (final Term queryTerm :  nonWeightedTerms) {

      if (fieldNameComparator(queryTerm.field())) {
        WeightedSpanTerm weightedSpanTerm = terms.get(queryTerm.text());

        if (weightedSpanTerm == null) {
          weightedSpanTerm = new WeightedSpanTerm(boost, queryTerm.text());
          weightedSpanTerm.addPositionSpans(spanPositions);
          weightedSpanTerm.positionSensitive = true;
          terms.put(queryTerm.text(), weightedSpanTerm);
        } else {
          if (spanPositions.size() > 0) {
            weightedSpanTerm.addPositionSpans(spanPositions);
          }
        }
      }
    }
  }

  /**
   * Fills a Map with {@link WeightedSpanTerm}s using the terms from the supplied Query.
   * 
   * @param terms
   *          Map to place created WeightedSpanTerms in
   * @param query
   *          Query to extract Terms from
   * @throws IOException If there is a low-level I/O error
   */
  protected void extractWeightedTerms(Map terms, Query query, float boost) throws IOException {
    Set nonWeightedTerms = new HashSet<>();
    final IndexSearcher searcher = new IndexSearcher(getLeafContext());
    searcher.createNormalizedWeight(query, false).extractTerms(nonWeightedTerms);

    for (final Term queryTerm : nonWeightedTerms) {

      if (fieldNameComparator(queryTerm.field())) {
        WeightedSpanTerm weightedSpanTerm = new WeightedSpanTerm(boost, queryTerm.text());
        terms.put(queryTerm.text(), weightedSpanTerm);
      }
    }
  }

  /**
   * Necessary to implement matches for queries against defaultField
   */
  protected boolean fieldNameComparator(String fieldNameToCheck) {
    boolean rv = fieldName == null || fieldName.equals(fieldNameToCheck)
      || (defaultField != null && defaultField.equals(fieldNameToCheck));
    return rv;
  }

  protected LeafReaderContext getLeafContext() throws IOException {
    if (internalReader == null) {
      boolean cacheIt = wrapToCaching && !(tokenStream instanceof CachingTokenFilter);

      // If it's from term vectors, simply wrap the underlying Terms in a reader
      if (tokenStream instanceof TokenStreamFromTermVector) {
        cacheIt = false;
        Terms termVectorTerms = ((TokenStreamFromTermVector) tokenStream).getTermVectorTerms();
        if (termVectorTerms.hasPositions() && termVectorTerms.hasOffsets()) {
          internalReader = new TermVectorLeafReader(DelegatingLeafReader.FIELD_NAME, termVectorTerms);
        }
      }

      // Use MemoryIndex (index/invert this tokenStream now)
      if (internalReader == null) {
        final MemoryIndex indexer = new MemoryIndex(true, usePayloads);//offsets and payloads
        if (cacheIt) {
          assert !cachedTokenStream;
          tokenStream = new CachingTokenFilter(new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
          cachedTokenStream = true;
          indexer.addField(DelegatingLeafReader.FIELD_NAME, tokenStream);
        } else {
          indexer.addField(DelegatingLeafReader.FIELD_NAME,
              new OffsetLimitTokenFilter(tokenStream, maxDocCharsToAnalyze));
        }
        final IndexSearcher searcher = indexer.createSearcher();
        // MEM index has only atomic ctx
        internalReader = ((LeafReaderContext) searcher.getTopReaderContext()).reader();
      }

      //Now wrap it so we always use a common field.
      this.internalReader = new DelegatingLeafReader(internalReader);
    }

    return internalReader.getContext();
  }
  
  /*
   * This reader will just delegate every call to a single field in the wrapped
   * LeafReader. This way we only need to build this field once rather than
   * N-Times
   */
  static final class DelegatingLeafReader extends FilterLeafReader {
    private static final String FIELD_NAME = "shadowed_field";

    DelegatingLeafReader(LeafReader in) {
      super(in);
    }
    
    @Override
    public FieldInfos getFieldInfos() {
      throw new UnsupportedOperationException();
    }

    @Override
    public Fields fields() throws IOException {
      return new FilterFields(super.fields()) {
        @Override
        public Terms terms(String field) throws IOException {
          return super.terms(DelegatingLeafReader.FIELD_NAME);
        }

        @Override
        public Iterator iterator() {
          return Collections.singletonList(DelegatingLeafReader.FIELD_NAME).iterator();
        }

        @Override
        public int size() {
          return 1;
        }
      };
    }

    @Override
    public NumericDocValues getNumericDocValues(String field) throws IOException {
      return super.getNumericDocValues(FIELD_NAME);
    }
    
    @Override
    public BinaryDocValues getBinaryDocValues(String field) throws IOException {
      return super.getBinaryDocValues(FIELD_NAME);
    }
    
    @Override
    public SortedDocValues getSortedDocValues(String field) throws IOException {
      return super.getSortedDocValues(FIELD_NAME);
    }
    
    @Override
    public NumericDocValues getNormValues(String field) throws IOException {
      return super.getNormValues(FIELD_NAME);
    }

    @Override
    public Bits getDocsWithField(String field) throws IOException {
      return super.getDocsWithField(FIELD_NAME);
    }
  }

  /**
   * Creates a Map of WeightedSpanTerms from the given Query and TokenStream.
   * 
   * 

* * @param query * that caused hit * @param tokenStream * of text to be highlighted * @return Map containing WeightedSpanTerms * @throws IOException If there is a low-level I/O error */ public Map getWeightedSpanTerms(Query query, float boost, TokenStream tokenStream) throws IOException { return getWeightedSpanTerms(query, boost, tokenStream, null); } /** * Creates a Map of WeightedSpanTerms from the given Query and TokenStream. * *

* * @param query * that caused hit * @param tokenStream * of text to be highlighted * @param fieldName * restricts Term's used based on field name * @return Map containing WeightedSpanTerms * @throws IOException If there is a low-level I/O error */ public Map getWeightedSpanTerms(Query query, float boost, TokenStream tokenStream, String fieldName) throws IOException { if (fieldName != null) { this.fieldName = fieldName; } else { this.fieldName = null; } Map terms = new PositionCheckingMap<>(); this.tokenStream = tokenStream; try { extract(query, boost, terms); } finally { IOUtils.close(internalReader); } return terms; } /** * Creates a Map of WeightedSpanTerms from the given Query and TokenStream. Uses a supplied * IndexReader to properly weight terms (for gradient highlighting). * *

* * @param query * that caused hit * @param tokenStream * of text to be highlighted * @param fieldName * restricts Term's used based on field name * @param reader * to use for scoring * @return Map of WeightedSpanTerms with quasi tf/idf scores * @throws IOException If there is a low-level I/O error */ public Map getWeightedSpanTermsWithScores(Query query, float boost, TokenStream tokenStream, String fieldName, IndexReader reader) throws IOException { if (fieldName != null) { this.fieldName = fieldName; } else { this.fieldName = null; } this.tokenStream = tokenStream; Map terms = new PositionCheckingMap<>(); extract(query, boost, terms); int totalNumDocs = reader.maxDoc(); Set weightedTerms = terms.keySet(); Iterator it = weightedTerms.iterator(); try { while (it.hasNext()) { WeightedSpanTerm weightedSpanTerm = terms.get(it.next()); int docFreq = reader.docFreq(new Term(fieldName, weightedSpanTerm.term)); // IDF algorithm taken from ClassicSimilarity class float idf = (float) (Math.log(totalNumDocs / (double) (docFreq + 1)) + 1.0); weightedSpanTerm.weight *= idf; } } finally { IOUtils.close(internalReader); } return terms; } protected void collectSpanQueryFields(SpanQuery spanQuery, Set fieldNames) { if (spanQuery instanceof FieldMaskingSpanQuery) { collectSpanQueryFields(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery(), fieldNames); } else if (spanQuery instanceof SpanFirstQuery) { collectSpanQueryFields(((SpanFirstQuery)spanQuery).getMatch(), fieldNames); } else if (spanQuery instanceof SpanNearQuery) { for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) { collectSpanQueryFields(clause, fieldNames); } } else if (spanQuery instanceof SpanNotQuery) { collectSpanQueryFields(((SpanNotQuery)spanQuery).getInclude(), fieldNames); } else if (spanQuery instanceof SpanOrQuery) { for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) { collectSpanQueryFields(clause, fieldNames); } } else { fieldNames.add(spanQuery.getField()); } } protected boolean mustRewriteQuery(SpanQuery spanQuery) { if (!expandMultiTermQuery) { return false; // Will throw UnsupportedOperationException in case of a SpanRegexQuery. } else if (spanQuery instanceof FieldMaskingSpanQuery) { return mustRewriteQuery(((FieldMaskingSpanQuery)spanQuery).getMaskedQuery()); } else if (spanQuery instanceof SpanFirstQuery) { return mustRewriteQuery(((SpanFirstQuery)spanQuery).getMatch()); } else if (spanQuery instanceof SpanNearQuery) { for (final SpanQuery clause : ((SpanNearQuery)spanQuery).getClauses()) { if (mustRewriteQuery(clause)) { return true; } } return false; } else if (spanQuery instanceof SpanNotQuery) { SpanNotQuery spanNotQuery = (SpanNotQuery)spanQuery; return mustRewriteQuery(spanNotQuery.getInclude()) || mustRewriteQuery(spanNotQuery.getExclude()); } else if (spanQuery instanceof SpanOrQuery) { for (final SpanQuery clause : ((SpanOrQuery)spanQuery).getClauses()) { if (mustRewriteQuery(clause)) { return true; } } return false; } else if (spanQuery instanceof SpanTermQuery) { return false; } else { return true; } } /** * This class makes sure that if both position sensitive and insensitive * versions of the same term are added, the position insensitive one wins. */ @SuppressWarnings("serial") protected static class PositionCheckingMap extends HashMap { @Override public void putAll(Map m) { for (Map.Entry entry : m.entrySet()) this.put(entry.getKey(), entry.getValue()); } @Override public WeightedSpanTerm put(K key, WeightedSpanTerm value) { WeightedSpanTerm prev = super.put(key, value); if (prev == null) return prev; WeightedSpanTerm prevTerm = prev; WeightedSpanTerm newTerm = value; if (!prevTerm.positionSensitive) { newTerm.positionSensitive = false; } return prev; } } public boolean getExpandMultiTermQuery() { return expandMultiTermQuery; } public void setExpandMultiTermQuery(boolean expandMultiTermQuery) { this.expandMultiTermQuery = expandMultiTermQuery; } public boolean isUsePayloads() { return usePayloads; } public void setUsePayloads(boolean usePayloads) { this.usePayloads = usePayloads; } public boolean isCachedTokenStream() { return cachedTokenStream; } /** Returns the tokenStream which may have been wrapped in a CachingTokenFilter. * getWeightedSpanTerms* sets the tokenStream, so don't call this before. */ public TokenStream getTokenStream() { assert tokenStream != null; return tokenStream; } /** * By default, {@link TokenStream}s that are not of the type * {@link CachingTokenFilter} are wrapped in a {@link CachingTokenFilter} to * ensure an efficient reset - if you are already using a different caching * {@link TokenStream} impl and you don't want it to be wrapped, set this to * false. This setting is ignored when a term vector based TokenStream is supplied, * since it can be reset efficiently. */ public void setWrapIfNotCachingTokenFilter(boolean wrap) { this.wrapToCaching = wrap; } /** A threshold of number of characters to analyze. When a TokenStream based on * term vectors with offsets and positions are supplied, this setting * does not apply. */ protected final void setMaxDocCharsToAnalyze(int maxDocCharsToAnalyze) { this.maxDocCharsToAnalyze = maxDocCharsToAnalyze; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy