Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.api.annotations.Beta;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
* Returns a subspace of a tensor
*
* @author bratseth
*/
@Beta
public class Slice extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List> subspaceAddress;
/**
* Creates a value function
*
* @param argument the tensor to return a cell value from
* @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress
* because those require a type, but a type is not resolved until this is evaluated
*/
public Slice(TensorFunction argument, List> subspaceAddress) {
this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " +
"Specify dimension names explicitly instead");
this.subspaceAddress = subspaceAddress;
}
@Override
public List> arguments() { return List.of(argument); }
public List> selectorFunctions() {
var result = new ArrayList>();
for (var dimVal : subspaceAddress) {
dimVal.index().flatMap(ScalarFunction::asTensorFunction).ifPresent(result::add);
}
return result;
}
public TensorFunction withTransformedFunctions(
Function, ScalarFunction> transformer)
{
List> transformedAddress = new ArrayList<>();
for (var orig : subspaceAddress) {
var idxFun = orig.index();
if (idxFun.isPresent()) {
var transformed = transformer.apply(idxFun.get());
transformedAddress.add(new DimensionValue(orig.dimension(), transformed));
} else {
transformedAddress.add(orig);
}
}
return new Slice<>(argument, transformedAddress);
}
@Override
public Slice withArguments(List> arguments) {
if (arguments.size() != 1)
throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
return new Slice<>(arguments.get(0), subspaceAddress);
}
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor tensor = argument.evaluate(context);
TensorType resultType = resultType(tensor.type());
PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context);
if (resultType.rank() == 0) { // shortcut common case
return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));
}
Tensor.Builder b = Tensor.Builder.of(resultType);
for (Iterator i = tensor.cellIterator(); i.hasNext(); ) {
Tensor.Cell cell = i.next();
if (matches(subspaceAddress, cell.getKey(), tensor.type()))
b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue());
}
return b.build();
}
private PartialAddress subspaceToAddress(TensorType type, EvaluationContext context) {
PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size());
for (int i = 0; i < subspaceAddress.size(); i++) {
if (subspaceAddress.get(i).label().isPresent())
b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
subspaceAddress.get(i).label().get());
else
b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
subspaceAddress.get(i).index().get().apply(context).intValue());
}
return b.build();
}
private boolean matches(PartialAddress subspaceAddress,
TensorAddress address, TensorType type) {
for (int i = 0; i < subspaceAddress.size(); i++) {
String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get());
if ( ! label.equals(subspaceAddress.label(i)))
return false;
}
return true;
}
/** Returns the subset of the given address which is present in the subspace type */
private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) {
TensorAddress.Builder b = new TensorAddress.Builder(subspaceType);
for (int i = 0; i < address.size(); i++) {
String dimension = type.dimensions().get(i).name();
if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
b.add(dimension, address.numericLabel(i));
}
return b.build();
}
@Override
public TensorType type(TypeContext context) {
return resultType(argument.type(context));
}
private List findDimensions(List dims, Predicate pred) {
return dims.stream().filter(pred).map(TensorType.Dimension::name).toList();
}
private TensorType resultType(TensorType argumentType) {
List peekDimensions;
if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
// Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.get(0).index().isPresent()) {
peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
if (peekDimensions.size() > 1) {
throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
"to " + argumentType + ", which has multiple");
}
}
else {
peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped);
if (peekDimensions.size() > 1)
throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
"to " + argumentType + ", which has multiple");
}
}
else { // general slicing
peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).toList();
}
try {
return TypeResolver.peek(argumentType, peekDimensions);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e);
}
}
@Override
public String toString(ToStringContext context) {
StringBuilder b = new StringBuilder(argument.toString(context));
if (context.typeContext().isEmpty()
&& subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms
if (subspaceAddress.get(0).index().isPresent())
b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
else
b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
}
else { // general form
b.append("{").append(subspaceAddress.stream()
.map(i -> i.toString(context, this))
.collect(Collectors.joining(", "))).append("}");
}
return b.toString();
}
@Override
public int hashCode() { return Objects.hash("slice", argument, subspaceAddress); }
public static class DimensionValue {
private final Optional dimension;
/** The label of this, or null if index is set */
private final String label;
/** The function returning the index of this, or null if label is set */
private final ScalarFunction index;
public DimensionValue(String dimension, String label) {
this(Optional.of(dimension), label, null);
}
public DimensionValue(String dimension, int index) {
this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
}
public DimensionValue(int index) {
this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
}
public DimensionValue(String label) {
this(Optional.empty(), label, null);
}
public DimensionValue(ScalarFunction index) {
this(Optional.empty(), null, index);
}
public DimensionValue(Optional dimension, String label) {
this(dimension, label, null);
}
public DimensionValue(Optional dimension, ScalarFunction index) {
this(dimension, null, index);
}
public DimensionValue(String dimension, ScalarFunction index) {
this(Optional.of(dimension), null, index);
}
private DimensionValue(Optional dimension, String label, ScalarFunction index) {
this.dimension = dimension;
this.label = label;
this.index = index;
}
/**
* Returns the given name of the dimension, or null if dense form is used, such that name
* must be inferred from order
*/
public Optional dimension() { return dimension; }
/** Returns the label for this dimension or empty if it is provided by an index function */
public Optional label() { return Optional.ofNullable(label); }
/** Returns the index expression for this dimension, or empty if it is not a number */
public Optional> index() { return Optional.ofNullable(index); }
@Override
public String toString() {
return toString(null, null);
}
String toString(ToStringContext context, Slice owner) {
StringBuilder b = new StringBuilder();
Optional dimensionName = dimension;
if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail
TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null;
if (type == null || type.dimensions().size() != 1)
throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner +
" cannot be uniquely resolved. Use the full form: " +
"'slice{myDimensionName:" + valueToString(context) + "}'");
else
dimensionName = Optional.of(type.dimensions().get(0).name());
}
dimensionName.ifPresent(d -> b.append(d).append(":"));
b.append(valueToString(context));
return b.toString();
}
private String valueToString(ToStringContext context) {
if (label != null) {
return TensorAddress.labelToString(label);
} else {
return index.toString(context);
}
}
@Override
public int hashCode() { return Objects.hash(dimension, label, index); }
}
private static class ConstantIntegerFunction implements ScalarFunction {
private final int value;
public ConstantIntegerFunction(int value) {
this.value = value;
}
@Override
public Double apply(EvaluationContext context) {
return (double)value;
}
@Override
public String toString() { return String.valueOf(value); }
@Override
public int hashCode() { return Objects.hash("constantIntegerFunction", value); }
}
}