All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.search.ReRankCollector Maven / Gradle / Ivy

There is a newer version: 9.7.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.search;

import com.carrotsearch.hppc.IntFloatHashMap;
import com.carrotsearch.hppc.IntIntHashMap;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrException;
import org.apache.solr.handler.component.QueryElevationComponent;
import org.apache.solr.request.SolrRequestInfo;

/* A TopDocsCollector used by reranking queries. */
public class ReRankCollector extends TopDocsCollector {

  private final TopDocsCollector mainCollector;
  private final IndexSearcher searcher;
  private final int reRankDocs;
  private final int length;
  private final Set boostedPriority; // order is the "priority"
  private final Rescorer reRankQueryRescorer;
  private final Sort sort;
  private final Query query;
  private ReRankScaler reRankScaler;
  private ReRankOperator reRankOperator;

  public ReRankCollector(
      int reRankDocs,
      int length,
      Rescorer reRankQueryRescorer,
      QueryCommand cmd,
      IndexSearcher searcher,
      Set boostedPriority,
      ReRankScaler reRankScaler,
      ReRankOperator reRankOperator)
      throws IOException {
    this(reRankDocs, length, reRankQueryRescorer, cmd, searcher, boostedPriority);
    this.reRankScaler = reRankScaler;
    this.reRankOperator = reRankOperator;
  }

  public ReRankCollector(
      int reRankDocs,
      int length,
      Rescorer reRankQueryRescorer,
      QueryCommand cmd,
      IndexSearcher searcher,
      Set boostedPriority)
      throws IOException {
    super(null);
    this.reRankDocs = reRankDocs;
    this.length = length;
    this.boostedPriority = boostedPriority;
    this.query = cmd.getQuery();
    Sort sort = cmd.getSort();
    if (sort == null) {
      this.sort = null;
      this.mainCollector =
          TopScoreDocCollector.create(Math.max(this.reRankDocs, length), cmd.getMinExactCount());
    } else {
      this.sort = sort = sort.rewrite(searcher);
      // scores are needed for Rescorer (regardless of whether sort needs it)
      this.mainCollector =
          TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), cmd.getMinExactCount());
    }
    this.searcher = searcher;
    this.reRankQueryRescorer = reRankQueryRescorer;
  }

  @Override
  public int getTotalHits() {
    return mainCollector.getTotalHits();
  }

  @Override
  public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
    return mainCollector.getLeafCollector(context);
  }

  @Override
  public ScoreMode scoreMode() {
    return this.mainCollector.scoreMode();
  }

  @Override
  public TopDocs topDocs(int start, int howMany) {

    try {

      TopDocs mainDocs = mainCollector.topDocs(0, Math.max(reRankDocs, length));

      if (mainDocs.totalHits.value == 0 || mainDocs.scoreDocs.length == 0) {
        return mainDocs;
      }

      if (sort != null) {
        TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, query);
      }

      ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs;
      boolean zeroOutScores = reRankScaler != null && reRankScaler.scaleScores();
      ScoreDoc[] mainScoreDocsClone = deepClone(mainScoreDocs, zeroOutScores);
      ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)];
      System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length);

      mainDocs.scoreDocs = reRankScoreDocs;

      // If we're scaling scores use the replace rescorer because we just want the re-rank score.
      TopDocs rescoredDocs;
      try {
        rescoredDocs =
            zeroOutScores // previously zero-ed out scores are to be replaced
                ? reRankScaler
                    .getReplaceRescorer()
                    .rescore(searcher, mainDocs, mainDocs.scoreDocs.length)
                : reRankQueryRescorer.rescore(searcher, mainDocs, mainDocs.scoreDocs.length);
      } catch (IncompleteRerankingException ex) {
        mainDocs.scoreDocs = mainScoreDocsClone;
        rescoredDocs = mainDocs;
      }

      // Lower howMany to return if we've collected fewer documents.
      howMany = Math.min(howMany, mainScoreDocs.length);

      if (boostedPriority != null) {
        SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
        Map requestContext = null;
        if (info != null) {
          requestContext = info.getReq().getContext();
        }

        IntIntHashMap boostedDocs =
            QueryElevationComponent.getBoostDocs(
                (SolrIndexSearcher) searcher, boostedPriority, requestContext);

        float maxScore =
            rescoredDocs.scoreDocs.length == 0 ? Float.NaN : rescoredDocs.scoreDocs[0].score;
        Arrays.sort(
            rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, maxScore));
      }

      if (howMany == rescoredDocs.scoreDocs.length) {
        if (reRankScaler != null && reRankScaler.scaleScores()) {
          rescoredDocs.scoreDocs =
              reRankScaler.scaleScores(
                  mainScoreDocsClone, rescoredDocs.scoreDocs, reRankScoreDocs.length);
        }
        return rescoredDocs; // Just return the rescoredDocs
      } else if (howMany > rescoredDocs.scoreDocs.length) {
        // We need to return more then we've reRanked, so create the combined page.
        ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
        System.arraycopy(
            mainScoreDocs, 0, scoreDocs, 0, scoreDocs.length); // lay down the initial docs
        System.arraycopy(
            rescoredDocs.scoreDocs,
            0,
            scoreDocs,
            0,
            rescoredDocs.scoreDocs.length); // overlay the re-ranked docs.
        rescoredDocs.scoreDocs = scoreDocs;
        if (reRankScaler != null && reRankScaler.scaleScores()) {
          rescoredDocs.scoreDocs =
              reRankScaler.scaleScores(
                  mainScoreDocsClone, rescoredDocs.scoreDocs, reRankScoreDocs.length);
        }
        return rescoredDocs;
      } else {
        // We've rescored more then we need to return.

        if (reRankScaler != null && reRankScaler.scaleScores()) {
          rescoredDocs.scoreDocs =
              reRankScaler.scaleScores(
                  mainScoreDocsClone, rescoredDocs.scoreDocs, rescoredDocs.scoreDocs.length);
        }
        ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
        System.arraycopy(rescoredDocs.scoreDocs, 0, scoreDocs, 0, howMany);
        rescoredDocs.scoreDocs = scoreDocs;
        return rescoredDocs;
      }
    } catch (Exception e) {
      throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
    }
  }

  private ScoreDoc[] deepClone(ScoreDoc[] scoreDocs, boolean zeroOut) {
    ScoreDoc[] scoreDocs1 = new ScoreDoc[scoreDocs.length];
    for (int i = 0; i < scoreDocs.length; i++) {
      ScoreDoc scoreDoc = scoreDocs[i];
      if (scoreDoc != null) {
        scoreDocs1[i] = new ScoreDoc(scoreDoc.doc, scoreDoc.score);
        if (zeroOut) {
          scoreDoc.score = 0f;
        }
      }
    }
    return scoreDocs1;
  }

  public static class BoostedComp implements Comparator {
    IntFloatHashMap boostedMap;

    public BoostedComp(IntIntHashMap boostedDocs, ScoreDoc[] scoreDocs, float maxScore) {
      this.boostedMap = new IntFloatHashMap(boostedDocs.size() * 2);

      for (int i = 0; i < scoreDocs.length; i++) {
        final int idx;
        if ((idx = boostedDocs.indexOf(scoreDocs[i].doc)) >= 0) {
          boostedMap.put(scoreDocs[i].doc, maxScore + boostedDocs.indexGet(idx));
        } else {
          break;
        }
      }
    }

    @Override
    public int compare(ScoreDoc doc1, ScoreDoc doc2) {
      float score1 = doc1.score;
      float score2 = doc2.score;
      int idx;
      if ((idx = boostedMap.indexOf(doc1.doc)) >= 0) {
        score1 = boostedMap.indexGet(idx);
      }

      if ((idx = boostedMap.indexOf(doc2.doc)) >= 0) {
        score2 = boostedMap.indexGet(idx);
      }

      return -Float.compare(score1, score2);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy