All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.codelibs.elasticsearch.dynarank.ranker.DynamicRanker Maven / Gradle / Ivy

There is a newer version: 7.16.0
Show newest version
package org.codelibs.elasticsearch.dynarank.ranker;

import static org.elasticsearch.action.search.ShardSearchFailure.readShardSearchFailure;
import static org.elasticsearch.search.internal.InternalSearchHits.readSearchHits;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;

import org.codelibs.elasticsearch.dynarank.filter.SearchActionFilter;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.Requests;
import org.elasticsearch.cluster.ClusterService;
import org.elasticsearch.cluster.metadata.AliasOrIndex;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.logging.ESLogger;
import org.elasticsearch.common.logging.ESLoggerFactory;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.script.CompiledScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptService.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.internal.InternalSearchHit;
import org.elasticsearch.search.internal.InternalSearchHits;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.lookup.SourceLookup;
import org.elasticsearch.search.profile.InternalProfileShardResults;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.netty.ChannelBufferStreamInput;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

public class DynamicRanker extends AbstractLifecycleComponent {

    public static final String DEFAULT_SCRIPT_TYPE = "inline";

    public static final String DEFAULT_SCRIPT_LANG = "groovy";

    public static final String INDEX_DYNARANK_SCRIPT = "index.dynarank.script_sort.script";

    public static final String INDEX_DYNARANK_SCRIPT_LANG = "index.dynarank.script_sort.lang";

    public static final String INDEX_DYNARANK_SCRIPT_TYPE = "index.dynarank.script_sort.type";

    public static final String INDEX_DYNARANK_SCRIPT_PARAMS = "index.dynarank.script_sort.params.";

    public static final String INDEX_DYNARANK_REORDER_SIZE = "index.dynarank.reorder_size";

    public static final String INDICES_DYNARANK_REORDER_SIZE = "indices.dynarank.reorder_size";

    public static final String INDICES_DYNARANK_CACHE_EXPIRE = "indices.dynarank.cache.expire";

    public static final String INDICES_DYNARANK_CACHE_CLEAN_INTERVAL = "indices.dynarank.cache.clean_interval";

    private static final String DYNARANK_RERANK_ENABLE = "_rerank";

    private static final String DYNARANK_MIN_TOTAL_HITS = "_minTotalHits";

    private ESLogger logger = ESLoggerFactory.getLogger("script.dynarank.sort");

    private ClusterService clusterService;

    private Integer defaultReorderSize;

    private ScriptService scriptService;

    private Cache scriptInfoCache;

    private ThreadPool threadPool;

    private TimeValue cleanInterval;

    private Reaper reaper;

    private Client client;

    @Inject
    public DynamicRanker(final Settings settings,
            final Client client,
            final ClusterService clusterService,
            final ScriptService scriptService, final ThreadPool threadPool,
            final ActionFilters filters) {
        super(settings);
        this.client = client;
        this.clusterService = clusterService;
        this.scriptService = scriptService;
        this.threadPool = threadPool;

        logger.info("Initializing DynamicRanker");

        defaultReorderSize = settings.getAsInt(INDICES_DYNARANK_REORDER_SIZE,
                100);
        final TimeValue expire = settings.getAsTime(
                INDICES_DYNARANK_CACHE_EXPIRE, null);
        cleanInterval = settings.getAsTime(
                INDICES_DYNARANK_CACHE_CLEAN_INTERVAL,
                TimeValue.timeValueSeconds(60));

        final CacheBuilder builder = CacheBuilder.newBuilder()
                .concurrencyLevel(16);
        if (expire != null) {
            builder.expireAfterAccess(expire.millis(), TimeUnit.MILLISECONDS);
        }
        scriptInfoCache = builder.build();

        for (final ActionFilter filter : filters.filters()) {
            if (filter instanceof SearchActionFilter) {
                ((SearchActionFilter) filter).setDynamicRanker(this);
                if (logger.isDebugEnabled()) {
                    logger.debug("Set DynamicRanker to " + filter);
                }
            }
        }

    }

    @Override
    protected void doStart() throws ElasticsearchException {
        reaper = new Reaper();
        threadPool.schedule(cleanInterval, ThreadPool.Names.SAME, reaper);
    }

    @Override
    protected void doStop() throws ElasticsearchException {
    }

    @Override
    protected void doClose() throws ElasticsearchException {
        reaper.close();
        scriptInfoCache.invalidateAll();
    }

    public ActionListener wrapActionListener(
            final String action, final SearchRequest request,
            final ActionListener listener) {
        switch (request.searchType()) {
        case DFS_QUERY_AND_FETCH:
        case QUERY_AND_FETCH:
        case QUERY_THEN_FETCH:
            break;
        default:
            return null;
        }

        if (request.scroll() != null) {
            return null;
        }

        final Object isRerank = request.getHeader(DYNARANK_RERANK_ENABLE);
        if (isRerank instanceof Boolean && !((Boolean) isRerank).booleanValue()) {
            return null;
        }

        final BytesReference source = request.source();
        if (source == null) {
            return null;
        }

        final String[] indices = request.indices();
        if (indices == null || indices.length != 1) {
            return null;
        }

        final String index = indices[0];
        final ScriptInfo scriptInfo = getScriptInfo(index);
        if (scriptInfo == null || scriptInfo.getScript() == null) {
            return null;
        }

        final long startTime = System.nanoTime();

        try {
            final Map sourceAsMap = SourceLookup
                    .sourceAsMap(source);
            final int size = getInt(sourceAsMap.get("size"), 10);
            final int from = getInt(sourceAsMap.get("from"), 0);
            if (size < 0 || from < 0) {
                return null;
            }

            if (from >= scriptInfo.getReorderSize()) {
                return null;
            }

            int maxSize = scriptInfo.getReorderSize();
            if (from + size > scriptInfo.getReorderSize()) {
                maxSize = from + size;
            }
            sourceAsMap.put("size", maxSize);
            sourceAsMap.put("from", 0);

            if (logger.isDebugEnabled()) {
                logger.debug("Rewrite query: from:{}->{} size:{}->{}", from, 0,
                        size, maxSize);
            }

            final XContentBuilder builder = XContentFactory
                    .contentBuilder(Requests.CONTENT_TYPE);
            builder.map(sourceAsMap);
            request.source(builder.bytes());

            final ActionListener searchResponseListener = createSearchResponseListener(
                    request, listener, from, size, scriptInfo.getReorderSize(),
                    startTime, scriptInfo);
            return new ActionListener() {
                @Override
                public void onResponse(SearchResponse response) {
                    try {
                        searchResponseListener.onResponse(response);
                    } catch (RetrySearchException e) {
                        Map newSourceAsMap = e.rewrite(sourceAsMap);
                        if (newSourceAsMap == null) {
                            throw new ElasticsearchException("Failed to rewrite source: " + sourceAsMap);
                        }
                        newSourceAsMap.put("size", size);
                        newSourceAsMap.put("from", from);
                        if (logger.isDebugEnabled()) {
                            logger.debug("Original Query: \n{}\nNew Query: \n{}", sourceAsMap, newSourceAsMap);
                        }
                        try {
                            final XContentBuilder builder = XContentFactory.contentBuilder(Requests.CONTENT_TYPE);
                            builder.map(newSourceAsMap);
                            request.source(builder.bytes());
                            for(String name:request.getHeaders()){
                                if (name.startsWith("filter.codelibs.")) {
                                    request.putHeader(name, Boolean.FALSE);
                                }
                            }
                            request.putHeader(DYNARANK_RERANK_ENABLE, Boolean.FALSE);
                            client.search(request, listener);
                        } catch (IOException ioe) {
                            throw new ElasticsearchException("Failed to parse a new source.", ioe);
                        }
                    }
                }

                @Override
                public void onFailure(Throwable t) {
                    searchResponseListener.onFailure(t);
                }
            };
        } catch (final IOException e) {
            throw new ElasticsearchException("Failed to parse a source.", e);
        }
    }

    public ScriptInfo getScriptInfo(final String index) {
        try {
            return scriptInfoCache.get(index, new Callable() {
                @Override
                public ScriptInfo call() throws Exception {
                    final MetaData metaData = clusterService.state()
                            .getMetaData();
                    AliasOrIndex aliasOrIndex = metaData
                            .getAliasAndIndexLookup().get(index);
                    if (aliasOrIndex == null) {
                        return ScriptInfo.NO_SCRIPT_INFO;
                    }
                    Settings indexSettings = null;
                    for (IndexMetaData indexMD : aliasOrIndex.getIndices()) {
                        final Settings scriptSettings = indexMD.getSettings();
                        final String script = scriptSettings
                                .get(INDEX_DYNARANK_SCRIPT);
                        if (script != null && script.length() > 0) {
                            indexSettings = scriptSettings;
                        }
                    }

                    if (indexSettings == null) {
                        return ScriptInfo.NO_SCRIPT_INFO;
                    }

                    return new ScriptInfo(indexSettings
                            .get(INDEX_DYNARANK_SCRIPT), indexSettings.get(
                            INDEX_DYNARANK_SCRIPT_LANG, DEFAULT_SCRIPT_LANG),
                            indexSettings.get(INDEX_DYNARANK_SCRIPT_TYPE,
                                    DEFAULT_SCRIPT_TYPE), indexSettings
                                    .getByPrefix(INDEX_DYNARANK_SCRIPT_PARAMS),
                            indexSettings.getAsInt(INDEX_DYNARANK_REORDER_SIZE,
                                    defaultReorderSize));
                }
            });
        } catch (final Exception e) {
            logger.warn("Failed to load ScriptInfo for {}.", e, index);
            return null;
        }
    }

    private ActionListener createSearchResponseListener(
            final SearchRequest request,
            final ActionListener listener, final int from,
            final int size, final int reorderSize, final long startTime,
            final ScriptInfo scriptInfo) {
        return new ActionListener() {
            @Override
            public void onResponse(final SearchResponse response) {
                final long totalHits = response.getHits().getTotalHits();
                if (totalHits == 0) {
                    if (logger.isDebugEnabled()) {
                        logger.debug(
                                "totalHits is {}. No reranking results: {}",
                                totalHits, response);
                    }
                    listener.onResponse(response);
                    return;
                }

                final Object minTotalHits = request
                        .getHeader(DYNARANK_MIN_TOTAL_HITS);
                if (minTotalHits instanceof Number
                        && totalHits < ((Number) minTotalHits).longValue()) {
                    if (logger.isDebugEnabled()) {
                        logger.debug(
                                "totalHits is {} < {}. No reranking results: {}",
                                totalHits, minTotalHits, response);
                    }
                    listener.onResponse(response);
                    return;
                }

                if (logger.isDebugEnabled()) {
                    logger.debug("Reranking results: {}", response);
                }

                try {
                    final BytesStreamOutput out = new BytesStreamOutput();
                    response.writeTo(out);

                    if (logger.isDebugEnabled()) {
                        logger.debug("Reading headers...");
                    }
                    final ChannelBufferStreamInput in = new ChannelBufferStreamInput(
                            out.bytes().toChannelBuffer());
                    Map headers = null;
                    if (in.readBoolean()) {
                        headers = in.readMap();
                    }
                    if (logger.isDebugEnabled()) {
                        logger.debug("Reading hits...");
                    }
                    final InternalSearchHits hits = readSearchHits(in);
                    final InternalSearchHits newHits = doReorder(hits, from,
                            size, reorderSize, scriptInfo);
                    if (logger.isDebugEnabled()) {
                        logger.debug("Reading aggregations...");
                    }
                    InternalAggregations aggregations = null;
                    if (in.readBoolean()) {
                        aggregations = InternalAggregations
                                .readAggregations(in);
                    }
                    if (logger.isDebugEnabled()) {
                        logger.debug("Reading suggest...");
                    }
                    Suggest suggest = null;
                    if (in.readBoolean()) {
                        suggest = Suggest.readSuggest(Suggest.Fields.SUGGEST,
                                in);
                    }
                    final boolean timedOut = in.readBoolean();
                    Boolean terminatedEarly = in.readOptionalBoolean();
                    InternalProfileShardResults profileResults;

                    if (in.getVersion().onOrAfter(Version.V_2_2_0) 
                            && in.readBoolean()) {
                        profileResults = new InternalProfileShardResults(in);
                    } else {
                        profileResults = null;
                    }
                    
                    final InternalSearchResponse internalResponse = new InternalSearchResponse(
                            newHits, aggregations, suggest, profileResults, timedOut,
                            terminatedEarly);
                    final int totalShards = in.readVInt();
                    final int successfulShards = in.readVInt();
                    final int shardFailureSize = in.readVInt();
                    ShardSearchFailure[] shardFailures;
                    if (shardFailureSize == 0) {
                        shardFailures = ShardSearchFailure.EMPTY_ARRAY;
                    } else {
                        shardFailures = new ShardSearchFailure[shardFailureSize];
                        for (int i = 0; i < shardFailures.length; i++) {
                            shardFailures[i] = readShardSearchFailure(in);
                        }
                    }
                    final String scrollId = in.readOptionalString();
                    final long tookInMillis = (System.nanoTime() - startTime) / 1000000;

                    if (logger.isDebugEnabled()) {
                        logger.debug("Creating new SearchResponse...");
                    }
                    final SearchResponse newResponse = new SearchResponse(
                            internalResponse, scrollId, totalShards,
                            successfulShards, tookInMillis, shardFailures);
                    if (headers != null) {
                        for (final Map.Entry entry : headers
                                .entrySet()) {
                            newResponse.putHeader(entry.getKey(),
                                    entry.getValue());
                        }
                    }
                    listener.onResponse(newResponse);

                    if (logger.isDebugEnabled()) {
                        logger.debug("Rewriting overhead time: {} - {} = {}ms",
                                tookInMillis, response.getTookInMillis(),
                                tookInMillis - response.getTookInMillis());
                    }

                } catch (final RetrySearchException e) {
                    throw e;
                } catch (final Exception e) {
                    if (logger.isDebugEnabled()) {
                        logger.debug("Failed to parse a search response.", e);
                    }
                    throw new ElasticsearchException(
                            "Failed to parse a search response.", e);
                }
            }

            @Override
            public void onFailure(final Throwable e) {
                listener.onFailure(e);
            }
        };
    }

    private InternalSearchHits doReorder(final InternalSearchHits hits,
            final int from, final int size, final int reorderSize,
            final ScriptInfo scriptInfo) {
        final InternalSearchHit[] searchHits = hits.internalHits();
        InternalSearchHit[] newSearchHits;
        if (logger.isDebugEnabled()) {
            logger.debug("searchHits.length <= reorderSize: {}",
                    searchHits.length <= reorderSize);
        }
        if (searchHits.length <= reorderSize) {
            final InternalSearchHit[] targets = onReorder(searchHits,
                    scriptInfo);
            if (from >= targets.length) {
                newSearchHits = new InternalSearchHit[0];
                if (logger.isDebugEnabled()) {
                    logger.debug("Invalid argument: " + from + " >= "
                            + targets.length);
                }
            } else {
                int end = from + size;
                if (end > targets.length) {
                    end = targets.length;
                }
                newSearchHits = Arrays.copyOfRange(targets, from, end);
            }
        } else {
            InternalSearchHit[] targets = Arrays.copyOfRange(searchHits, 0,
                    reorderSize);
            targets = onReorder(targets, scriptInfo);
            final List list = new ArrayList<>(size);
            for (int i = from; i < targets.length; i++) {
                list.add(targets[i]);
            }
            for (int i = targets.length; i < searchHits.length; i++) {
                list.add(searchHits[i]);
            }
            newSearchHits = list.toArray(new InternalSearchHit[list.size()]);
        }
        return new InternalSearchHits(newSearchHits, hits.totalHits(),
                hits.maxScore());
    }

    private InternalSearchHit[] onReorder(final InternalSearchHit[] searchHits,
            final ScriptInfo scriptInfo) {
        final Map vars = new HashMap();
        final InternalSearchHit[] hits = searchHits;
        vars.put("searchHits", hits);
        vars.putAll(scriptInfo.getSettings());
        final CompiledScript compiledScript = scriptService.compile(
                new Script(scriptInfo.getScript(), scriptInfo.getScriptType(),
                        scriptInfo.getLang(), new HashMap()),
                ScriptContext.Standard.SEARCH, SearchContext.current(),
                Collections. emptyMap());
        return (InternalSearchHit[]) scriptService.executable(compiledScript,
                vars).run();
    }

    private int getInt(final Object value, final int defaultValue) {
        if (value instanceof Number) {
            return ((Number) value).intValue();
        } else if (value instanceof String) {
            return Integer.parseInt(value.toString());
        }
        return defaultValue;
    }

    public static class ScriptInfo {
        protected final static ScriptInfo NO_SCRIPT_INFO = new ScriptInfo();

        private String script;

        private String lang;

        private ScriptType scriptType;

        private Map settings;

        private int reorderSize;

        ScriptInfo() {
            // nothing
        }

        ScriptInfo(final String script, final String lang,
                final String scriptType, final Settings settings,
                final int reorderSize) {
            this.script = script;
            this.lang = lang;
            this.reorderSize = reorderSize;
            this.settings = new HashMap<>();
            for (final String name : settings.names()) {
                final String value = settings.get(name);
                if (value != null) {
                    this.settings.put(name, value);
                } else {
                    this.settings.put(name, settings.getAsArray(name));
                }
            }
            if ("INDEXED".equalsIgnoreCase(scriptType)) {
                this.scriptType = ScriptType.INDEXED;
            } else if ("FILE".equalsIgnoreCase(scriptType)) {
                this.scriptType = ScriptType.FILE;
            } else {
                this.scriptType = ScriptType.INLINE;
            }
        }

        public String getScript() {
            return script;
        }

        public String getLang() {
            return lang;
        }

        public ScriptType getScriptType() {
            return scriptType;
        }

        public Map getSettings() {
            return settings;
        }

        public int getReorderSize() {
            return reorderSize;
        }

        @Override
        public String toString() {
            return "ScriptInfo [script=" + script + ", lang=" + lang
                    + ", scriptType=" + scriptType + ", settings=" + settings
                    + ", reorderSize=" + reorderSize + "]";
        }
    }

    private class Reaper implements Runnable {
        private volatile boolean closed;

        void close() {
            closed = true;
        }

        @Override
        public void run() {
            if (closed) {
                return;
            }

            try {
                for (final Map.Entry entry : scriptInfoCache
                        .asMap().entrySet()) {
                    final String index = entry.getKey();

                    final IndexMetaData indexMD = clusterService.state()
                            .getMetaData().index(index);
                    if (indexMD == null) {
                        scriptInfoCache.invalidate(index);
                        if (logger.isDebugEnabled()) {
                            logger.debug("Invalidate cache for " + index);
                        }
                        continue;
                    }

                    final Settings indexSettings = indexMD.getSettings();
                    final String script = indexSettings
                            .get(INDEX_DYNARANK_SCRIPT);
                    if (script == null || script.length() == 0) {
                        scriptInfoCache.invalidate(index);
                        if (logger.isDebugEnabled()) {
                            logger.debug("Invalidate cache for " + index);
                        }
                        continue;
                    }

                    final ScriptInfo scriptInfo = new ScriptInfo(script,
                            indexSettings.get(INDEX_DYNARANK_SCRIPT_LANG,
                                    DEFAULT_SCRIPT_LANG), indexSettings.get(
                                    INDEX_DYNARANK_SCRIPT_TYPE,
                                    DEFAULT_SCRIPT_TYPE),
                            indexSettings
                                    .getByPrefix(INDEX_DYNARANK_SCRIPT_PARAMS),
                            indexSettings.getAsInt(INDEX_DYNARANK_REORDER_SIZE,
                                    defaultReorderSize));
                    if (logger.isDebugEnabled()) {
                        logger.debug("Reload cache for " + index + " => "
                                + scriptInfo);
                    }
                    scriptInfoCache.put(index, scriptInfo);
                }
            } catch (final Exception e) {
                logger.warn("Failed to update a cache for ScriptInfo.", e);
            } finally {
                threadPool.schedule(cleanInterval, ThreadPool.Names.GENERIC,
                        reaper);
            }

        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy