Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.codelibs.elasticsearch.dynarank.ranker.DynamicRanker Maven / Gradle / Ivy
Go to download
This plugin provides a feature to re-rank a search result at the search time.
package org.codelibs.elasticsearch.dynarank.ranker;
import static org.elasticsearch.action.search.ShardSearchFailure.readShardSearchFailure;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.codelibs.elasticsearch.dynarank.script.DiversitySortScriptEngine;
import org.codelibs.elasticsearch.dynarank.script.DynaRankScript;
import org.codelibs.elasticsearch.dynarank.script.DynaRankScript.Factory;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchResponse.Clusters;
import org.elasticsearch.action.search.SearchResponseSections;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexAbstraction;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.MappingMetadata;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.profile.SearchProfileResults;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.threadpool.ThreadPool;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
public class DynamicRanker extends AbstractLifecycleComponent {
private static final Logger logger = LogManager.getLogger(DynamicRanker.class);
private static DynamicRanker instance = null;
public static final String DEFAULT_SCRIPT_TYPE = "inline";
public static final String DEFAULT_SCRIPT_LANG = "painless";
public static final Setting SETTING_INDEX_DYNARANK_SCRIPT =
Setting.simpleString("index.dynarank.script_sort.script", Property.IndexScope, Property.Dynamic);
public static final Setting SETTING_INDEX_DYNARANK_LANG =
Setting.simpleString("index.dynarank.script_sort.lang", Property.IndexScope, Property.Dynamic);
public static final Setting SETTING_INDEX_DYNARANK_TYPE = new Setting<>("index.dynarank.script_sort.type",
s -> DEFAULT_SCRIPT_TYPE, Function.identity(), Property.IndexScope, Property.Dynamic);
public static final Setting SETTING_INDEX_DYNARANK_PARAMS =
Setting.groupSetting("index.dynarank.script_sort.params.", Property.IndexScope, Property.Dynamic);
public static final Setting SETTING_INDEX_DYNARANK_REORDER_SIZE =
Setting.intSetting("index.dynarank.reorder_size", 100, Property.IndexScope, Property.Dynamic);
public static final Setting SETTING_INDEX_DYNARANK_KEEP_TOPN =
Setting.intSetting("index.dynarank.keep_topn", 0, Property.IndexScope, Property.Dynamic);
public static final Setting SETTING_DYNARANK_CACHE_EXPIRE =
Setting.timeSetting("dynarank.cache.expire", TimeValue.MINUS_ONE, Property.NodeScope);
public static final Setting SETTING_DYNARANK_CACHE_CLEAN_INTERVAL =
Setting.timeSetting("dynarank.cache.clean_interval", TimeValue.timeValueSeconds(60), Property.NodeScope);
public static final String DYNARANK_RERANK_ENABLE = "Dynarank-Rerank";
public static final String DYNARANK_MIN_TOTAL_HITS = "Dynarank-Min-Total-Hits";
private final ClusterService clusterService;
private final ScriptService scriptService;
private final Cache scriptInfoCache;
private final ThreadPool threadPool;
private final NamedWriteableRegistry namedWriteableRegistry;
private final TimeValue cleanInterval;
private Reaper reaper;
private final Client client;
public static DynamicRanker getInstance() {
return instance;
}
@Inject
public DynamicRanker(final Settings settings, final Client client, final ClusterService clusterService,
final ScriptService scriptService, final ThreadPool threadPool, final ActionFilters filters,
final NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.clusterService = clusterService;
this.scriptService = scriptService;
this.threadPool = threadPool;
this.namedWriteableRegistry = namedWriteableRegistry;
logger.info("Initializing DynamicRanker");
final TimeValue expire = SETTING_DYNARANK_CACHE_EXPIRE.get(settings);
cleanInterval = SETTING_DYNARANK_CACHE_CLEAN_INTERVAL.get(settings);
final CacheBuilder builder = CacheBuilder.newBuilder().concurrencyLevel(16);
if (expire.millis() >= 0) {
builder.expireAfterAccess(expire.millis(), TimeUnit.MILLISECONDS);
}
scriptInfoCache = builder.build();
}
@Override
protected void doStart() throws ElasticsearchException {
instance = this;
reaper = new Reaper();
threadPool.schedule(reaper, cleanInterval, ThreadPool.Names.SAME);
}
@Override
protected void doStop() throws ElasticsearchException {
}
@Override
protected void doClose() throws ElasticsearchException {
reaper.close();
scriptInfoCache.invalidateAll();
}
public ActionListener wrapActionListener(final String action, final SearchRequest request,
final ActionListener listener) {
switch (request.searchType()) {
case DFS_QUERY_THEN_FETCH:
case QUERY_THEN_FETCH:
break;
default:
return null;
}
if (request.scroll() != null) {
return null;
}
final ThreadContext threadContext = threadPool.getThreadContext();
final String isRerank = threadContext.getHeader(DYNARANK_RERANK_ENABLE);
if (isRerank != null && !Boolean.valueOf(isRerank)) {
return null;
}
final SearchSourceBuilder source = request.source();
if (source == null) {
return null;
}
final String[] indices = request.indices();
if (indices == null || indices.length != 1) {
return null;
}
final String index = indices[0];
final ScriptInfo scriptInfo = getScriptInfo(index);
if (scriptInfo == null || scriptInfo.getScript() == null) {
return null;
}
final long startTime = System.nanoTime();
final int size = getInt(source.size(), 10);
final int from = getInt(source.from(), 0);
if (size < 0 || from < 0) {
return null;
}
if (from >= scriptInfo.getReorderSize()) {
return null;
}
int maxSize = scriptInfo.getReorderSize();
if (from + size > scriptInfo.getReorderSize()) {
maxSize = from + size;
}
source.size(maxSize);
source.from(0);
if (logger.isDebugEnabled()) {
logger.debug("Rewrite query: from:{}->{} size:{}->{}", from, 0, size, maxSize);
}
final ActionListener searchResponseListener =
createSearchResponseListener(request, listener, from, size, startTime, scriptInfo);
return new ActionListener() {
@Override
public void onResponse(final Response response) {
try {
searchResponseListener.onResponse(response);
} catch (final RetrySearchException e) {
threadPool.getThreadContext().putHeader(DYNARANK_RERANK_ENABLE, Boolean.FALSE.toString());
source.size(size);
source.from(from);
source.toString();
final SearchSourceBuilder newSource = e.rewrite(source);
if (newSource == null) {
throw new ElasticsearchException("Failed to rewrite source: " + source);
}
if (logger.isDebugEnabled()) {
logger.debug("Original Query: \n{}\nRewrited Query: \n{}", source, newSource);
}
request.source(newSource);
@SuppressWarnings("unchecked")
final ActionListener actionListener = (ActionListener) listener;
client.search(request, actionListener);
}
}
@Override
public void onFailure(final Exception e) {
searchResponseListener.onFailure(e);
}
};
}
public ScriptInfo getScriptInfo(final String index) {
try {
return scriptInfoCache.get(index, () -> {
final Metadata metaData = clusterService.state().getMetadata();
IndexAbstraction indexAbstraction = metaData.getIndicesLookup().get(index);
if (indexAbstraction == null) {
return ScriptInfo.NO_SCRIPT_INFO;
}
final ScriptInfo[] scriptInfos = indexAbstraction.getIndices().stream()
.map(metaData::index)
.filter(idx -> SETTING_INDEX_DYNARANK_LANG.get(idx.getSettings()).length() > 0)
.map(idx ->
new ScriptInfo(SETTING_INDEX_DYNARANK_SCRIPT.get(idx.getSettings()), SETTING_INDEX_DYNARANK_LANG.get(idx.getSettings()),
SETTING_INDEX_DYNARANK_TYPE.get(idx.getSettings()), SETTING_INDEX_DYNARANK_PARAMS.get(idx.getSettings()),
SETTING_INDEX_DYNARANK_REORDER_SIZE.get(idx.getSettings()), SETTING_INDEX_DYNARANK_KEEP_TOPN.get(idx.getSettings()),
idx.mapping())
)
.toArray(n -> new ScriptInfo[n]);
if (scriptInfos.length == 0) {
return ScriptInfo.NO_SCRIPT_INFO;
} else if (scriptInfos.length == 1) {
return scriptInfos[0];
} else {
for (final ScriptInfo scriptInfo : scriptInfos) {
if (!scriptInfo.getLang().equals(DiversitySortScriptEngine.SCRIPT_NAME)) {
return ScriptInfo.NO_SCRIPT_INFO;
}
}
return scriptInfos[0];
}
});
} catch (final Exception e) {
logger.warn("Failed to load ScriptInfo for {}.", e, index);
return null;
}
}
private ActionListener createSearchResponseListener(final SearchRequest request,
final ActionListener listener, final int from, final int size, final long startTime,
final ScriptInfo scriptInfo) {
return new ActionListener() {
@Override
public void onResponse(final Response response) {
final SearchResponse searchResponse = (SearchResponse) response;
final long totalHits = searchResponse.getHits().getTotalHits().value;
if (totalHits == 0) {
if (logger.isDebugEnabled()) {
logger.debug("totalHits is {}. No reranking results: {}", totalHits, searchResponse);
}
listener.onResponse(response);
return;
}
final String minTotalHitsValue = threadPool.getThreadContext().getHeader(DYNARANK_MIN_TOTAL_HITS);
if (minTotalHitsValue != null) {
final long minTotalHits = Long.parseLong(minTotalHitsValue);
if (totalHits < minTotalHits) {
if (logger.isDebugEnabled()) {
logger.debug("totalHits is {} < {}. No reranking results: {}", totalHits, minTotalHits, searchResponse);
}
listener.onResponse(response);
return;
}
}
if (logger.isDebugEnabled()) {
logger.debug("Reranking results: {}", searchResponse);
}
try {
final BytesStreamOutput out = new BytesStreamOutput();
searchResponse.writeTo(out);
if (logger.isDebugEnabled()) {
logger.debug("Reading headers...");
}
final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry);
if (logger.isDebugEnabled()) {
logger.debug("Reading hits...");
}
// BEGIN: SearchResponse#writeTo
// BEGIN: InternalSearchResponse#writeTo
final SearchHits hits = new SearchHits(in);
final SearchHits newHits = doReorder(hits, from, size, scriptInfo);
if (logger.isDebugEnabled()) {
logger.debug("Reading aggregations...");
}
final InternalAggregations aggregations = in.readBoolean() ? InternalAggregations.readFrom(in) : null;
if (logger.isDebugEnabled()) {
logger.debug("Reading suggest...");
}
final Suggest suggest = in.readBoolean() ? new Suggest(in) : null;
final boolean timedOut = in.readBoolean();
final Boolean terminatedEarly = in.readOptionalBoolean();
final SearchProfileResults profileResults = in.readOptionalWriteable(SearchProfileResults::new);
final int numReducePhases = in.readVInt();
final SearchResponseSections internalResponse = new InternalSearchResponse(newHits, aggregations, suggest,
profileResults, timedOut, terminatedEarly, numReducePhases);
// END: InternalSearchResponse
final int totalShards = in.readVInt();
final int successfulShards = in.readVInt();
final int size = in.readVInt();
final ShardSearchFailure[] shardFailures;
if (size == 0) {
shardFailures = ShardSearchFailure.EMPTY_ARRAY;
} else {
shardFailures = new ShardSearchFailure[size];
for (int i = 0; i < shardFailures.length; i++) {
shardFailures[i] = readShardSearchFailure(in);
}
}
final Clusters clusters;
if (in.getVersion().onOrAfter(Version.V_6_1_0)) {
clusters = new Clusters(in.readVInt(), in.readVInt(), in.readVInt());
} else {
clusters = Clusters.EMPTY;
}
final String scrollId = in.readOptionalString();
/* tookInMillis = */ in.readVLong();
final int skippedShards = in.readVInt();
// END: SearchResponse
final long tookInMillis = (System.nanoTime() - startTime) / 1000000;
if (logger.isDebugEnabled()) {
logger.debug("Creating new SearchResponse...");
}
@SuppressWarnings("unchecked")
final Response newResponse = (Response) new SearchResponse(internalResponse, scrollId, totalShards, successfulShards,
skippedShards, tookInMillis, shardFailures, clusters);
listener.onResponse(newResponse);
if (logger.isDebugEnabled()) {
logger.debug("Rewriting overhead time: {} - {} = {}ms", tookInMillis, searchResponse.getTook().getMillis(),
tookInMillis - searchResponse.getTook().getMillis());
}
} catch (final RetrySearchException e) {
throw e;
} catch (final Exception e) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to parse a search response.", e);
}
throw new ElasticsearchException("Failed to parse a search response.", e);
}
}
@Override
public void onFailure(final Exception e) {
listener.onFailure(e);
}
};
}
private SearchHits doReorder(final SearchHits hits, final int from, final int size,
final ScriptInfo scriptInfo) {
final SearchHit[] searchHits = hits.getHits();
SearchHit[] newSearchHits;
if (logger.isDebugEnabled()) {
logger.debug("searchHits.length <= reorderSize: {}", searchHits.length <= scriptInfo.getReorderSize());
}
if (searchHits.length <= scriptInfo.getReorderSize()) {
final SearchHit[] targets = onReorder(searchHits, scriptInfo);
if (from >= targets.length) {
newSearchHits = new SearchHit[0];
if (logger.isDebugEnabled()) {
logger.debug("Invalid argument: {} >= {}", from, targets.length);
}
} else {
int end = from + size;
if (end > targets.length) {
end = targets.length;
}
newSearchHits = Arrays.copyOfRange(targets, from, end);
}
} else {
SearchHit[] targets = Arrays.copyOfRange(searchHits, 0, scriptInfo.getReorderSize());
targets = onReorder(targets, scriptInfo);
final List list = new ArrayList<>(size);
for (int i = from; i < targets.length; i++) {
list.add(targets[i]);
}
for (int i = targets.length; i < searchHits.length; i++) {
list.add(searchHits[i]);
}
newSearchHits = list.toArray(new SearchHit[list.size()]);
}
return new SearchHits(newSearchHits, hits.getTotalHits(), hits.getMaxScore());
}
private SearchHit[] onReorder(final SearchHit[] searchHits,
final ScriptInfo scriptInfo) {
final int keepTopN = scriptInfo.getKeepTopN();
if (searchHits.length <= keepTopN) {
return searchHits;
}
final Factory factory = scriptService.compile(
new Script(scriptInfo.getScriptType(), scriptInfo.getLang(),
scriptInfo.getScript(), scriptInfo.getSettings()),
DynaRankScript.CONTEXT);
if (keepTopN == 0) {
return factory.newInstance(scriptInfo.getSettings())
.execute(searchHits);
}
final SearchHit[] hits = Arrays.copyOfRange(searchHits, keepTopN,
searchHits.length);
final SearchHit[] reordered = factory
.newInstance(scriptInfo.getSettings()).execute(hits);
for (int i = keepTopN; i < searchHits.length; i++) {
searchHits[i] = reordered[i - keepTopN];
}
return searchHits;
}
private int getInt(final Object value, final int defaultValue) {
if (value instanceof Number) {
final int v = ((Number) value).intValue();
if (v < 0) {
return defaultValue;
}
return v;
} else if (value instanceof String) {
return Integer.parseInt(value.toString());
}
return defaultValue;
}
public static class ScriptInfo {
protected final static ScriptInfo NO_SCRIPT_INFO = new ScriptInfo();
private String script;
private String lang;
private ScriptType scriptType;
private Map settings;
private int reorderSize;
private int keepTopN;
ScriptInfo() {
// nothing
}
ScriptInfo(final String script, final String lang, final String scriptType, final Settings settings, final int reorderSize, final int keepTopN, final MappingMetadata mappingMetadata) {
this.script = script;
this.lang = lang;
this.reorderSize = reorderSize;
this.keepTopN=keepTopN;
this.settings = new HashMap<>();
for (final String name : settings.keySet()) {
final List list = settings.getAsList(name);
this.settings.put(name, list.toArray(new String[list.size()]));
}
this.settings.put("source_as_map", mappingMetadata.getSourceAsMap());
if ("STORED".equalsIgnoreCase(scriptType)) {
this.scriptType = ScriptType.STORED;
} else {
this.scriptType = ScriptType.INLINE;
}
}
public String getScript() {
return script;
}
public String getLang() {
return lang;
}
public ScriptType getScriptType() {
return scriptType;
}
public Map getSettings() {
return settings;
}
public int getReorderSize() {
return reorderSize;
}
public int getKeepTopN() {
return keepTopN;
}
@Override
public String toString() {
return "ScriptInfo [script=" + script + ", lang=" + lang + ", scriptType=" + scriptType + ", settings=" + settings
+ ", reorderSize=" + reorderSize + ", keepTopN=" + keepTopN + "]";
}
}
private class Reaper implements Runnable {
private volatile boolean closed;
void close() {
closed = true;
}
@Override
public void run() {
if (closed) {
return;
}
try {
for (final Map.Entry entry : scriptInfoCache.asMap().entrySet()) {
final String index = entry.getKey();
final IndexMetadata indexMD = clusterService.state().getMetadata().index(index);
if (indexMD == null) {
scriptInfoCache.invalidate(index);
if (logger.isDebugEnabled()) {
logger.debug("Invalidate cache for {}", index);
}
continue;
}
final Settings indexSettings = indexMD.getSettings();
final String script = SETTING_INDEX_DYNARANK_SCRIPT.get(indexSettings);
if (script == null || script.length() == 0) {
scriptInfoCache.invalidate(index);
if (logger.isDebugEnabled()) {
logger.debug("Invalidate cache for {}", index);
}
continue;
}
final ScriptInfo scriptInfo = new ScriptInfo(script, SETTING_INDEX_DYNARANK_LANG.get(indexSettings),
SETTING_INDEX_DYNARANK_TYPE.get(indexSettings), SETTING_INDEX_DYNARANK_PARAMS.get(indexSettings),
SETTING_INDEX_DYNARANK_REORDER_SIZE.get(indexSettings), SETTING_INDEX_DYNARANK_KEEP_TOPN.get(indexSettings),
indexMD.mapping());
if (logger.isDebugEnabled()) {
logger.debug("Reload cache for {} => {}", index, scriptInfo);
}
scriptInfoCache.put(index, scriptInfo);
}
} catch (final Exception e) {
logger.warn("Failed to update a cache for ScriptInfo.", e);
} finally {
threadPool.schedule(reaper, cleanInterval, ThreadPool.Names.GENERIC);
}
}
}
}