Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.anserini.server.SearchService Maven / Gradle / Ivy
/*
* Anserini: A Lucene toolkit for reproducible information retrieval research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.anserini.server;
import io.anserini.index.Constants;
import io.anserini.search.ScoredDoc;
import io.anserini.search.SimpleSearcher;
import io.anserini.search.HnswDenseSearcher;
import io.anserini.util.PrebuiltIndexHandler;
import io.anserini.index.IndexInfo;
import io.anserini.index.ShardInfo;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.lucene.document.Document;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;
public class SearchService {
private final String indexDir;
private final String prebuiltIndex;
private final float k1 = 0.9f;
private final float b = 0.4f;
private final ObjectMapper mapper = new ObjectMapper();
private final boolean isHnswIndex;
private final Map indexOverrides = new ConcurrentHashMap<>();
public SearchService(String prebuiltIndex) {
this.prebuiltIndex = prebuiltIndex;
IndexInitializationResult result = initializeIndex(prebuiltIndex);
this.indexDir = result.indexDir;
this.isHnswIndex = result.isHnswIndex;
if (result.error != null) {
throw new RuntimeException(result.error);
}
}
public List> search(String query, int hits) {
return search(query, hits, null, null, null);
}
public List> search(String query, int hits, Integer efSearch, String encoder,
String queryGenerator) {
validateSearchParameters(query, hits);
validateSettings(efSearch, encoder, queryGenerator);
try {
if (!isHnswIndex) {
try (SimpleSearcher searcher = new SimpleSearcher(indexDir)) {
searcher.set_bm25(k1, b);
ScoredDoc[] results = searcher.search(query, hits);
List> candidates = new ArrayList<>();
for (ScoredDoc r : results) {
Map candidate = new LinkedHashMap<>();
candidate.put("docid", r.docid);
candidate.put("score", r.score);
String raw = r.lucene_document.get(Constants.RAW);
if (raw != null) {
JsonNode rootNode = mapper.readTree(raw);
Map content = mapper.convertValue(rootNode, Map.class);
content.remove("docid");
content.remove("id");
content.remove("_id");
candidate.put("doc", content);
} else {
candidate.put("doc", null);
}
candidates.add(candidate);
}
return candidates;
}
} else {
IndexInfo indexInfo = IndexInfo.get(prebuiltIndex);
HnswDenseSearcher.Args args = new HnswDenseSearcher.Args();
// Various fallbacks for if the user doesn't provide a parameter
args.index = indexDir;
args.efSearch = efSearch != null ? efSearch : getEfSearchOverride() != null ? getEfSearchOverride(): IndexInfo.DEFAULT_EF_SEARCH;
args.encoder = encoder != null ? encoder : getEncoderOverride() != null ? getEncoderOverride(): indexInfo.encoder;
args.queryGenerator = queryGenerator != null ? queryGenerator : getQueryGeneratorOverride() != null ? getQueryGeneratorOverride(): indexInfo.queryGenerator;
try (HnswDenseSearcher searcher = new HnswDenseSearcher(args)) {
ScoredDoc[] results = searcher.search(query, hits);
List> candidates = new ArrayList<>();
for (ScoredDoc r : results) {
candidates.add(Map.of("docid", r.docid, "score", r.score));
}
return candidates;
}
}
} catch (Exception e) {
e.printStackTrace();
return List.of();
}
}
public Map getDocument(String docid) {
if (!isHnswIndex)
throw new IllegalArgumentException("getDocument is only supported for HNSW indexes");
try (SimpleSearcher searcher = new SimpleSearcher(indexDir)) {
Document lucene_document = searcher.doc(docid);
if (lucene_document == null) {
return Map.of("error", "Document not found: " + docid);
}
String raw = lucene_document.get(Constants.RAW);
Map candidate = new LinkedHashMap<>();
if (raw != null) {
JsonNode rootNode = mapper.readTree(raw);
Map content = mapper.convertValue(rootNode, Map.class);
content.remove("docid");
content.remove("id");
content.remove("_id");
candidate.put("doc", content);
} else {
candidate.put("doc", null);
}
return candidate;
} catch (Exception e) {
e.printStackTrace();
return Map.of("error", "Error retrieving document: " + e.getMessage());
}
}
public Integer getEfSearchOverride() {
return (Integer) indexOverrides.get("efSearch");
}
public String getEncoderOverride() {
return (String) indexOverrides.get("encoder");
}
public String getQueryGeneratorOverride() {
return (String) indexOverrides.get("queryGenerator");
}
public void setEfSearchOverride(String value) {
if (value == null || value.trim().isEmpty()) {
throw new IllegalArgumentException("efSearch cannot be empty");
}
int efSearch;
try {
efSearch = Integer.parseInt(value.trim());
} catch (NumberFormatException e) {
throw new IllegalArgumentException("efSearch must be a valid integer, but got: " + value);
}
validateSettings(efSearch, getEncoderOverride(), getQueryGeneratorOverride());
indexOverrides.put("efSearch", efSearch);
}
public void setEncoderOverride(String value) {
if (value == null || value.trim().isEmpty()) {
throw new IllegalArgumentException("Encoder cannot be empty");
}
validateSettings(getEfSearchOverride(), value, getQueryGeneratorOverride());
indexOverrides.put("encoder", value);
}
public void setQueryGeneratorOverride(String value) {
if (value == null || value.trim().isEmpty()) {
throw new IllegalArgumentException("QueryGenerator cannot be empty");
}
validateSettings(getEfSearchOverride(), getEncoderOverride(), value);
indexOverrides.put("queryGenerator", value);
}
private void validateSearchParameters(String query, int hits) {
if (query == null || query.trim().isEmpty()) {
throw new IllegalArgumentException("Query cannot be empty");
}
if (hits <= 0) {
throw new IllegalArgumentException("Number of hits must be positive");
}
}
private void validateSettings(Integer efSearch, String encoder, String queryGenerator) {
IndexInfo indexInfo = IndexInfo.get(prebuiltIndex);
if (efSearch != null) {
if (efSearch <= 0) {
throw new IllegalArgumentException("efSearch must be positive but got " + efSearch);
}
if (!isHnswIndex) {
throw new IllegalArgumentException(
"efSearch parameter is only supported for HNSW indexes, but index " + prebuiltIndex + " is not HNSW");
}
}
if (encoder != null) {
if (!encoder.equals(indexInfo.encoder)) {
throw new IllegalArgumentException("Unsupported encoder: " + encoder + " for index " + prebuiltIndex);
}
}
if (queryGenerator != null) {
if (!queryGenerator.equals(indexInfo.queryGenerator)) {
throw new IllegalArgumentException(
"Unsupported queryGenerator: " + queryGenerator + " for index " + prebuiltIndex);
}
}
}
private IndexInitializationResult initializeIndex(String prebuiltIndex) {
try {
PrebuiltIndexHandler handler = new PrebuiltIndexHandler(prebuiltIndex);
handler.initialize();
handler.download();
String indexDir = handler.decompressIndex();
IndexInfo indexInfo = IndexInfo.get(prebuiltIndex);
boolean isHnsw = indexInfo.indexType == IndexInfo.IndexType.DENSE_HNSW;
return new IndexInitializationResult(indexDir, isHnsw, null);
} catch (Exception e) {
return new IndexInitializationResult(null, false, e);
}
}
static List> searchSharded(String identifier, String query, int hits, Integer efSearch, String encoder, String queryGenerator) {
// Retrieve the shards from ShardInfo organizer enum
ShardInfo shardInfo = ShardInfo.fromIdentifier(identifier);
IndexInfo[] shards = shardInfo.getShards();
return Arrays.stream(shards)
.parallel()
.map(shard -> {
SearchService service = new SearchService(shard.indexName);
return service.search(query, hits, efSearch, encoder, queryGenerator);
})
.flatMap(List::stream)
.sorted((a, b) -> ((Float) b.get("score")).compareTo((Float) a.get("score")))
.limit(hits)
.collect(Collectors.toList());
}
private static class IndexInitializationResult {
final String indexDir;
final boolean isHnswIndex;
final Exception error;
IndexInitializationResult(String indexDir, boolean isHnswIndex, Exception error) {
this.indexDir = indexDir;
this.isHnswIndex = isHnswIndex;
this.error = error;
}
}
}