All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.vespa.schemals.schemadocument.resolvers.RankExpressionSymbolResolver Maven / Gradle / Ivy

There is a newer version: 8.441.21
Show newest version
package ai.vespa.schemals.schemadocument.resolvers;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import org.eclipse.lsp4j.Diagnostic;

import ai.vespa.schemals.context.ParseContext;
import ai.vespa.schemals.index.Symbol;
import ai.vespa.schemals.index.Symbol.SymbolStatus;
import ai.vespa.schemals.index.Symbol.SymbolType;
import ai.vespa.schemals.parser.rankingexpression.ast.unaryFunctionName;
import ai.vespa.schemals.schemadocument.resolvers.RankExpression.BuiltInFunctions;
import ai.vespa.schemals.schemadocument.resolvers.RankExpression.GenericFunction;
import ai.vespa.schemals.tree.Node;
import ai.vespa.schemals.tree.SchemaNode;
import ai.vespa.schemals.tree.Node.LanguageType;
import ai.vespa.schemals.tree.rankingexpression.RankNode;
import ai.vespa.schemals.tree.rankingexpression.RankNode.RankNodeType;
import ai.vespa.schemals.tree.rankingexpression.RankNode.ReturnType;

/**
 * RankExpressionSymbolResolver goes through unresolved symbols in rank expression, to check if they are calling built in functions and tries
 * to figure out which symbolType the symbol referes to.
 */
public class RankExpressionSymbolResolver { 

    /**
     * Resolves rank expression references in the tree
     *
     * @param node        The schema node to resolve the rank expression references in.
     * @param context     The parse context.
     */
    public static List resolveRankExpressionReferences(SchemaNode node, ParseContext context) {
        List diagnostics = new ArrayList<>();

        if (node.getLanguageType() == LanguageType.RANK_EXPRESSION) {

            diagnostics.addAll(resolveRankExpression(node, context));

        } else {
            for (Node child : node) {
                diagnostics.addAll(resolveRankExpressionReferences(child.getSchemaNode(), context));
            }
        }

        return diagnostics;
    }

    public static List resolveRankExpression(SchemaNode schemaNode, ParseContext context) {
        List diagnostics = new ArrayList<>();

        List rankNodes = RankNode.createTree(schemaNode);

        for (RankNode node : rankNodes) {
            diagnostics.addAll(traverseRankExpressionTree(node, context));
        }

        return diagnostics;
    }

    private static List traverseRankExpressionTree(RankNode node, ParseContext context) {
        List diagnostics = new ArrayList<>();

        for (RankNode child : node) {
            diagnostics.addAll(traverseRankExpressionTree(child, context));
        }

        // All feature nodes has a symbol before the traversal
        if (node.hasSymbol()) {

            if (node.getSymbolStatus() == SymbolStatus.UNRESOLVED) {
                resolveReference(node, context, diagnostics);
            }

            if (node.getSymbolStatus() == SymbolStatus.UNRESOLVED) {
                findBuiltInTensorFunction(node);
            }

            if (node.getSymbolStatus() == SymbolStatus.UNRESOLVED) {
                findBuiltInFunction(node, context, diagnostics);
            }

        }

        return diagnostics;
    }

    public static final Set> builtInTokenizedFunctions = new HashSet<>() {{
        add(unaryFunctionName.class);
    }};

    private static void findBuiltInTensorFunction(RankNode node) {

        if (node.getType() == RankNodeType.BUILT_IN_FUNCTION) {
            Symbol symbol = node.getSymbol();
            symbol.setType(SymbolType.FUNCTION);
            symbol.setStatus(SymbolStatus.BUILTIN_REFERENCE);
        }
    }

    private static void removeSymbolFromIndex(ParseContext context, SchemaNode node) {
        // walk down first-child path and remove the first symbol found 
        // because some branches of the tree look like a long chain
        while (true) {
            if (node.hasSymbol()) {
                Symbol symbol = node.getSymbol();
                if (symbol.getStatus() == SymbolStatus.REFERENCE) {
                    context.schemaIndex().deleteSymbolReference(symbol);
                }
                node.removeSymbol();
                return;
            }
            if (node.size() > 0)node = node.get(0).getSchemaNode();
            else break;
        }
    }

    private static void findBuiltInFunction(RankNode node, ParseContext context, List diagnostics) {
        if (node.getType() != RankNodeType.FEATURE) {
            return;
        }

        String identifier = node.getSymbol().getShortIdentifier();

        GenericFunction functionHandler = BuiltInFunctions.rankExpressionBuiltInFunctions.get(identifier);
        if (functionHandler == null) return;
        
        node.setReturnType(ReturnType.DOUBLE);
        node.getSymbol().setType(SymbolType.FUNCTION);
        node.getSymbol().setStatus(SymbolStatus.BUILTIN_REFERENCE);

        Optional functionProperty = node.getProperty();
        if (functionProperty.isPresent()) {
            removeSymbolFromIndex(context, functionProperty.get());
        }

        diagnostics.addAll(functionHandler.handleArgumentList(context, node));
    }

    private static final List possibleTypes = new ArrayList<>() {{
        // add(SymbolType.PARAMETER); // This is a special case
        add(SymbolType.FUNCTION);
        add(SymbolType.RANK_CONSTANT);
        add(SymbolType.TENSOR_DIMENSION_MAPPED);
        add(SymbolType.TENSOR_DIMENSION_INDEXED);
    }};

    private static void resolveReference(RankNode referenceNode, ParseContext context, List diagnostics) {

        if (referenceNode.getSymbolType() != SymbolType.TYPE_UNKNOWN) {
            return;
        }

        if (referenceNode.getInsideLambdaFunction()) {
            resolveReferenceInsideLambda(referenceNode, context, diagnostics);
            return;
        }

        Symbol reference = referenceNode.getSymbol();

        Optional definition = Optional.empty();

        if (!referenceNode.getArgumentListExists()) {
            // This can be a parameter
            definition = context.schemaIndex().findSymbol(reference.getScope(), SymbolType.PARAMETER, reference.getShortIdentifier());
        }

        // If the symbol isn't a parameter, maybe it is a function
        if (definition.isEmpty()) {
            // NOTE: Seems like a name collision between a parameter and a constants leads to undefined behaviour
            for (SymbolType possibleType : possibleTypes) {
                definition = context.schemaIndex().findSymbol(reference.getScope(), possibleType, reference.getShortIdentifier());

                // If this is a ambiguous reference to a function or a constant, the app doesn't deploy. Therefore a break will not be a problem.
                // TODO: Implement error message for the error above.
                if (definition.isPresent()) {
                    break;
                }
            }
        }

        if (definition.isEmpty()) {
            return;
        }

        reference.setType(definition.get().getType());
        reference.setStatus(SymbolStatus.REFERENCE);
        context.schemaIndex().insertSymbolReference(definition.get(), reference);
    }

    private static void resolveReferenceInsideLambda(RankNode node, ParseContext context, List diagnostics) {
        
        Symbol symbol = node.getSymbol();

        List possibleDefinition = context.schemaIndex().findSymbolsInScope(symbol.getScope(), SymbolType.PARAMETER, symbol.getShortIdentifier());

        if (possibleDefinition.size() == 0) {
            // Symbol not found
            return;
        }

        if (possibleDefinition.size() > 1) {
            return;
        }

        symbol.setType(SymbolType.PARAMETER);
        symbol.setStatus(SymbolStatus.REFERENCE);
        context.schemaIndex().insertSymbolReference(possibleDefinition.get(0), symbol);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy