Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;
import com.google.common.collect.ImmutableMap;
import com.yahoo.schema.expressiontransforms.OnnxModelTransformer;
import com.yahoo.schema.expressiontransforms.TokenTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
/**
* A context which only contains type information.
* This returns empty tensor types (double) for unknown features which are not
* query, attribute or constant features, as we do not have information about which such
* features exist (but we know those that exist are doubles).
*
* This is not multithread safe.
*
* @author bratseth
*/
public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext {
private final Optional parent;
private final Map featureTypes = new HashMap<>();
private final Map resolvedTypes = new HashMap<>();
/** To avoid re-resolving diamond-shaped dependencies */
private final Map globallyResolvedTypes;
/** For invocation loop detection */
private final Deque currentResolutionCallStack;
private final SortedSet queryFeaturesNotDeclared;
private boolean tensorsAreUsed;
MapEvaluationTypeContext(ImmutableMap functions, Map featureTypes) {
super(functions);
this.parent = Optional.empty();
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = new ArrayDeque<>();
this.queryFeaturesNotDeclared = new TreeSet<>();
tensorsAreUsed = false;
globallyResolvedTypes = new HashMap<>();
}
private MapEvaluationTypeContext(Map functions,
Map bindings,
Optional parent,
Map featureTypes,
Deque currentResolutionCallStack,
SortedSet queryFeaturesNotDeclared,
boolean tensorsAreUsed,
Map globallyResolvedTypes) {
super(functions, bindings);
this.parent = parent;
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = currentResolutionCallStack;
this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
this.tensorsAreUsed = tensorsAreUsed;
this.globallyResolvedTypes = globallyResolvedTypes;
}
public void setType(Reference reference, TensorType type) {
featureTypes.put(reference, type);
queryFeaturesNotDeclared.remove(reference);
}
public Map featureTypes() { return Collections.unmodifiableMap(featureTypes); }
@Override
public TensorType getType(String reference) {
throw new UnsupportedOperationException("Not able to parse general references from string form");
}
public void forgetResolvedTypes() {
resolvedTypes.clear();
}
private boolean referenceCanBeResolvedGlobally(Reference reference) {
Optional function = functionInvocation(reference);
return function.isPresent() && function.get().arguments().size() == 0;
// are there other cases we would like to resolve globally?
}
@Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference);
TensorType resolvedType = resolvedTypes.get(reference);
if (resolvedType == null && canBeResolvedGlobally) {
resolvedType = globallyResolvedTypes.get(reference);
}
if (resolvedType != null) {
return resolvedType;
}
resolvedType = resolveType(reference);
if (resolvedType == null)
return defaultTypeOf(reference); // Don't store fallback to default as we may know more later
resolvedTypes.put(reference, resolvedType);
if (resolvedType.rank() > 0)
tensorsAreUsed = true;
if (canBeResolvedGlobally) {
globallyResolvedTypes.put(reference, resolvedType);
}
return resolvedType;
}
MapEvaluationTypeContext getParent(String forArgument, String boundTo) {
return parent.orElseThrow(
() -> new IllegalArgumentException("argument "+forArgument+" is bound to "+boundTo+" but there is no parent context"));
}
@Override
public String resolveBinding(String name) {
String bound = getBinding(name);
if (bound == null) {
return name;
}
return getParent(name, bound).resolveBinding(bound);
}
private TensorType resolveType(Reference reference) {
if (currentResolutionCallStack.contains(reference))
throw new IllegalArgumentException("Invocation loop: " +
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
" -> " + reference);
// Bound to a function argument?
Optional binding = boundIdentifier(reference);
if (binding.isPresent()) {
try {
// This is not pretty, but changing to bind expressions rather
// than their string values requires deeper changes
var expr = new RankingExpression(binding.get());
return expr.type(getParent(reference.name(), binding.get()));
} catch (ParseException e) {
throw new IllegalArgumentException(e);
}
}
try {
currentResolutionCallStack.addLast(reference);
// A reference to an attribute, query or constant feature?
if (FeatureNames.isSimpleFeature(reference)) {
// The argument may be a local identifier bound to the actual value
String argument = reference.simpleArgument().get();
String argumentBinding = resolveBinding(argument);
reference = Reference.simple(reference.name(), argumentBinding);
return featureTypes.get(reference);
}
// A reference to a function?
Optional function = functionInvocation(reference);
if (function.isPresent()) {
var body = function.get().getBody();
var child = this.withBindings(bind(function.get().arguments(), reference.arguments()));
return body.type(child);
}
// A reference to an ONNX model?
Optional onnxFeatureType = onnxFeatureType(reference);
if (onnxFeatureType.isPresent()) {
return onnxFeatureType.get();
}
// A reference to a feature for transformer token input?
Optional transformerTokensFeatureType = transformerTokensFeatureType(reference);
if (transformerTokensFeatureType.isPresent()) {
return transformerTokensFeatureType.get();
}
// A reference to a feature which returns a tensor?
Optional featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
return featureTensorType.get();
}
// A directly injected identifier? (Useful for stateless model evaluation)
if (reference.isIdentifier() && featureTypes.containsKey(reference)) {
return featureTypes.get(reference);
}
// the name of a constant feature?
if (reference.isIdentifier()) {
Reference asConst = FeatureNames.asConstantFeature(reference.name());
if (featureTypes.containsKey(asConst)) {
return featureTypes.get(asConst);
}
}
// We do not know what this is - since we do not have complete knowledge about the match features
// in Java we must assume this is a match feature and return the double type - which is the type of
// all match features
return TensorType.empty;
}
finally {
currentResolutionCallStack.removeLast();
}
}
/**
* Returns the default type for this simple feature, or null if it does not have a default
*/
public TensorType defaultTypeOf(Reference reference) {
if ( ! FeatureNames.isSimpleFeature(reference))
throw new IllegalArgumentException("This can only be called for simple references, not " + reference);
if (reference.name().equals("query")) { // we do not require all query features to be declared, only non-doubles
queryFeaturesNotDeclared.add(reference);
return TensorType.empty;
}
return null;
}
/**
* Returns the binding if this reference is a simple identifier which is bound in this context.
* Returns empty otherwise.
*/
private Optional boundIdentifier(Reference reference) {
if ( ! reference.arguments().isEmpty()) return Optional.empty();
if ( reference.output() != null) return Optional.empty();
return Optional.ofNullable(getBinding(reference.name()));
}
private Optional functionInvocation(Reference reference) {
if (reference.output() != null) return Optional.empty();
ExpressionFunction function = getFunctions().get(reference.name());
if (function == null) return Optional.empty();
if (function.arguments().size() != reference.arguments().size()) return Optional.empty();
return Optional.of(function);
}
private Optional onnxFeatureType(Reference reference) {
if ( ! reference.name().equals("onnxModel") && ! reference.name().equals("onnx"))
return Optional.empty();
if ( ! featureTypes.containsKey(reference)) {
String configOrFileName = reference.arguments().expressions().get(0).toString();
// Look up standardized format as added in RankProfile
String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
reference = new Reference("onnx", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
if ( ! featureTypes.containsKey(reference)) {
throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
}
}
return Optional.of(featureTypes.get(reference));
}
private Optional transformerTokensFeatureType(Reference reference) {
if ( ! reference.name().equals("tokenTypeIds") &&
! reference.name().equals("tokenInputIds") &&
! reference.name().equals("tokenAttentionMask"))
return Optional.empty();
if ( ! (reference.arguments().size() > 1))
throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments");
ExpressionNode size = reference.arguments().expressions().get(0);
return Optional.of(TokenTransformer.createTensorType(reference.name(), size));
}
/**
* There are three features which may return some (non-empty) tensor type:
* - tensorFromLabels
* - tensorFromWeightedSet
* - closest
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
*/
private Optional tensorFeatureType(Reference reference) {
if ( ! reference.name().equals("tensorFromLabels") &&
! reference.name().equals("tensorFromWeightedSet") &&
! reference.name().equals("closest"))
{
return Optional.empty();
}
if (reference.arguments().size() != 1 && reference.arguments().size() != 2)
throw new IllegalArgumentException(reference.name() + " must have one or two arguments");
ExpressionNode arg0 = reference.arguments().expressions().get(0);
if (reference.name().equals("closest")) {
if (arg0 instanceof ReferenceNode argRefNode) {
var argRef = argRefNode.reference();
if (argRef.isIdentifier()) {
var attrFeature = FeatureNames.asAttributeFeature(argRef.name());
TensorType attrTT = featureTypes.get(attrFeature);
if (attrTT != null && attrTT.rank() > 0) {
TensorType mapped = attrTT.mappedSubtype();
if (mapped.rank() > 0) {
return Optional.of(mapped);
} else {
throw new IllegalArgumentException("Unexpected tensor type " + attrTT + " for " + attrFeature + " used by " + reference);
}
}
}
}
throw new IllegalArgumentException("The first argument of " + reference.name() +
" must be the name of a tensor attribute, not " + arg0);
}
if ( ! ( arg0 instanceof ReferenceNode) || ! FeatureNames.isSimpleFeature(((ReferenceNode)arg0).reference()))
throw new IllegalArgumentException("The first argument of " + reference.name() +
" must be a simple feature, not " + arg0);
String dimension;
if (reference.arguments().size() > 1) {
ExpressionNode arg1 = reference.arguments().expressions().get(1);
if ( ( ! (arg1 instanceof ReferenceNode) || ! (((ReferenceNode)arg1).reference().isIdentifier()))
&&
( ! (arg1 instanceof NameNode)))
throw new IllegalArgumentException("The second argument of " + reference.name() +
" must be a dimension name, not " + arg1);
dimension = reference.arguments().expressions().get(1).toString();
}
else { // default
dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString();
}
// TODO: Determine the type of the weighted set/vector and use that as value type
return Optional.of(new TensorType.Builder().mapped(dimension).build());
}
/** Binds the given list of formal arguments to their actual values */
private Map bind(List formalArguments,
Arguments invocationArguments) {
Map bindings = new HashMap<>(formalArguments.size());
for (int i = 0; i < formalArguments.size(); i++) {
String identifier = invocationArguments.expressions().get(i).toString();
bindings.put(formalArguments.get(i), identifier);
}
return bindings;
}
/**
* Returns an unmodifiable view of the query features which was requested but for which we have no type info
* (such that they default to TensorType.empty), shared between all instances of this
* involved in resolving a particular rank profile.
*/
public SortedSet queryFeaturesNotDeclared() {
return Collections.unmodifiableSortedSet(queryFeaturesNotDeclared);
}
/** Returns true if any feature across all instances involved in resolving this rank profile resolves to a tensor */
public boolean tensorsAreUsed() { return tensorsAreUsed; }
@Override
public MapEvaluationTypeContext withBindings(Map bindings) {
return new MapEvaluationTypeContext(getFunctions(),
bindings,
Optional.of(this),
featureTypes,
currentResolutionCallStack,
queryFeaturesNotDeclared,
tensorsAreUsed,
globallyResolvedTypes);
}
}