com.yahoo.tensor.functions.Generate Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
/**
* An indexed tensor whose values are generated by a function
*
* @author bratseth
*/
public class Generate extends PrimitiveTensorFunction {
private final TensorType type;
// One of these are null
private final Function, Double> freeGenerator;
private final ScalarFunction boundGenerator;
/** The same as Generate.free */
public Generate(TensorType type, Function, Double> generator) {
this(type, Objects.requireNonNull(generator), null);
}
/**
* Creates a generated tensor from a free function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
public static Generate free(TensorType type, Function, Double> generator) {
return new Generate<>(type, Objects.requireNonNull(generator), null);
}
/**
* Creates a generated tensor from a bound function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
public static Generate bound(TensorType type, ScalarFunction generator) {
return new Generate<>(type, null, Objects.requireNonNull(generator));
}
private Generate(TensorType type, Function, Double> freeGenerator, ScalarFunction boundGenerator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
validateType(type);
this.type = type;
this.freeGenerator = freeGenerator;
this.boundGenerator = boundGenerator;
}
private void validateType(TensorType type) {
for (TensorType.Dimension dimension : type.dimensions())
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
}
@Override
public List> arguments() {
return boundGenerator != null && boundGenerator.asTensorFunction().isPresent()
? List.of(boundGenerator.asTensorFunction().get())
: List.of();
}
@Override
public TensorFunction withArguments(List> arguments) {
if ( arguments.size() > 1)
throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size());
if (arguments.isEmpty()) return this;
if (arguments.get(0).asScalarFunction().isEmpty())
throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, " +
"but got " + arguments.get(0));
return new Generate<>(type, null, arguments.get(0).asScalarFunction().get());
}
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
public TensorType type(TypeContext context) { return type; }
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
GenerateEvaluationContext generateContext = new GenerateEvaluationContext(type, context);
for (int i = 0; i < indexes.size(); i++) {
indexes.next();
builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
}
return builder.build();
}
private DimensionSizes dimensionSizes(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < b.dimensions(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
@Override
public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; }
private String generatorToString(ToStringContext context) {
if (freeGenerator != null)
return freeGenerator.toString();
else
return boundGenerator.toString(new GenerateToStringContext(context));
}
@Override
public int hashCode() { return Objects.hash("generate", type, freeGenerator, boundGenerator); }
/**
* A context for generating all the values of a tensor produced by evaluating Generate.
* This returns all the current index values as variables and falls back to delivering from the given
* evaluation context.
*/
private class GenerateEvaluationContext implements EvaluationContext {
private final TensorType type;
private final EvaluationContext context;
private IndexedTensor.Indexes indexes;
GenerateEvaluationContext(TensorType type, EvaluationContext context) {
this.type = type;
this.context = context;
}
double apply(IndexedTensor.Indexes indexes) {
if (freeGenerator != null) {
return freeGenerator.apply(indexes.toList());
}
else {
this.indexes = indexes;
return boundGenerator.apply(this);
}
}
@Override
public Tensor getTensor(String name) {
Optional index = type.indexOfDimension(name);
if (index.isPresent()) // this is the name of a dimension
return Tensor.from(indexes.indexesForReading()[index.get()]);
else
return context.getTensor(name);
}
@Override
public TensorType getType(NAMETYPE name) {
Optional index = type.indexOfDimension(name.name());
if (index.isPresent()) // this is the name of a dimension
return TensorType.empty;
else
return context.getType(name);
}
@Override
public TensorType getType(String name) {
Optional index = type.indexOfDimension(name);
if (index.isPresent()) // this is the name of a dimension
return TensorType.empty;
else
return context.getType(name);
}
@Override
public String resolveBinding(String name) {
return context.resolveBinding(name);
}
}
/** A context which adds the bindings of the generate dimension names to the given context. */
private class GenerateToStringContext implements ToStringContext {
private final ToStringContext context;
public GenerateToStringContext(ToStringContext context) {
this.context = context;
}
@Override
public String getBinding(String identifier) {
if (type.dimension(identifier).isPresent())
return identifier; // dimension names are bound but not substituted in the generate context
else
return context.getBinding(identifier);
}
@Override
public ToStringContext parent() { return context; }
}
}