All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.o19s.es.ltr.query.StoredLtrQueryBuilder Maven / Gradle / Ivy

There is a newer version: 6.8.0
Show newest version
/*
 * Copyright [2017] Wikimedia Foundation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.o19s.es.ltr.query;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.ESLoggerFactory;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;

import com.o19s.es.ltr.LtrQueryContext;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.feature.store.CompiledLtrModel;
import com.o19s.es.ltr.feature.store.FeatureStore;
import com.o19s.es.ltr.feature.store.index.IndexFeatureStore;
import com.o19s.es.ltr.ranker.linear.LinearRanker;
import com.o19s.es.ltr.utils.AbstractQueryBuilderUtils;
import com.o19s.es.ltr.utils.FeatureStoreLoader;

/**
 * sltr query, build a ltr query based on a stored model.
 */
public class StoredLtrQueryBuilder extends AbstractQueryBuilder implements NamedWriteable {
    public static final String NAME = "sltr";
    public static final ParseField MODEL_NAME = new ParseField("model");
    public static final ParseField FEATURESET_NAME = new ParseField("featureset");
    public static final ParseField STORE_NAME = new ParseField("store");
    public static final ParseField PARAMS = new ParseField("params");
    public static final ParseField ACTIVE_FEATURES = new ParseField("active_features");
    private static final ObjectParser PARSER;

    public static final Logger LOGGER = ESLoggerFactory.getLogger(StoredLtrQueryBuilder.class);

    static {
        PARSER = new ObjectParser<>(NAME);
        PARSER.declareString(StoredLtrQueryBuilder::modelName, MODEL_NAME);
        PARSER.declareString(StoredLtrQueryBuilder::featureSetName, FEATURESET_NAME);
        PARSER.declareString(StoredLtrQueryBuilder::storeName, STORE_NAME);
        PARSER.declareField(StoredLtrQueryBuilder::params, XContentParser::map, PARAMS, ObjectParser.ValueType.OBJECT);
        PARSER.declareStringArray(StoredLtrQueryBuilder::activeFeatures, ACTIVE_FEATURES);
        AbstractQueryBuilderUtils.declareStandardFields(PARSER);
    }

    /**
     * Injected context used to load a {@link FeatureStore} when running {@link #doToQuery(QueryShardContext)}
     */
    private final transient FeatureStoreLoader storeLoader;
    private String modelName;
    private String featureSetName;
    private String storeName;
    private Map params;
    private List activeFeatures;

    public StoredLtrQueryBuilder(FeatureStoreLoader storeLoader) {
        this.storeLoader = storeLoader;
    }


    public StoredLtrQueryBuilder(FeatureStoreLoader storeLoader, StreamInput input) throws IOException {
        super(input);
        this.storeLoader = Objects.requireNonNull(storeLoader);
        modelName = input.readOptionalString();
        featureSetName = input.readOptionalString();
        params = input.readMap();
        if (input.getVersion().onOrAfter(Version.V_6_2_4)) {
            String[] activeFeat = input.readOptionalStringArray();
            activeFeatures = activeFeat == null ? null : Arrays.asList(activeFeat);
        }
        storeName = input.readOptionalString();
    }

    public static StoredLtrQueryBuilder fromXContent(FeatureStoreLoader storeLoader,
                                                     XContentParser parser) throws IOException {
        storeLoader = Objects.requireNonNull(storeLoader);
        final StoredLtrQueryBuilder builder = new StoredLtrQueryBuilder(storeLoader);
        try {
            PARSER.parse(parser, builder, null);
        } catch (IllegalArgumentException iae) {
            throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae);
        }
        if (builder.modelName() == null && builder.featureSetName() == null) {
            throw new ParsingException(parser.getTokenLocation(), "Either [" + MODEL_NAME + "] or [" + FEATURESET_NAME + "] must be set.");
        }
        if (builder.params() == null) {
            throw new ParsingException(parser.getTokenLocation(), "Field [" + PARAMS + "] is mandatory.");
        }
        return builder;
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeOptionalString(modelName);
        out.writeOptionalString(featureSetName);
        out.writeMap(params);
        if (out.getVersion().onOrAfter(Version.V_6_2_4)) {
            out.writeOptionalStringArray(activeFeatures != null ? activeFeatures.toArray(new String[0]) : null);
        }
        out.writeOptionalString(storeName);
    }

    @Override
    protected void doXContent(XContentBuilder builder, Params p) throws IOException {
        builder.startObject(NAME);
        if (modelName != null) {
            builder.field(MODEL_NAME.getPreferredName(), modelName);
        }
        if (featureSetName != null) {
            builder.field(FEATURESET_NAME.getPreferredName(), featureSetName);
        }
        if (storeName != null) {
            builder.field(STORE_NAME.getPreferredName(), storeName);
        }
        if (this.params != null && !this.params.isEmpty()) {
            builder.field(PARAMS.getPreferredName(), this.params);
        }
        if (this.activeFeatures != null && !this.activeFeatures.isEmpty()) {
            builder.field(ACTIVE_FEATURES.getPreferredName(), this.activeFeatures);
        }
        printBoostAndQueryName(builder);
        builder.endObject();
    }

    private static void validateActiveFeatures(FeatureSet features, LtrQueryContext context) {
        for (String featureName : context.getActiveFeatures()) {
            if (!features.hasFeature(featureName)) {
                throw new IllegalArgumentException("Feature: [" + featureName + "] " + "provided in active_features does not exist");
            }
        }
    }

    @Override
    protected Query doToQuery(QueryShardContext context) throws IOException {
        try {
            String indexName = storeName != null ? IndexFeatureStore.indexName(storeName) : IndexFeatureStore.DEFAULT_STORE;
            FeatureStore store = storeLoader.load(indexName, context.getClient());
            LtrQueryContext ltrQueryContext = new LtrQueryContext(context,
                    activeFeatures == null ? Collections.emptySet() : new HashSet<>(activeFeatures));
            if (modelName != null) {
                CompiledLtrModel model = store.loadModel(modelName);
                validateActiveFeatures(model.featureSet(), ltrQueryContext);
                return RankerQuery.build(model, ltrQueryContext, params);
            } else {
                assert featureSetName != null;
                FeatureSet set = store.loadSet(featureSetName);
                float[] weitghs = new float[set.size()];
                Arrays.fill(weitghs, 1F);
                LinearRanker ranker = new LinearRanker(weitghs);
                CompiledLtrModel model = new CompiledLtrModel("linear", set, ranker);
                validateActiveFeatures(model.featureSet(), ltrQueryContext);
                return RankerQuery.build(model, ltrQueryContext, params);
            }
        } catch (Exception e) {
            LOGGER.warn("Ignore RankerQuery.", e);
        }
        return Queries.newMatchAllQuery();
    }

    @Override
    protected boolean doEquals(StoredLtrQueryBuilder other) {
        return Objects.equals(modelName, other.modelName) &&
                Objects.equals(featureSetName, other.featureSetName) &&
                Objects.equals(storeName, other.storeName) &&
                Objects.equals(params, other.params) &&
                Objects.equals(activeFeatures, other.activeFeatures);
    }

    @Override
    protected int doHashCode() {
        return Objects.hash(modelName, featureSetName, storeName, params, activeFeatures);
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

    public String modelName() {
        return modelName;
    }

    public StoredLtrQueryBuilder modelName(String modelName) {
        this.modelName = Objects.requireNonNull(modelName);
        return this;
    }

    public String featureSetName() {
        return featureSetName;
    }

    public StoredLtrQueryBuilder featureSetName(String featureSetName) {
        this.featureSetName = featureSetName;
        return this;
    }

    public String storeName() {
        return storeName;
    }

    public StoredLtrQueryBuilder storeName(String storeName) {
        this.storeName = storeName;
        return this;
    }

    public Map params() {
        return params;
    }

    public StoredLtrQueryBuilder params(Map params) {
        this.params = Objects.requireNonNull(params);
        return this;
    }

    public List activeFeatures() {
        return activeFeatures;
    }

    public StoredLtrQueryBuilder activeFeatures(List activeFeatures) {
        this.activeFeatures = Objects.requireNonNull(activeFeatures);
        return this;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy