com.o19s.es.ltr.feature.store.PrecompiledExpressionFeature Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of elasticsearch-learning-to-rank Show documentation
Show all versions of elasticsearch-learning-to-rank Show documentation
Learing to Rank Query w/ RankLib Models
/*
* Copyright [2017] Wikimedia Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.o19s.es.ltr.feature.store;
import com.o19s.es.ltr.LtrQueryContext;
import com.o19s.es.ltr.feature.Feature;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.query.DerivedExpressionQuery;
import com.o19s.es.ltr.utils.Scripting;
import org.apache.lucene.expressions.Expression;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
public class PrecompiledExpressionFeature implements Feature, Accountable {
public static final String TEMPLATE_LANGUAGE = "derived_expression";
private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(PrecompiledExpressionFeature.class);
private final String name;
private final Expression expression;
private final Set expressionVariables;
private final Collection queryParams;
private PrecompiledExpressionFeature(String name, Expression expression, Collection queryParams) {
this.name = name;
this.expression = expression;
this.queryParams = queryParams;
this.expressionVariables = new HashSet<>(Arrays.asList(this.expression.variables));
}
public static PrecompiledExpressionFeature compile(StoredFeature feature) {
assert TEMPLATE_LANGUAGE.equals(feature.templateLanguage());
Expression expr = (Expression) Scripting.compile(feature.template());
return new PrecompiledExpressionFeature(feature.name(), expr, feature.queryParams());
}
@Override
public long ramBytesUsed() {
return BASE_RAM_USED +
(Character.BYTES * name.length()) + NUM_BYTES_ARRAY_HEADER +
(((Character.BYTES * expression.sourceText.length()) + NUM_BYTES_ARRAY_HEADER) * 2);
}
@Override
public String name() {
return name;
}
@Override
public Query doToQuery(LtrQueryContext context, FeatureSet set, Map params) {
List missingParams = queryParams.stream()
.filter((x) -> params == null || !params.containsKey(x))
.collect(Collectors.toList());
if (!missingParams.isEmpty()) {
String names = missingParams.stream().collect(Collectors.joining(","));
throw new IllegalArgumentException("Missing required param(s): [" + names + "]");
}
Map queryParamValues = getQueryParamValues(params);
return new DerivedExpressionQuery(set, expression, queryParamValues);
}
private Map getQueryParamValues() {
return getQueryParamValues();
}
/**
* If a param is an expression variable then we need to use it as {@link Expression#variables} in lucene expression exposed via Bindings
*
* @param params Map all query params across features.
* @return Map of {@link Expression#variables} names to params values (values are doubles).
*/
private Map getQueryParamValues(Map params) {
Map queryParamValues = new HashMap<>();
for (String param : queryParams) {
if (expressionVariables.contains(param)) {
try {
Number number = (Number) params.get(param);
queryParamValues.put(param, number.doubleValue());
} catch (ClassCastException classCastException) {
throw new IllegalArgumentException("parameter: " + param + " expected to be of type Double");
}
}
}
return queryParamValues;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PrecompiledExpressionFeature that = (PrecompiledExpressionFeature) o;
return Objects.equals(name, that.name)
&& Objects.equals(expression, that.expression)
&& Objects.equals(queryParams, that.queryParams)
&& Objects.equals(expressionVariables, that.expressionVariables);
}
@Override
public int hashCode() {
return Objects.hash(name, expression, queryParams, expressionVariables);
}
@Override
public void validate(FeatureSet set) {
for (String var : expression.variables) {
if (!set.hasFeature(var) && !queryParams.contains(var)) {
throw new IllegalArgumentException("Derived feature [" + this.name + "] refers " +
"to unknown feature or parameter: [" + var + "]");
}
if(set.hasFeature(var) && queryParams.contains(var)){
throw new IllegalArgumentException("Duplicate name " + var + " . " +
"Cannot be used as both feature name and query parameter name");
}
}
}
}