com.yahoo.tensor.DimensionSizes Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
import java.util.Arrays;
/**
* The sizes of a set of dimensions.
*
* @author bratseth
*/
public final class DimensionSizes {
private final long[] sizes;
private final long[] productOfSizesFromHereOn;
private final long totalSize;
private DimensionSizes(Builder builder) {
this.sizes = builder.sizes;
builder.sizes = null; // invalidate builder to avoid copying the array
this.productOfSizesFromHereOn = new long[sizes.length];
long product = 1;
for (int i = sizes.length; i-- > 0; ) {
productOfSizesFromHereOn[i] = product;
product *= sizes[i];
}
this.totalSize = product;
}
/**
* Create sizes from a type containing bound indexed dimensions only.
*
* @throws IllegalStateException if the type contains dimensions which are not bound and indexed
*/
public static DimensionSizes of(TensorType type) {
Builder b = new Builder(type.rank());
for (int i = 0; i < type.rank(); i++) {
if ( type.dimensions().get(i).type() != TensorType.Dimension.Type.indexedBound)
throw new IllegalArgumentException(type + " contains dimensions without a size");
b.set(i, type.dimensions().get(i).size().get());
}
return b.build();
}
/**
* Returns the length of this in the nth dimension
*
* @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one
*/
public long size(int dimensionIndex) {
if (dimensionIndex < 0 || dimensionIndex >= sizes.length)
throw new IllegalArgumentException("Illegal dimension index " + dimensionIndex +
": This has " + sizes.length + " dimensions");
return sizes[dimensionIndex];
}
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
/** Returns the product of the sizes of this */
public long totalSize() {
return totalSize;
}
long productOfDimensionsAfter(int afterIndex) {
return productOfSizesFromHereOn[afterIndex];
}
@Override
public boolean equals(Object o) {
if (o == this) return true;
if (!(o instanceof DimensionSizes)) return false;
return Arrays.equals(((DimensionSizes) o).sizes, this.sizes);
}
@Override
public int hashCode() { return Arrays.hashCode(sizes); }
/**
* Builder of a set of dimension sizes.
* Dimensions whose size is not set before building will get size 0.
*/
public final static class Builder {
private int dimensionIndex = 0;
private long[] sizes;
public Builder(int dimensions) {
this.sizes = new long[dimensions];
}
public Builder set(int dimensionIndex, long size) {
sizes[dimensionIndex] = size;
return this;
}
public Builder add(long size) {
sizes[dimensionIndex++] = size;
return this;
}
/**
* Returns the length of this in the nth dimension
*
* @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one
*/
public long size(int dimensionIndex) {
if (dimensionIndex <0 || dimensionIndex >= sizes.length)
throw new IllegalArgumentException("Illegal dimension index " + dimensionIndex +
": This has " + sizes.length + " dimensions");
return sizes[dimensionIndex];
}
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
/** Build this. This builder becomes invalid after calling this. */
public DimensionSizes build() { return new DimensionSizes(this); }
}
}