com.yahoo.tensor.functions.Expand Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
import java.util.Objects;
/**
* The expand tensor function returns a tensor with a new dimension of
* size 1 is added, equivalent to "tensor * tensor(dim_name[1])(1)".
*
* @author lesters
*/
public class Expand extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
public Expand(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
public List> arguments() { return List.of(argument); }
@Override
public TensorFunction withArguments(List> arguments) {
if (arguments.size() != 1)
throw new IllegalArgumentException("Expand must have 1 argument, got " + arguments.size());
return new Expand<>(arguments.get(0), dimension);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return toPrimitive(dimension);
}
@Override
public final TensorType type(TypeContext context) {
return toPrimitive(context.resolveBinding(dimension)).type(context);
}
private PrimitiveTensorFunction toPrimitive(String dimension) {
TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimension, 1).build();
Generate expansion = new Generate<>(type, ScalarFunctions.constant(1.0));
return new Join<>(expansion, argument, ScalarFunctions.multiply());
}
@Override
public String toString(ToStringContext context) {
return "expand(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")";
}
@Override
public int hashCode() { return Objects.hash("expand", argument, dimension); }
}