com.yahoo.tensor.functions.Map Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.DoubleUnaryOperator;
/**
* The map tensor function produces a tensor where the given function is applied on each cell value.
*
* @author bratseth
*/
public class Map extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final DoubleUnaryOperator mapper;
public Map(TensorFunction argument, DoubleUnaryOperator mapper) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(mapper, "The argument function cannot be null");
this.argument = argument;
this.mapper = mapper;
}
public static TensorType outputType(TensorType inputType) {
return TypeResolver.map(inputType);
}
public TensorFunction argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
@Override
public List> arguments() { return List.of(argument); }
@Override
public TensorFunction withArguments(List> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size());
return new Map<>(arguments.get(0), mapper);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Map<>(argument.toPrimitive(), mapper);
}
@Override
public TensorType type(TypeContext context) {
return outputType(argument.type(context));
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor input = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(outputType(input.type()));
for (Iterator i = input.cellIterator(); i.hasNext(); ) {
java.util.Map.Entry cell = i.next();
builder.cell(cell.getKey(), mapper.applyAsDouble(cell.getValue()));
}
return builder.build();
}
@Override
public String toString(ToStringContext context) {
return "map(" + argument.toString(context) + ", " + mapper + ")";
}
@Override
public int hashCode() { return Objects.hash("map", argument, mapper); }
}