All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.o19s.es.ltr.LtrQueryParserPlugin Maven / Gradle / Ivy

There is a newer version: 6.8.0
Show newest version
/*
 * Copyright [2016] Doug Turnbull
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */
package com.o19s.es.ltr;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.Collections.unmodifiableList;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.miscellaneous.LengthFilter;
import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry.Entry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.analysis.AbstractTokenFilterFactory;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.indices.analysis.AnalysisModule;
import org.elasticsearch.index.Index;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.AnalysisPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.watcher.ResourceWatcherService;

import com.o19s.es.explore.ExplorerQueryBuilder;
import com.o19s.es.ltr.action.AddFeaturesToSetAction;
import com.o19s.es.ltr.action.CachesStatsAction;
import com.o19s.es.ltr.action.ClearCachesAction;
import com.o19s.es.ltr.action.CreateModelFromSetAction;
import com.o19s.es.ltr.action.FeatureStoreAction;
import com.o19s.es.ltr.action.ListStoresAction;
import com.o19s.es.ltr.action.TransportAddFeatureToSetAction;
import com.o19s.es.ltr.action.TransportCacheStatsAction;
import com.o19s.es.ltr.action.TransportClearCachesAction;
import com.o19s.es.ltr.action.TransportCreateModelFromSetAction;
import com.o19s.es.ltr.action.TransportFeatureStoreAction;
import com.o19s.es.ltr.action.TransportListStoresAction;
import com.o19s.es.ltr.feature.store.StorableElement;
import com.o19s.es.ltr.feature.store.StoredFeature;
import com.o19s.es.ltr.feature.store.StoredFeatureSet;
import com.o19s.es.ltr.feature.store.StoredLtrModel;
import com.o19s.es.ltr.feature.store.index.CachedFeatureStore;
import com.o19s.es.ltr.feature.store.index.Caches;
import com.o19s.es.ltr.feature.store.index.IndexFeatureStore;
import com.o19s.es.ltr.logging.LoggingFetchSubPhase;
import com.o19s.es.ltr.logging.LoggingSearchExtBuilder;
import com.o19s.es.ltr.query.LtrQueryBuilder;
import com.o19s.es.ltr.query.StoredLtrQueryBuilder;
import com.o19s.es.ltr.query.ValidatingLtrQueryBuilder;
import com.o19s.es.ltr.ranker.parser.LinearRankerParser;
import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory;
import com.o19s.es.ltr.ranker.parser.XGBoostJsonParser;
import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine;
import com.o19s.es.ltr.ranker.ranklib.RanklibModelParser;
import com.o19s.es.ltr.rest.RestAddFeatureToSet;
import com.o19s.es.ltr.rest.RestCreateModelFromSet;
import com.o19s.es.ltr.rest.RestFeatureStoreCaches;
import com.o19s.es.ltr.rest.RestSimpleFeatureStore;
import com.o19s.es.ltr.utils.FeatureStoreLoader;
import com.o19s.es.ltr.utils.Suppliers;

import ciir.umass.edu.learning.RankerFactory;

public class LtrQueryParserPlugin extends Plugin implements SearchPlugin, ScriptPlugin, ActionPlugin, AnalysisPlugin {
    private final LtrRankerParserFactory parserFactory;
    private final Caches caches;

    public LtrQueryParserPlugin(Settings settings) {
        caches = new Caches(settings);
        // Use memoize to Lazy load the RankerFactory as it's a heavy object to construct
        Supplier ranklib = Suppliers.memoize(RankerFactory::new);
        parserFactory = new LtrRankerParserFactory.Builder()
                .register(RanklibModelParser.TYPE, () -> new RanklibModelParser(ranklib.get()))
                .register(LinearRankerParser.TYPE, LinearRankerParser::new)
                .register(XGBoostJsonParser.TYPE, XGBoostJsonParser::new)
                .build();
    }

    @Override
    public List> getQueries() {

        return asList(
                new QuerySpec<>(ExplorerQueryBuilder.NAME, ExplorerQueryBuilder::new, ExplorerQueryBuilder::fromXContent),
                new QuerySpec<>(LtrQueryBuilder.NAME, LtrQueryBuilder::new, LtrQueryBuilder::fromXContent),
                new QuerySpec<>(StoredLtrQueryBuilder.NAME,
                        (input) -> new StoredLtrQueryBuilder(getFeatureStoreLoader(), input),
                        (ctx) -> StoredLtrQueryBuilder.fromXContent(getFeatureStoreLoader(), ctx)),
                new QuerySpec<>(ValidatingLtrQueryBuilder.NAME,
                        (input) -> new ValidatingLtrQueryBuilder(input, parserFactory),
                        (ctx) -> ValidatingLtrQueryBuilder.fromXContent(ctx, parserFactory)));
    }

    @Override
    public List getFetchSubPhases(FetchPhaseConstructionContext context) {
        return singletonList(new LoggingFetchSubPhase());
    }

    @Override
    public List> getSearchExts() {
        return singletonList(
                new SearchExtSpec<>(LoggingSearchExtBuilder.NAME, LoggingSearchExtBuilder::new, LoggingSearchExtBuilder::parse));
    }

    @Override
    public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) {
        return new RankLibScriptEngine(settings, parserFactory);
    }

    @Override
    public List getRestHandlers(Settings settings, RestController restController,
                                             ClusterSettings clusterSettings, IndexScopedSettings indexScopedSettings,
                                             SettingsFilter settingsFilter, IndexNameExpressionResolver indexNameExpressionResolver,
                                             Supplier nodesInCluster) {
        List list = new ArrayList<>();
        RestSimpleFeatureStore.register(list, settings, restController);
        list.add(new RestFeatureStoreCaches(settings, restController));
        list.add(new RestCreateModelFromSet(settings, restController));
        list.add(new RestAddFeatureToSet(settings, restController));
        return unmodifiableList(list);
    }

    @Override
    public List> getActions() {
        return unmodifiableList(asList(
                new ActionHandler<>(FeatureStoreAction.INSTANCE, TransportFeatureStoreAction.class),
                new ActionHandler<>(CachesStatsAction.INSTANCE, TransportCacheStatsAction.class),
                new ActionHandler<>(ClearCachesAction.INSTANCE, TransportClearCachesAction.class),
                new ActionHandler<>(AddFeaturesToSetAction.INSTANCE, TransportAddFeatureToSetAction.class),
                new ActionHandler<>(CreateModelFromSetAction.INSTANCE, TransportCreateModelFromSetAction.class),
                new ActionHandler<>(ListStoresAction.INSTANCE, TransportListStoresAction.class)));
    }

    @Override
    public List getNamedWriteables() {
        return unmodifiableList(asList(
                new Entry(StorableElement.class, StoredFeature.TYPE, StoredFeature::new),
                new Entry(StorableElement.class, StoredFeatureSet.TYPE, StoredFeatureSet::new),
                new Entry(StorableElement.class, StoredLtrModel.TYPE, StoredLtrModel::new)
        ));
    }

    @Override
    public List getNamedXContent() {
        return unmodifiableList(asList(
                new NamedXContentRegistry.Entry(StorableElement.class,
                        new ParseField(StoredFeature.TYPE),
                        (CheckedFunction) StoredFeature::parse),
                new NamedXContentRegistry.Entry(StorableElement.class,
                        new ParseField(StoredFeatureSet.TYPE),
                        (CheckedFunction) StoredFeatureSet::parse),
                new NamedXContentRegistry.Entry(StorableElement.class,
                        new ParseField(StoredLtrModel.TYPE),
                        (CheckedFunction) StoredLtrModel::parse)
        ));
    }

    @Override
    public List> getSettings() {
        return unmodifiableList(asList(
                IndexFeatureStore.STORE_VERSION_PROP,
                Caches.LTR_CACHE_MEM_SETTING,
                Caches.LTR_CACHE_EXPIRE_AFTER_READ,
                Caches.LTR_CACHE_EXPIRE_AFTER_WRITE));
    }

    @Override
    public Collection createComponents(Client client,
                                               ClusterService clusterService,
                                               ThreadPool threadPool,
                                               ResourceWatcherService resourceWatcherService,
                                               ScriptService scriptService,
                                               NamedXContentRegistry xContentRegistry,
                                               Environment environment,
                                               NodeEnvironment nodeEnvironment,
                                               NamedWriteableRegistry namedWriteableRegistry) {
        clusterService.addListener(event -> {
            for (Index i : event.indicesDeleted()) {
                if (IndexFeatureStore.isIndexStore(i.getName())) {
                    caches.evict(i.getName());
                }
            }
        });
        return asList(caches, parserFactory);
    }

    protected FeatureStoreLoader getFeatureStoreLoader() {
        return (storeName, client) -> new CachedFeatureStore(new IndexFeatureStore(storeName, client, parserFactory), caches);
    }


    @Override
    public Map> getTokenFilters() {
        Map> filters = new HashMap<>();
        filters.put("ltr_edge_ngram", LtrEdgeNGramFilterFactory::new);
        filters.put("ltr_length", LtrLengthTokenFilterFactory::new);
        return Collections.unmodifiableMap(filters);
    }

    // A simplified version of some token filters needed by the feature stores.
    // This is because some common filter have been moved to analysis-common module
    // which is not included in the integration test cluster.
    // Add a simple version of these token filter to make the plugin self contained.
    private static final int STORABLE_ELEMENT_MAX_NAME_SIZE = 512;
    private static class LtrEdgeNGramFilterFactory extends AbstractTokenFilterFactory {
        LtrEdgeNGramFilterFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
            super(indexSettings, name, settings);
        }

        @Override
        public TokenStream create(TokenStream tokenStream) {
            return new EdgeNGramTokenFilter(tokenStream, 1, STORABLE_ELEMENT_MAX_NAME_SIZE);
        }
    }

    private class LtrLengthTokenFilterFactory extends AbstractTokenFilterFactory {
        LtrLengthTokenFilterFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
            super(indexSettings, name, settings);
        }

        @Override
        public TokenStream create(TokenStream tokenStream) {
            return new LengthFilter(tokenStream, 0, STORABLE_ELEMENT_MAX_NAME_SIZE);
        }
    }
}