All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.anserini.fusion.ScoredDocsFuser Maven / Gradle / Ivy

/*
 * Anserini: A Lucene toolkit for reproducible information retrieval research
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.anserini.fusion;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.StoredField;

import io.anserini.search.ScoredDocs;

//replace topic wtih const
public class ScoredDocsFuser {
  public static final String TOPIC = "TOPIC";

  /**
   * Reads a TREC run file and returns a ScoredDocs containing the data.
   * 
   * @param filepath Path to the TREC run file.
   * @throws IOException If the file cannot be read.
   * @return A ScoredDocs object containing the data from the TREC run file.
   */
  public static ScoredDocs readRun(Path filepath, boolean reSort) throws IOException {
    ScoredDocs scoredDocs = new ScoredDocs();
    try (BufferedReader br = new BufferedReader(new FileReader(filepath.toFile()))) {
      List lucene_documents = new ArrayList<>(); // topic
      List docids = new ArrayList<>(); // docid
      List scores = new ArrayList<>(); // score
      List rank = new ArrayList<>(); // rank

      String line;
      while ((line = br.readLine()) != null) {
        String[] data = line.split("\\s+");
  
        // Populate the lists with the parsed topic and docid
        Document doc = new Document();
        doc.add(new StoredField(TOPIC, data[0]));
        lucene_documents.add(doc);
        docids.add(data[2]);

        // Parse RANK as integer
        int rankInt = Integer.parseInt(data[3]);
        rank.add(rankInt);

        // Parse SCORE as float
        float scoreFloat = Float.parseFloat(data[4]);
        scores.add(scoreFloat);
      }

      scoredDocs.lucene_documents = lucene_documents.toArray(new Document[0]);
      scoredDocs.docids = docids.toArray(new String[0]);
      scoredDocs.scores = ArrayUtils.toPrimitive(scores.toArray(new Float[scores.size()]), Float.NaN);
      scoredDocs.lucene_docids = ArrayUtils.toPrimitive(rank.toArray(new Integer[0]));
    }
  
    if (reSort) {
      ScoredDocsFuser.sortScoredDocs(scoredDocs);
    }

    return scoredDocs;
  }

  /**
   * Rescored given ScoredDocs using the specified method.
   *
   * @param method  Rescore method to be applied (e.g., RRF, SCALE, NORMALIZE).
   * @param rrfK    Parameter k needed for reciprocal rank fusion.
   * @param scale   Scaling factor needed for rescoring by scaling.
   * @param scoredDocs ScoredDocs object to be rescored.
   * @throws UnsupportedOperationException If an unsupported rescore method is provided.
   */
  public static void rescore(RescoreMethod method, int rrfK, double scale, ScoredDocs scoredDocs) {
    switch (method) {
      case RRF -> ScoredDocsFuser.rescoreRRF(rrfK, scoredDocs);
      case SCALE -> ScoredDocsFuser.rescoreScale(scale, scoredDocs);
      case NORMALIZE -> ScoredDocsFuser.normalizeScores(scoredDocs);
      default -> throw new UnsupportedOperationException("Unknown rescore method: " + method);
    }
  }

  private static void rescoreRRF(int rrfK, ScoredDocs scoredDocs) {
    int length = scoredDocs.lucene_documents.length;
    for (int i = 0; i < length; i++) {
      float score = (float)(1.0 / (rrfK + scoredDocs.lucene_docids[i]));
      scoredDocs.scores[i] = score;
    }
  }

  private static void rescoreScale(double scale, ScoredDocs scoredDocs) {
    int length = scoredDocs.lucene_documents.length;
    for (int i = 0; i < length; i++) {
      float score = (float) (scoredDocs.scores[i] * scale);
      scoredDocs.scores[i] = score;
    }
  }

  private static void normalizeScores(ScoredDocs scoredDocs) {
    Map> indicesForTopics = new HashMap>(); // topic, list of indices for that topic
    int length = scoredDocs.lucene_documents.length;
    for (int i = 0; i < length; i++) {
      indicesForTopics.computeIfAbsent(scoredDocs.lucene_documents[i].get(TOPIC), k -> new ArrayList<>()).add(i);
    }

    for (List topicIndices : indicesForTopics.values()) {
      int numRecords = topicIndices.size();
      float minScore = scoredDocs.scores[topicIndices.get(0)];
      float maxScore = scoredDocs.scores[topicIndices.get(numRecords - 1)];
      for (int i = 0; i < numRecords; i++) {
        int index = topicIndices.get(i);
        minScore = Float.min(minScore, scoredDocs.scores[index]);
        maxScore = Float.max(maxScore, scoredDocs.scores[index]);
      }

      for (int i = 0; i < numRecords; i++) {
        int index = topicIndices.get(i);
        float normalizedScore = ((float) scoredDocs.scores[index] - minScore) / (maxScore - minScore);
        scoredDocs.scores[index] = normalizedScore;
      }
    }
  }

  /**
   * Merges multiple ScoredDocs instances into a single ScoredDocs instance.
   * The merged ScoredDocs will contain the top documents for each topic, with scores summed across the input runs.
   *
   * @param runs  List of ScoredDocs instances to merge.
   * @param depth Maximum number of documents to consider from each run for each topic (null for no limit).
   * @param k     Maximum number of top documents to include in the merged run for each topic (null for no limit).
   * @return A new ScoredDocs instance containing the merged results.
   * @throws IllegalArgumentException if less than 2 runs are provided.
   */
  public static ScoredDocs merge(List runs, Integer depth, Integer k) {
    if (runs.size() < 2) {
      throw new IllegalArgumentException("Merge requires at least 2 runs.");
    }

    // for every topic, produce a map of docid to score, num of accumulated
    HashMap>> docScores = new HashMap<>();
    for (ScoredDocs run : runs) {
      for (int i = 0; i < run.lucene_documents.length; i++) {
        String query = run.lucene_documents[i].get(TOPIC);
        String docid = run.docids[i];
        Float score = run.scores[i];
        docScores.computeIfAbsent(query, key -> new HashMap<>())
          .merge(docid, new AbstractMap.SimpleEntry<>(score, 1), 
                (existing, newValue) -> 
                  existing.getValue() >= depth ? existing : new AbstractMap.SimpleEntry<>(existing.getKey() + newValue.getKey(), existing.getValue() + 1));
      }
    }
    
    List lucene_documents = new ArrayList<>(); // topic
    List docids = new ArrayList<>(); // docid
    List score = new ArrayList<>(); // score
    List rank = new ArrayList<>(); // rank
    for (String query : docScores.keySet()) {
      // for the current query, a list of all docids and scores, sorted by scores
      List> sortedDocScores = docScores.get(query).entrySet().stream()
        .map(entry -> Map.entry(entry.getKey(), entry.getValue().getKey()))
        .sorted(Map.Entry.comparingByValue().reversed())
        .limit(k != null ? k : Integer.MAX_VALUE)
        .collect(Collectors.toList());

      for (int i = 0; i < sortedDocScores.size(); i++) {
        Map.Entry entry = sortedDocScores.get(i);
        Document doc = new Document();
        doc.add(new StoredField(TOPIC, query));
        lucene_documents.add(doc);
        docids.add(entry.getKey());
        rank.add(i + 1);
        score.add(entry.getValue());
      }
    }

    ScoredDocs mergedRun = new ScoredDocs();
    mergedRun.lucene_documents = lucene_documents.toArray(new Document[0]);
    mergedRun.docids = docids.toArray(new String[0]);
    mergedRun.scores = ArrayUtils.toPrimitive(score.toArray(new Float[score.size()]), Float.NaN);
    mergedRun.lucene_docids = ArrayUtils.toPrimitive(rank.toArray(new Integer[0]));
    
    return mergedRun;
  }

  /**
   * Sorts given ScoredDocs by topic, then by score.
   * 
   * @param scoredDocs ScoredDocs object to be sorted.
   */
  public static void sortScoredDocs(ScoredDocs scoredDocs){
    Integer[] indices = new Integer[scoredDocs.lucene_documents.length];
    for (int i = 0; i < indices.length; i++) {
      indices[i] = i;
    }

    Arrays.sort(indices, (index1, index2) -> {
      String topic1 = scoredDocs.lucene_documents[index1].get(TOPIC);
      String topic2 = scoredDocs.lucene_documents[index2].get(TOPIC);
      int topicComparison = (topic1.compareTo(topic2));
      if (topicComparison != 0) {
        return topicComparison;
      }
      return Float.compare(scoredDocs.scores[index2], scoredDocs.scores[index1]);
    });

    Document[] sorted_lucene_documents = new Document[indices.length];
    String[] sortedDocids = new String[indices.length];
    float[] sortedScores = new float[indices.length];
    int[] sortedRanks = new int[indices.length];
    for (int i = 0; i < indices.length; i++) {
      int index = indices[i];
      sorted_lucene_documents[i] = scoredDocs.lucene_documents[index];
      sortedDocids[i] = scoredDocs.docids[index];
      sortedScores[i] = scoredDocs.scores[index];
      sortedRanks[i] = scoredDocs.lucene_docids[index];
    }

    scoredDocs.lucene_documents = sorted_lucene_documents;
    scoredDocs.docids = sortedDocids;
    scoredDocs.scores = sortedScores;
    scoredDocs.lucene_docids = sortedRanks;
  }

  /**
   * Saves a ScoredDocs run data to a text file in the TREC run format.
   * 
   * @param outputPath Path to the output file.
   * @param tag Tag to be added to each record in the TREC run file. If null, the existing tags are retained.
   * @param run ScoredDocs object to be saved.
   * @throws IOException If an I/O error occurs while writing to the file.
   * @throws IllegalStateException If the ScoredDocs is empty.
   */
  public static void saveToTxt(Path outputPath, String tag, ScoredDocs run) throws IOException {
    if (run.lucene_documents == null || run.lucene_documents.length == 0) {
      throw new IllegalStateException("Nothing to save. ScoredDocs is empty");
    }

    ScoredDocsFuser.sortScoredDocs(run);
    try (BufferedWriter writer = Files.newBufferedWriter(outputPath)) {
      for (int i = 0; i < run.lucene_documents.length; i++) {
        writer.write(String.format("%s Q0 %s %d %.6f %s%n", 
          run.lucene_documents[i].get(TOPIC), run.docids[i], run.lucene_docids[i], run.scores[i], tag));
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy