All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.schema.MapEvaluationTypeContext Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;

import com.google.common.collect.ImmutableMap;
import com.yahoo.schema.expressiontransforms.OnnxModelTransformer;
import com.yahoo.schema.expressiontransforms.TokenTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;

/**
 * A context which only contains type information.
 * This returns empty tensor types (double) for unknown features which are not
 * query, attribute or constant features, as we do not have information about which such
 * features exist (but we know those that exist are doubles).
 *
 * This is not multithread safe.
 *
 * @author bratseth
 */
public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext {

    private final Optional parent;

    private final Map featureTypes = new HashMap<>();

    private final Map resolvedTypes = new HashMap<>();

    /** To avoid re-resolving diamond-shaped dependencies */
    private final Map globallyResolvedTypes;

    /** For invocation loop detection */
    private final Deque currentResolutionCallStack;

    private final SortedSet queryFeaturesNotDeclared;
    private boolean tensorsAreUsed;

    MapEvaluationTypeContext(ImmutableMap functions, Map featureTypes) {
        super(functions);
        this.parent = Optional.empty();
        this.featureTypes.putAll(featureTypes);
        this.currentResolutionCallStack =  new ArrayDeque<>();
        this.queryFeaturesNotDeclared = new TreeSet<>();
        tensorsAreUsed = false;
        globallyResolvedTypes = new HashMap<>();
    }

    private MapEvaluationTypeContext(Map functions,
                                     Map bindings,
                                     Optional parent,
                                     Map featureTypes,
                                     Deque currentResolutionCallStack,
                                     SortedSet queryFeaturesNotDeclared,
                                     boolean tensorsAreUsed,
                                     Map globallyResolvedTypes) {
        super(functions, bindings);
        this.parent = parent;
        this.featureTypes.putAll(featureTypes);
        this.currentResolutionCallStack = currentResolutionCallStack;
        this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
        this.tensorsAreUsed = tensorsAreUsed;
        this.globallyResolvedTypes = globallyResolvedTypes;
    }

    public void setType(Reference reference, TensorType type) {
        featureTypes.put(reference, type);
        queryFeaturesNotDeclared.remove(reference);
    }

    public Map featureTypes() { return Collections.unmodifiableMap(featureTypes); }

    @Override
    public TensorType getType(String reference) {
        throw new UnsupportedOperationException("Not able to parse general references from string form");
    }

    public void forgetResolvedTypes() {
        resolvedTypes.clear();
    }

    private boolean referenceCanBeResolvedGlobally(Reference reference) {
        Optional function = functionInvocation(reference);
        return function.isPresent() && function.get().arguments().size() == 0;
        // are there other cases we would like to resolve globally?
    }

    @Override
    public TensorType getType(Reference reference) {
        // computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
        boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference);

        TensorType resolvedType = resolvedTypes.get(reference);
        if (resolvedType == null && canBeResolvedGlobally) {
            resolvedType = globallyResolvedTypes.get(reference);
        }
        if (resolvedType != null) {
            return resolvedType;
        }

        resolvedType = resolveType(reference);
        if (resolvedType == null)
            return defaultTypeOf(reference); // Don't store fallback to default as we may know more later
        resolvedTypes.put(reference, resolvedType);
        if (resolvedType.rank() > 0)
            tensorsAreUsed = true;

        if (canBeResolvedGlobally) {
            globallyResolvedTypes.put(reference, resolvedType);
        }

        return resolvedType;
    }

    MapEvaluationTypeContext getParent(String forArgument, String boundTo) {
        return parent.orElseThrow(
            () -> new IllegalArgumentException("argument "+forArgument+" is bound to "+boundTo+" but there is no parent context"));
    }

    @Override
    public String resolveBinding(String name) {
        String bound = getBinding(name);
        if (bound == null) {
            return name;
        }
        return getParent(name, bound).resolveBinding(bound);
    }

    private TensorType resolveType(Reference reference) {
        if (currentResolutionCallStack.contains(reference))
            throw new IllegalArgumentException("Invocation loop: " +
                                               currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
                                               " -> " + reference);
        // Bound to a function argument?
        Optional binding = boundIdentifier(reference);
        if (binding.isPresent()) {
            try {
                // This is not pretty, but changing to bind expressions rather
                // than their string values requires deeper changes
                var expr = new RankingExpression(binding.get());
                return expr.type(getParent(reference.name(), binding.get()));
            } catch (ParseException e) {
                throw new IllegalArgumentException(e);
            }
        }

        try {
            currentResolutionCallStack.addLast(reference);

            // A reference to an attribute, query or constant feature?
            if (FeatureNames.isSimpleFeature(reference)) {
                // The argument may be a local identifier bound to the actual value
                String argument = reference.simpleArgument().get();
                String argumentBinding = resolveBinding(argument);
                reference = Reference.simple(reference.name(), argumentBinding);
                return featureTypes.get(reference);
            }

            // A reference to a function?
            Optional function = functionInvocation(reference);
            if (function.isPresent()) {
                var body = function.get().getBody();
                var child = this.withBindings(bind(function.get().arguments(), reference.arguments()));
                return body.type(child);
            }

            // A reference to an ONNX model?
            Optional onnxFeatureType = onnxFeatureType(reference);
            if (onnxFeatureType.isPresent()) {
                return onnxFeatureType.get();
            }

            // A reference to a feature for transformer token input?
            Optional transformerTokensFeatureType = transformerTokensFeatureType(reference);
            if (transformerTokensFeatureType.isPresent()) {
                return transformerTokensFeatureType.get();
            }

            // A reference to a feature which returns a tensor?
            Optional featureTensorType = tensorFeatureType(reference);
            if (featureTensorType.isPresent()) {
                return featureTensorType.get();
            }

            // A directly injected identifier? (Useful for stateless model evaluation)
            if (reference.isIdentifier() && featureTypes.containsKey(reference)) {
                return featureTypes.get(reference);
            }

            // the name of a constant feature?
            if (reference.isIdentifier()) {
                Reference asConst = FeatureNames.asConstantFeature(reference.name());
                if (featureTypes.containsKey(asConst)) {
                    return featureTypes.get(asConst);
                }
            }

            // We do not know what this is - since we do not have complete knowledge about the match features
            // in Java we must assume this is a match feature and return the double type - which is the type of
            // all match features
            return TensorType.empty;
        }
        finally {
            currentResolutionCallStack.removeLast();
        }
    }

    /**
     * Returns the default type for this simple feature, or null if it does not have a default
     */
    public TensorType defaultTypeOf(Reference reference) {
        if ( ! FeatureNames.isSimpleFeature(reference))
            throw new IllegalArgumentException("This can only be called for simple references, not " + reference);
        if (reference.name().equals("query")) { // we do not require all query features to be declared, only non-doubles
            queryFeaturesNotDeclared.add(reference);
            return TensorType.empty;
        }
        return null;
    }

    /**
     * Returns the binding if this reference is a simple identifier which is bound in this context.
     * Returns empty otherwise.
     */
    private Optional boundIdentifier(Reference reference) {
        if ( ! reference.arguments().isEmpty()) return Optional.empty();
        if ( reference.output() != null) return Optional.empty();
        return Optional.ofNullable(getBinding(reference.name()));
    }

    private Optional functionInvocation(Reference reference) {
        if (reference.output() != null) return Optional.empty();
        ExpressionFunction function = getFunctions().get(reference.name());
        if (function == null) return Optional.empty();
        if (function.arguments().size() != reference.arguments().size()) return Optional.empty();
        return Optional.of(function);
    }

    private Optional onnxFeatureType(Reference reference) {
        if ( ! reference.name().equals("onnxModel") && ! reference.name().equals("onnx"))
            return Optional.empty();

        if ( ! featureTypes.containsKey(reference)) {
            String configOrFileName = reference.arguments().expressions().get(0).toString();

            // Look up standardized format as added in RankProfile
            String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
            String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);

            reference = new Reference("onnx", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
            if ( ! featureTypes.containsKey(reference)) {
                throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
            }
        }

        return Optional.of(featureTypes.get(reference));
    }

    private Optional transformerTokensFeatureType(Reference reference) {
        if ( ! reference.name().equals("tokenTypeIds") &&
                ! reference.name().equals("tokenInputIds") &&
                ! reference.name().equals("tokenAttentionMask"))
            return Optional.empty();

        if ( ! (reference.arguments().size() > 1))
            throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments");

        ExpressionNode size = reference.arguments().expressions().get(0);
        return Optional.of(TokenTransformer.createTensorType(reference.name(), size));
    }

    /**
     * There are three features which may return some (non-empty) tensor type:
     * - tensorFromLabels
     * - tensorFromWeightedSet
     * - closest
     * This returns the type of those features if this is a reference to either of them, or empty otherwise.
     */
    private Optional tensorFeatureType(Reference reference) {
        if ( ! reference.name().equals("tensorFromLabels") &&
             ! reference.name().equals("tensorFromWeightedSet") &&
             ! reference.name().equals("closest"))
        {
            return Optional.empty();
        }

        if (reference.arguments().size() != 1 && reference.arguments().size() != 2)
            throw new IllegalArgumentException(reference.name() + " must have one or two arguments");

        ExpressionNode arg0 = reference.arguments().expressions().get(0);
        if (reference.name().equals("closest")) {
            if (arg0 instanceof ReferenceNode argRefNode) {
                var argRef = argRefNode.reference();
                if (argRef.isIdentifier()) {
                    var attrFeature = FeatureNames.asAttributeFeature(argRef.name());
                    TensorType attrTT = featureTypes.get(attrFeature);
                    if (attrTT != null && attrTT.rank() > 0) {
                        TensorType mapped = attrTT.mappedSubtype();
                        if (mapped.rank() > 0) {
                            return Optional.of(mapped);
                        } else {
                            throw new IllegalArgumentException("Unexpected tensor type " + attrTT + " for " + attrFeature + " used by " + reference);
                        }
                    }
                }
            }
            throw new IllegalArgumentException("The first argument of " + reference.name() +
                                               " must be the name of a tensor attribute, not " + arg0);
        }
        if ( ! ( arg0 instanceof ReferenceNode) || ! FeatureNames.isSimpleFeature(((ReferenceNode)arg0).reference()))
            throw new IllegalArgumentException("The first argument of " + reference.name() +
                                               " must be a simple feature, not " + arg0);

        String dimension;
        if (reference.arguments().size() > 1) {
            ExpressionNode arg1 = reference.arguments().expressions().get(1);
            if ( ( ! (arg1 instanceof ReferenceNode) || ! (((ReferenceNode)arg1).reference().isIdentifier()))
                 &&
                 ( ! (arg1 instanceof NameNode)))
                throw new IllegalArgumentException("The second argument of " + reference.name() +
                                                   " must be a dimension name, not " + arg1);
            dimension = reference.arguments().expressions().get(1).toString();
        }
        else { // default
            dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString();
        }

        // TODO: Determine the type of the weighted set/vector and use that as value type
        return Optional.of(new TensorType.Builder().mapped(dimension).build());
    }

    /** Binds the given list of formal arguments to their actual values */
    private Map bind(List formalArguments,
                                     Arguments invocationArguments) {
        Map bindings = new HashMap<>(formalArguments.size());
        for (int i = 0; i < formalArguments.size(); i++) {
            String identifier = invocationArguments.expressions().get(i).toString();
            bindings.put(formalArguments.get(i), identifier);
        }
        return bindings;
    }

    /**
     * Returns an unmodifiable view of the query features which was requested but for which we have no type info
     * (such that they default to TensorType.empty), shared between all instances of this
     * involved in resolving a particular rank profile.
     */
    public SortedSet queryFeaturesNotDeclared() {
        return Collections.unmodifiableSortedSet(queryFeaturesNotDeclared);
    }

    /** Returns true if any feature across all instances involved in resolving this rank profile resolves to a tensor */
    public boolean tensorsAreUsed() { return tensorsAreUsed; }

    @Override
    public MapEvaluationTypeContext withBindings(Map bindings) {
        return new MapEvaluationTypeContext(getFunctions(),
                                            bindings,
                                            Optional.of(this),
                                            featureTypes,
                                            currentResolutionCallStack,
                                            queryFeaturesNotDeclared,
                                            tensorsAreUsed,
                                            globallyResolvedTypes);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy