com.yahoo.tensor.evaluation.VariableTensor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.evaluation;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
/**
* A tensor variable name which resolves to a tensor in the context at evaluation time
*
* @author bratseth
*/
public class VariableTensor extends PrimitiveTensorFunction {
private final String name;
private final Optional requiredType;
public VariableTensor(String name) {
this.name = name;
this.requiredType = Optional.empty();
}
/** A variable tensor which must be compatible with the given type */
public VariableTensor(String name, TensorType requiredType) {
this.name = name;
this.requiredType = Optional.of(requiredType);
}
@Override
public List> arguments() { return List.of(); }
@Override
public TensorFunction withArguments(List> arguments) { return this; }
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
public TensorType type(TypeContext context) {
TensorType givenType = context.getType(name);
if (givenType == null) return null;
verifyType(givenType);
return givenType;
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor tensor = context.getTensor(name);
if (tensor == null) return null;
verifyType(tensor.type());
return tensor;
}
@Override
public String toString(ToStringContext context) {
return name;
}
@Override
public int hashCode() { return Objects.hash("variableTensor", name, requiredType); }
private void verifyType(TensorType givenType) {
if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
requiredType.get() + " but was " + givenType);
}
}