com.yahoo.tensor.IndexedTensor Maven / Gradle / Ivy
Show all versions of vespajlib Show documentation
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
/**
* An indexed (dense) tensor.
*
* Some methods on indexed tensors make use of a standard value order: Cells are ordered by increasing
* index where dimensions to the right are incremented before indexes to the left, where the order of dimensions are
* alphabetical by name. In consequence, tensor value ordering is independent of the order in which dimensions are
* specified, and the values of the right-most dimension are adjacent.
*
* @author bratseth
*/
public abstract class IndexedTensor implements Tensor {
/** The prescribed and possibly abstract type this is an instance of */
private final TensorType type;
/** The sizes of the dimensions of this in the order of the dimensions of the type */
private final DimensionSizes dimensionSizes;
IndexedTensor(TensorType type, DimensionSizes dimensionSizes) {
this.type = type;
this.dimensionSizes = dimensionSizes;
}
/**
* Returns an iterator over the cells of this in the standard value order.
*/
@Override
public Iterator cellIterator() {
return new CellIterator();
}
/** Returns an iterator over all the cells in this tensor which matches the given partial address */
// TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently
public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) {
long[] startAddress = new long[type().dimensions().size()];
List iterateDimensions = new ArrayList<>();
for (int i = 0; i < type().dimensions().size(); i++) {
long partialAddressLabel = partialAddress.numericLabel(type.dimensions().get(i).name());
if (partialAddressLabel >= 0) // iterate at this label
startAddress[i] = partialAddressLabel;
else // iterate over this dimension
iterateDimensions.add(i);
}
return new SubspaceIterator(iterateDimensions, startAddress, iterationSizes);
}
/** Returns an iterator over the values of this returned in the standard value order */
@Override
public Iterator valueIterator() {
return new ValueIterator();
}
/**
* Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions
* given and the inner iterator is over each unique value of the rest of the dimensions, in the
* standard value order
*
* @param dimensions the names of the dimensions of the superspace
* @param sizes the size of each dimension in the space we are returning values for, containing
* one value per dimension of this tensor (in order). Each size may be the same or smaller
* than the corresponding size of this tensor
*/
public Iterator subspaceIterator(Set dimensions, DimensionSizes sizes) {
return new SuperspaceIterator(dimensions, sizes);
}
/** Returns a subspace iterator having the sizes of the dimensions of this tensor */
public Iterator subspaceIterator(Set dimensions) {
return subspaceIterator(dimensions, dimensionSizes);
}
/**
* Returns the value at the given indexes as a double
*
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(long ... indexes) {
return get(toValueIndex(indexes, dimensionSizes));
}
public double get(DirectIndexedAddress address) {
return get(address.getDirectIndex());
}
public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); }
/**
* Returns the value at the given indexes as a float
*
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public float getFloat(long ... indexes) {
return getFloat((int)toValueIndex(indexes, dimensionSizes));
}
/** Returns the value at this address, or 0.0 if there is no value at this address */
@Override
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
return get(toValueIndex(address, dimensionSizes, type));
}
catch (IllegalArgumentException e) {
return 0.0;
}
}
@Override
public Double getAsDouble(TensorAddress address) {
try {
long index = toValueIndex(address, dimensionSizes, type);
if (index < 0 || size() <= index) return null;
return get(index);
} catch (IllegalArgumentException e) {
return null;
}
}
@Override
public boolean has(TensorAddress address) {
try {
long index = toValueIndex(address, dimensionSizes, type);
if (index < 0) return false;
return (index < size());
} catch (IllegalArgumentException e) {
return false;
}
}
/**
* Returns the value at the given standard value order index as a double.
*
* @param valueIndex the direct index into the underlying data.
* @throws IllegalArgumentException if index is out of bounds
*/
public abstract double get(long valueIndex);
/**
* Returns the value at the given standard value order index as a float.
*
* @param valueIndex the direct index into the underlying data.
* @throws IllegalArgumentException if index is out of bounds
*/
public abstract float getFloat(long valueIndex);
static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
long valueIndex = 0;
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i))
throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds");
valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i];
}
return valueIndex;
}
static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) {
long valueIndex = 0;
for (int i = 0, size = address.size(); i < size; i++) {
long label = address.numericLabel(i);
if (label >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
valueIndex += sizes.productOfDimensionsAfter(i) * label;
}
return valueIndex;
}
void throwOnIncompatibleType(TensorType type) {
if ( ! this.type().isRenamableTo(type))
throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
": Types are not compatible");
}
@Override
public TensorType type() { return type; }
@Override
public abstract IndexedTensor withType(TensorType type);
public DimensionSizes dimensionSizes() { return dimensionSizes; }
public long[] shape() {
long[] result = new long[dimensionSizes.dimensions()];
for (int i = 0; i < result.length; ++i) {
result[i] = dimensionSizes.size(i);
}
return result;
}
@Override
public Map cells() {
if (dimensionSizes.dimensions() == 0)
return Map.of(TensorAddress.of(), get(0));
ImmutableMap.Builder builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size());
for (long i = 0; i < size(); i++) {
indexes.next();
builder.put(indexes.toAddress(), get(i));
}
return builder.build();
}
@Override
public Tensor remove(Set addresses) {
throw new IllegalArgumentException("Remove is not supported for indexed tensors");
}
@Override
public String toString() {
return toString(true, true);
}
@Override
public String toString(boolean withType, boolean shortForms) {
return toString(withType, shortForms, Long.MAX_VALUE);
}
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
if (! shortForms || type.rank() == 0 || type.dimensions().stream().anyMatch(d -> d.size().isEmpty()))
return Tensor.toStandardString(this, withType, shortForms, maxCells);
Indexes indexes = Indexes.of(dimensionSizes);
StringBuilder b = new StringBuilder();
if (withType)
b.append(type).append(":");
indexedBlockToString(this, indexes, maxCells, b);
return b.toString();
}
static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, long maxCells, StringBuilder b) {
int index = 0;
for (; index < tensor.size() && index < maxCells; index++) {
indexes.next();
if (index > 0)
b.append(", ");
// start brackets
b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart())));
// value
switch (tensor.type().valueType()) {
case DOUBLE: b.append(tensor.get(index)); break;
case FLOAT: b.append(tensor.getFloat(index)); break;
case BFLOAT16: b.append(tensor.getFloat(index)); break;
case INT8: b.append((byte)tensor.getFloat(index)); break;
default:
throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
}
// end bracket and comma
b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd())));
}
if (index == maxCells && index < tensor.size())
b.append(", ...]");
}
@Override
public boolean equals(Object other) {
if ( ! ( other instanceof Tensor)) return false;
return Tensor.equals(this, ((Tensor)other));
}
public abstract static class Builder implements Tensor.Builder {
final TensorType type;
private Builder(TensorType type) {
this.type = type;
}
public static Builder of(TensorType type) {
if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
}
/**
* Creates a builder initialized with the given values
*
* @param type the type of the tensor to build
* @param values the initial values of the tensor. This transfers ownership of the value array - it
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, float[] values) {
if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
}
/**
* Creates a builder initialized with the given values
*
* @param type the type of the tensor to build
* @param values the initial values of the tensor. This transfers ownership of the value array - it
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, double[] values) {
if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
}
/**
* Create a builder with dimension size information for this instance. Must be one size entry per dimension,
* and, agree with the type size information when specified in the type.
* If sizes are completely specified in the type this size information is redundant.
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
return switch (type.valueType()) {
case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
};
}
/**
* Creates a builder initialized with the given values
*
* @param type the type of the tensor to build
* @param values the initial values of the tensor in the standard value order.
* This transfers ownership of the value array - it
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
return switch (type.valueType()) {
case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
};
}
/**
* Creates a builder initialized with the given values
*
* @param type the type of the tensor to build
* @param values the initial values of the tensor in the standard value order.
* This transfers ownership of the value array - it
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
return switch (type.valueType()) {
case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
};
}
private static void validateSizes(DimensionSizes sizes, int length) {
if (sizes.totalSize() != length) {
throw new IllegalArgumentException("Invalid size(" + length + ") of supplied value vector." +
" Type specifies that size should be " + sizes.totalSize());
}
}
private static void validate(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException(sizes.dimensions() +
" is the wrong number of dimensions for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
sizes.size(i) +
" but cannot be larger than " + size.get() + " in " + type);
}
}
public abstract Builder cell(double value, long ... indexes);
public abstract Builder cell(float value, long ... indexes);
@Override
public TensorType type() { return type; }
@Override
public abstract IndexedTensor build();
}
public interface DirectIndexBuilder {
TensorType type();
/** Sets a value by its standard value order index */
void cellByDirectIndex(long index, double value);
/** Sets a value by its standard value order index */
void cellByDirectIndex(long index, float value);
}
/** A bound builder can create the double array directly */
public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder {
private final DimensionSizes sizes;
private static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < type.dimensions().size(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
BoundBuilder(TensorType type, DimensionSizes sizes) {
super(type);
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
}
public BoundBuilder fill(float[] values) {
long index = 0;
for (float value : values) {
cellByDirectIndex(index++, value);
}
return this;
}
public BoundBuilder fill(double[] values) {
long index = 0;
for (double value : values) {
cellByDirectIndex(index++, value);
}
return this;
}
DimensionSizes sizes() { return sizes; }
}
/**
* A builder used when we don't know the size of the dimensions up front.
* All values is all dimensions must be specified.
*/
private static class UnboundBuilder extends Builder {
/** List of List or Double */
private List |