All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.Merge Maven / Gradle / Ivy

Go to download

Library for use in Java components of Vespa. Shared code which do not fit anywhere else.

There is a newer version: 8.409.18
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;

/**
 * The merge tensor operation produces from two argument tensors having equal types
 * a tensor having the same type where the values are the union of the values of both tensors. In the cases where both
 * tensors contain a value for a given cell, and only then, the lambda scalar expression is evaluated to produce
 * the resulting cell value.
 *
 * @author bratseth
 */
public class Merge extends PrimitiveTensorFunction {

    private final TensorFunction argumentA, argumentB;
    private final DoubleBinaryOperator merger;

    public Merge(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator merger) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(merger, "The merger function cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.merger = merger;
    }

    /** Returns the type resulting from applying Merge to the two given types */
    public static TensorType outputType(TensorType a, TensorType b) {
        return TypeResolver.merge(a, b);
    }

    public DoubleBinaryOperator merger() { return merger; }

    @Override
    public List> arguments() { return List.of(argumentA, argumentB); }

    @Override
    public TensorFunction withArguments(List> arguments) {
        if ( arguments.size() != 2)
            throw new IllegalArgumentException("Merge must have 2 arguments, got " + arguments.size());
        return new Merge<>(arguments.get(0), arguments.get(1), merger);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new Merge<>(argumentA.toPrimitive(), argumentB.toPrimitive(), merger);
    }

    @Override
    public TensorType type(TypeContext context) {
        return outputType(argumentA.type(context), argumentB.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor a = argumentA.evaluate(context);
        Tensor b = argumentB.evaluate(context);
        TensorType mergedType = outputType(a.type(), b.type());
        return evaluate(a, b, mergedType, merger);
    }


    @Override
    public String toString(ToStringContext context) {
        return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
    }

    @Override
    public int hashCode() { return Objects.hash("merge", argumentA, argumentB, merger); }

    static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
        // Choose merge algorithm
        if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
            return indexedVectorMerge((IndexedTensor)a, (IndexedTensor)b, mergedType, combinator);
        else
            return generalMerge(a, b, mergedType, combinator);
    }

    private static boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private static Tensor indexedVectorMerge(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
        long aSize = a.dimensionSizes().size(0);
        long bSize = b.dimensionSizes().size(0);
        long mergedSize = Math.max(aSize, bSize);
        long sharedSize = Math.min(aSize, bSize);
        Iterator aIterator = a.valueIterator();
        Iterator bIterator = b.valueIterator();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
        for (long i = 0; i < sharedSize; i++)
            builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
        Iterator largestIterator = aSize > bSize ? aIterator : bIterator;
        for (long i = sharedSize; i < mergedSize; i++)
            builder.cell(largestIterator.next(), i);
        return builder.build();
    }

    private static Tensor generalMerge(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
        Tensor.Builder builder = Tensor.Builder.of(mergedType);
        addCellsOf(a, b, builder, combinator);
        addCellsOf(b, a, builder, null);
        return builder.build();
    }

    private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) {
        for (Iterator i = a.cellIterator(); i.hasNext(); ) {
            Map.Entry aCell = i.next();
            var key = aCell.getKey();
            Double bVal = b.getAsDouble(key);
            if (bVal == null) {
                builder.cell(key, aCell.getValue());
            } else if (combinator != null) {
                builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
            }
        }
    }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy