All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.schema.expressiontransforms.NormalizerFunctionExpander Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.expressiontransforms;

import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.RankProfile.RankFeatureNormalizer;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.functions.Generate;

import java.io.StringReader;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.ArrayList;

/**
 * Recognizes pseudo-functions and creates global-phase normalizers
 * @author arnej
 */
public class NormalizerFunctionExpander extends ExpressionTransformer {

    public final static String NORMALIZE_LINEAR = "normalize_linear";
    public final static String RECIPROCAL_RANK = "reciprocal_rank";
    public final static String RECIPROCAL_RANK_FUSION = "reciprocal_rank_fusion";

    @Override
    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode r) {
            node = transformReference(r, context);
        }
        if (node instanceof CompositeNode composite) {
            node = transformChildren(composite, context);
        }
        return node;
    }

    private ExpressionNode transformReference(ReferenceNode node, RankProfileTransformContext context) {
        Reference ref = node.reference();
        String name = ref.name();
        if (ref.output() != null) {
            return node;
        }
        var f = context.rankProfile().getFunctions().get(name);
        if (f != null) {
            // never transform declared functions
            return node;
        }
        return switch(name) {
            case RECIPROCAL_RANK_FUSION -> transform(expandRRF(ref), context);
            case NORMALIZE_LINEAR -> transformNormLin(ref, context);
            case RECIPROCAL_RANK -> transformRRank(ref, context);
            default -> node;
        };
    }

    private ExpressionNode expandRRF(Reference ref) {
        var args = ref.arguments();
        if (args.size() < 2) {
            throw new IllegalArgumentException("must have at least 2 arguments: " + ref);
        }
        List children = new ArrayList<>();
        List operators = new ArrayList<>();
        for (var arg : args.expressions()) {
            if (! children.isEmpty()) operators.add(Operator.plus);
            children.add(new ReferenceNode(RECIPROCAL_RANK, List.of(arg), null));
        }
        // must be further transformed (see above)
        return new OperationNode(children, operators);
    }

    private ExpressionNode transformNormLin(Reference ref, RankProfileTransformContext context) {
        var args = ref.arguments();
        if (args.size() != 1) {
            throw new IllegalArgumentException("must have exactly 1 argument: " + ref);
        }
        var input = args.expressions().get(0);
        if (input instanceof ReferenceNode inputRefNode) {
            var inputRef = inputRefNode.reference();
            RankFeatureNormalizer normalizer = RankFeatureNormalizer.linear(ref, inputRef);
            context.rankProfile().addFeatureNormalizer(normalizer);
            var newRef = Reference.fromIdentifier(normalizer.name());
            return new ReferenceNode(newRef);
        } else {
            throw new IllegalArgumentException("the first argument must be a simple feature: " + ref + " => " + input.getClass());
        }
    }

    private ExpressionNode transformRRank(Reference ref, RankProfileTransformContext context) {
        var args = ref.arguments();
        if (args.size() < 1 || args.size() > 2) {
            throw new IllegalArgumentException("must have 1 or 2 arguments: " + ref);
        }
        double k = 60.0;
        if (args.size() == 2) {
            var kArg = args.expressions().get(1);
            if (kArg instanceof ConstantNode kNode) {
                k = kNode.getValue().asDouble();
            } else {
                throw new IllegalArgumentException("the second argument (k) must be a constant in: " + ref);
            }
        }
        var input = args.expressions().get(0);
        if (input instanceof ReferenceNode inputRefNode) {
            var inputRef = inputRefNode.reference();
            RankFeatureNormalizer normalizer = RankFeatureNormalizer.rrank(ref, inputRef, k);
            context.rankProfile().addFeatureNormalizer(normalizer);
            var newRef = Reference.fromIdentifier(normalizer.name());
            return new ReferenceNode(newRef);
        } else {
            throw new IllegalArgumentException("the first argument must be a simple feature: " + ref);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy