All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.anserini.server.SearchService Maven / Gradle / Ivy

/*
 * Anserini: A Lucene toolkit for reproducible information retrieval research
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.anserini.server;

import io.anserini.index.Constants;
import io.anserini.search.ScoredDoc;
import io.anserini.search.SimpleSearcher;
import io.anserini.search.HnswDenseSearcher;
import io.anserini.util.PrebuiltIndexHandler;
import io.anserini.index.IndexInfo;
import io.anserini.index.ShardInfo;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.apache.lucene.document.Document;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

public class SearchService {

  private final String indexDir;
  private final String prebuiltIndex;
  private final float k1 = 0.9f;
  private final float b = 0.4f;
  private final ObjectMapper mapper = new ObjectMapper();
  private final boolean isHnswIndex;
  private final Map indexOverrides = new ConcurrentHashMap<>();

  public SearchService(String prebuiltIndex) {
    this.prebuiltIndex = prebuiltIndex;
    IndexInitializationResult result = initializeIndex(prebuiltIndex);
    this.indexDir = result.indexDir;
    this.isHnswIndex = result.isHnswIndex;
    if (result.error != null) {
      throw new RuntimeException(result.error);
    }
  }

  public List> search(String query, int hits) {
    return search(query, hits, null, null, null);
  }

  public List> search(String query, int hits, Integer efSearch, String encoder,
      String queryGenerator) {
    validateSearchParameters(query, hits);
    validateSettings(efSearch, encoder, queryGenerator);

    try {
      if (!isHnswIndex) {
        try (SimpleSearcher searcher = new SimpleSearcher(indexDir)) {
          searcher.set_bm25(k1, b);
          ScoredDoc[] results = searcher.search(query, hits);
          List> candidates = new ArrayList<>();
          for (ScoredDoc r : results) {
            Map candidate = new LinkedHashMap<>();
            candidate.put("docid", r.docid);
            candidate.put("score", r.score);

            String raw = r.lucene_document.get(Constants.RAW);
            if (raw != null) {
              JsonNode rootNode = mapper.readTree(raw);
              Map content = mapper.convertValue(rootNode, Map.class);
              content.remove("docid");
              content.remove("id");
              content.remove("_id");
              candidate.put("doc", content);
            } else {
              candidate.put("doc", null);
            }
            candidates.add(candidate);
          }
          return candidates;
        }
      } else {
        IndexInfo indexInfo = IndexInfo.get(prebuiltIndex);
        HnswDenseSearcher.Args args = new HnswDenseSearcher.Args();
        // Various fallbacks for if the user doesn't provide a parameter
        args.index = indexDir;
        args.efSearch = efSearch != null ? efSearch : getEfSearchOverride() != null ? getEfSearchOverride(): IndexInfo.DEFAULT_EF_SEARCH;
        args.encoder = encoder != null ? encoder : getEncoderOverride() != null ? getEncoderOverride(): indexInfo.encoder;
        args.queryGenerator = queryGenerator != null ? queryGenerator : getQueryGeneratorOverride() != null ? getQueryGeneratorOverride(): indexInfo.queryGenerator;
        try (HnswDenseSearcher searcher = new HnswDenseSearcher(args)) {
          ScoredDoc[] results = searcher.search(query, hits);
          List> candidates = new ArrayList<>();
          for (ScoredDoc r : results) {
            candidates.add(Map.of("docid", r.docid, "score", r.score));
          }
          return candidates;
        }
      }
    } catch (Exception e) {
      e.printStackTrace();
      return List.of();
    }
  }

  public Map getDocument(String docid) {
    if (!isHnswIndex)
      throw new IllegalArgumentException("getDocument is only supported for HNSW indexes");
    try (SimpleSearcher searcher = new SimpleSearcher(indexDir)) {
      Document lucene_document = searcher.doc(docid);
      if (lucene_document == null) {
        return Map.of("error", "Document not found: " + docid);
      }

      String raw = lucene_document.get(Constants.RAW);
      Map candidate = new LinkedHashMap<>();
      if (raw != null) {
        JsonNode rootNode = mapper.readTree(raw);
        Map content = mapper.convertValue(rootNode, Map.class);
        content.remove("docid");
        content.remove("id");
        content.remove("_id");
        candidate.put("doc", content);
      } else {
        candidate.put("doc", null);
      }
      return candidate;
    } catch (Exception e) {
      e.printStackTrace();
      return Map.of("error", "Error retrieving document: " + e.getMessage());
    }
  }

  public Integer getEfSearchOverride() {
    return (Integer) indexOverrides.get("efSearch");
  }

  public String getEncoderOverride() {
    return (String) indexOverrides.get("encoder");
  }

  public String getQueryGeneratorOverride() {
    return (String) indexOverrides.get("queryGenerator");
  }

  public void setEfSearchOverride(String value) {
    if (value == null || value.trim().isEmpty()) {
      throw new IllegalArgumentException("efSearch cannot be empty");
    }

    int efSearch;
    try {
      efSearch = Integer.parseInt(value.trim());
    } catch (NumberFormatException e) {
      throw new IllegalArgumentException("efSearch must be a valid integer, but got: " + value);
    }
    validateSettings(efSearch, getEncoderOverride(), getQueryGeneratorOverride());
    indexOverrides.put("efSearch", efSearch);
  }

  public void setEncoderOverride(String value) {
    if (value == null || value.trim().isEmpty()) {
      throw new IllegalArgumentException("Encoder cannot be empty");
    }
    validateSettings(getEfSearchOverride(), value, getQueryGeneratorOverride());
    indexOverrides.put("encoder", value);
  }

  public void setQueryGeneratorOverride(String value) {
    if (value == null || value.trim().isEmpty()) {
      throw new IllegalArgumentException("QueryGenerator cannot be empty");
    }
    validateSettings(getEfSearchOverride(), getEncoderOverride(), value);
    indexOverrides.put("queryGenerator", value);
  }

  private void validateSearchParameters(String query, int hits) {
    if (query == null || query.trim().isEmpty()) {
      throw new IllegalArgumentException("Query cannot be empty");
    }
    if (hits <= 0) {
      throw new IllegalArgumentException("Number of hits must be positive");
    }
  }

  private void validateSettings(Integer efSearch, String encoder, String queryGenerator) {
    IndexInfo indexInfo = IndexInfo.get(prebuiltIndex);

    if (efSearch != null) {
      if (efSearch <= 0) {
        throw new IllegalArgumentException("efSearch must be positive but got " + efSearch);
      }
      if (!isHnswIndex) {
        throw new IllegalArgumentException(
            "efSearch parameter is only supported for HNSW indexes, but index " + prebuiltIndex + " is not HNSW");
      }
    }

    if (encoder != null) {
      if (!encoder.equals(indexInfo.encoder)) {
        throw new IllegalArgumentException("Unsupported encoder: " + encoder + " for index " + prebuiltIndex);
      }
    }

    if (queryGenerator != null) {
      if (!queryGenerator.equals(indexInfo.queryGenerator)) {
        throw new IllegalArgumentException(
            "Unsupported queryGenerator: " + queryGenerator + " for index " + prebuiltIndex);
      }
    }
  }

  private IndexInitializationResult initializeIndex(String prebuiltIndex) {
    try {
      PrebuiltIndexHandler handler = new PrebuiltIndexHandler(prebuiltIndex);
      handler.initialize();
      handler.download();
      String indexDir = handler.decompressIndex();
      IndexInfo indexInfo = IndexInfo.get(prebuiltIndex);
      boolean isHnsw = indexInfo.indexType == IndexInfo.IndexType.DENSE_HNSW;
      return new IndexInitializationResult(indexDir, isHnsw, null);
    } catch (Exception e) {
      return new IndexInitializationResult(null, false, e);
    }
  }

  static List> searchSharded(String identifier, String query, int hits, Integer efSearch, String encoder, String queryGenerator) {

    // Retrieve the shards from ShardInfo organizer enum
    ShardInfo shardInfo = ShardInfo.fromIdentifier(identifier);
    IndexInfo[] shards = shardInfo.getShards();

    return Arrays.stream(shards)
      .parallel()
      .map(shard -> {
        SearchService service = new SearchService(shard.indexName);
        return service.search(query, hits, efSearch, encoder, queryGenerator);
      })
      .flatMap(List::stream)
      .sorted((a, b) -> ((Float) b.get("score")).compareTo((Float) a.get("score")))
      .limit(hits)
      .collect(Collectors.toList());
  }

  private static class IndexInitializationResult {
    final String indexDir;
    final boolean isHnswIndex;
    final Exception error;

    IndexInitializationResult(String indexDir, boolean isHnswIndex, Exception error) {
      this.indexDir = indexDir;
      this.isHnswIndex = isHnswIndex;
      this.error = error;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy