com.yahoo.tensor.functions.DynamicTensor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.List;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
*
* @author bratseth
*/
public abstract class DynamicTensor extends PrimitiveTensorFunction {
private final TensorType type;
DynamicTensor(TensorType type) {
this.type = type;
}
@Override
public TensorType type(TypeContext context) { return type; }
@Override
public List> arguments() { return List.of(); }
public abstract List> cellGeneratorFunctions();
@Override
public TensorFunction withArguments(List> arguments) {
if (!arguments.isEmpty())
throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
return this;
}
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
TensorType type() { return type; }
abstract String contentToString(ToStringContext context);
@Override
public String toString(ToStringContext context) {
return type().toString() + ":" + contentToString(context);
}
/** Creates a dynamic tensor function. The cell addresses must match the type. */
public static DynamicTensor from(TensorType type, Map> cells) {
return new MappedDynamicTensor<>(type, cells);
}
/** Creates a dynamic tensor function for a bound, indexed tensor */
public static DynamicTensor from(TensorType type, List> cells) {
return new IndexedDynamicTensor<>(type, cells);
}
private static class MappedDynamicTensor extends DynamicTensor {
private final ImmutableMap> cells;
MappedDynamicTensor(TensorType type, Map> cells) {
super(type);
this.cells = ImmutableMap.copyOf(cells);
}
public List> cellGeneratorFunctions() {
var result = new ArrayList>();
for (var fun : cells.values()) {
fun.asTensorFunction().ifPresent(result::add);
}
return result;
}
public TensorFunction withTransformedFunctions(
Function, ScalarFunction> transformer)
{
Map> transformedCells = new LinkedHashMap<>();
for (var orig : cells.entrySet()) {
var transformed = transformer.apply(orig.getValue());
transformedCells.put(orig.getKey(), transformed);
}
return new MappedDynamicTensor<>(type(), transformedCells);
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (var cell : cells.entrySet())
builder.cell(cell.getKey(), cell.getValue().apply(context));
return builder.build();
}
@Override
String contentToString(ToStringContext context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{{}:" + cells.values().iterator().next().toString(context) + "}";
}
StringBuilder b = new StringBuilder("{");
for (var cell : cells.entrySet()) {
b.append(cell.getKey().toString(type())).append(":").append(cell.getValue().toString(context));
b.append(",");
}
if (b.length() > 1)
b.setLength(b.length() - 1);
b.append("}");
return b.toString();
}
@Override
public int hashCode() { return Objects.hash("mappedDynamicTensor", type(), cells); }
}
private static class IndexedDynamicTensor extends DynamicTensor {
private final List> cells;
IndexedDynamicTensor(TensorType type, List> cells) {
super(type);
if ( ! type.hasOnlyIndexedBoundDimensions())
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
"only indexed, bound dimensions, but this has " + type);
this.cells = List.copyOf(cells);
}
public List> cellGeneratorFunctions() {
var result = new ArrayList>();
for (var fun : cells) {
fun.asTensorFunction().ifPresent(result::add);
}
return result;
}
public TensorFunction withTransformedFunctions(
Function, ScalarFunction> transformer)
{
List> transformedCells = new ArrayList<>();
for (var orig : cells) {
var transformed = transformer.apply(orig);
transformedCells.add(transformed);
}
return new IndexedDynamicTensor<>(type(), transformedCells);
}
@Override
public Tensor evaluate(EvaluationContext context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
for (int i = 0; i < cells.size(); i++)
builder.cellByDirectIndex(i, cells.get(i).apply(context));
return builder.build();
}
@Override
String contentToString(ToStringContext context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.get(0).toString(context) + "}";
}
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type());
StringBuilder b = new StringBuilder("{");
for (var cell : cells) {
indexes.next();
b.append(indexes.toAddress().toString(type())).append(":").append(cell.toString(context));
b.append(",");
}
if (b.length() > 1)
b.setLength(b.length() - 1);
b.append("}");
return b.toString();
}
@Override
public int hashCode() { return Objects.hash("indexedDynamicTensor", type(), cells); }
}
}