All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.evaluation.VariableTensor Maven / Gradle / Ivy

Go to download

Library for use in Java components of Vespa. Shared code which do not fit anywhere else.

There is a newer version: 8.409.18
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.evaluation;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

/**
 * A tensor variable name which resolves to a tensor in the context at evaluation time
 *
 * @author bratseth
 */
public class VariableTensor extends PrimitiveTensorFunction {

    private final String name;
    private final Optional requiredType;

    public VariableTensor(String name) {
        this.name = name;
        this.requiredType = Optional.empty();
    }

    /** A variable tensor which must be compatible with the given type */
    public VariableTensor(String name, TensorType requiredType) {
        this.name = name;
        this.requiredType = Optional.of(requiredType);
    }

    @Override
    public List> arguments() { return List.of(); }

    @Override
    public TensorFunction withArguments(List> arguments) { return this; }

    @Override
    public PrimitiveTensorFunction toPrimitive() { return this; }

    @Override
    public TensorType type(TypeContext context) {
        TensorType givenType = context.getType(name);
        if (givenType == null) return null;
        verifyType(givenType);
        return givenType;
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor tensor = context.getTensor(name);
        if (tensor == null) return null;
        verifyType(tensor.type());
        return tensor;
    }

    @Override
    public String toString(ToStringContext context) {
        return name;
    }

    @Override
    public int hashCode() { return Objects.hash("variableTensor", name, requiredType); }

    private void verifyType(TensorType givenType) {
        if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
            throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
                                               requiredType.get() + " but was " + givenType);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy