com.yahoo.tensor.TensorParser Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
/**
* @author bratseth
*/
class TensorParser {
static Tensor tensorFrom(String tensorString, Optional explicitType) {
try {
return tensorFromBody(tensorString, explicitType);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" +
(explicitType.isPresent() ? " of type " + explicitType.get() : ""),
e);
}
}
static Tensor tensorFromBody(String tensorString, Optional explicitType) {
Optional type;
String valueString;
// The order in which dimensions are written in the type string.
// This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than
// the natural order of the tensor.
List dimensionOrder;
tensorString = tensorString.trim();
if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
dimensionOrder = new ArrayList<>();
TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder);
if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + explicitType.get());
type = Optional.of(typeFromString);
valueString = tensorString.substring(colonIndex + 1);
}
else {
type = explicitType;
valueString = tensorString;
dimensionOrder = null;
}
valueString = valueString.trim();
if (valueString.startsWith("{") &&
(type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) {
return tensorFromMappedValueString(valueString, type);
}
else if (valueString.startsWith("{")) {
return tensorFromMixedValueString(valueString, type, dimensionOrder);
}
else if (valueString.startsWith("[")) {
return tensorFromDenseValueString(valueString, type, dimensionOrder);
}
else {
var t = maybeFromBinaryValueString(valueString, type, dimensionOrder);
if (t.isPresent()) { return t.get(); }
if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString +
"') where type " + explicitType.get() + " is required");
try {
return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build();
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" +
tensorString + "'");
}
}
}
/** Derives the tensor type from the first address string in the given tensor string */
private static TensorType typeFromMappedValueString(String valueString) {
TensorType.Builder builder = new TensorType.Builder();
MappedValueTypeParser parser = new MappedValueTypeParser(valueString, builder);
parser.parse();
return builder.build();
}
private static Tensor tensorFromMappedValueString(String valueString, Optional type) {
try {
valueString = valueString.trim();
Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromMappedValueString(valueString)));
MappedValueParser parser = new MappedValueParser(valueString, builder);
parser.parse();
return builder.build();
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
}
private static Tensor tensorFromMixedValueString(String valueString,
Optional type,
List dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
if (type.get().dimensions().stream().filter(d -> d.isMapped()).count() > 1)
throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " +
"but got " + type.get());
if (! MixedValueParser.findMappedDimension(type.get()).isPresent())
throw new IllegalArgumentException("No suitable dimension in " + type.get() + " for parsing a tensor on " +
"the mixed form: Should have one mapped dimension");
try {
valueString = valueString.trim();
if ( ! valueString.startsWith("{") && valueString.endsWith("}"))
throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
Tensor.Builder builder = Tensor.Builder.of(type.get());
MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder);
parser.parse();
return builder.build();
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
}
private static Optional maybeFromBinaryValueString(
String valueString,
Optional optType,
List dimensionOrder)
{
if (optType.isEmpty() || dimensionOrder != null) {
return Optional.empty();
}
var type = optType.get();
long sz = 1;
for (var d : type.dimensions()) {
sz *= d.size().orElse(0L);
}
if (sz == 0
|| type.dimensions().isEmpty()
|| valueString.length() < sz * 2
|| valueString.chars().anyMatch(ch -> (Character.digit(ch, 16) == -1)))
{
return Optional.empty();
}
try {
double[] values = decodeHexString(valueString, type.valueType());
if (values.length != sz) {
return Optional.empty();
}
var builder = IndexedTensor.Builder.of(type);
var dib = (IndexedTensor.DirectIndexBuilder) builder;
for (int i = 0; i < sz; ++i) {
dib.cellByDirectIndex(i, values[i]);
}
return Optional.of(builder.build());
} catch (IllegalArgumentException e) {
return Optional.empty();
}
}
private static Tensor tensorFromDenseValueString(String valueString,
Optional type,
List dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type.get());
if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) {
new UnboundDenseValueParser(valueString, builder).parse();
return checkBoundDimensionSizes(builder.build());
}
new DenseValueParser(valueString, dimensionOrder, (IndexedTensor.BoundBuilder) builder).parse();
return builder.build();
}
private static Tensor checkBoundDimensionSizes(IndexedTensor tensor) {
TensorType type = tensor.type();
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (dimension.size().isPresent() && dimension.size().get() != tensor.dimensionSizes().size(i)) {
throw new IllegalArgumentException("Unexpected size " + tensor.dimensionSizes().size(i) +
" for dimension " + dimension.name() + " for type " + type);
}
}
return tensor;
}
private static abstract class ValueParser {
protected final String string;
protected int position = 0;
protected ValueParser(String string) {
this.string = string;
}
protected void skipSpace() {
while (position < string.length() && Character.isWhitespace(string.charAt(position)))
position++;
}
protected void consume(char character) {
skipSpace();
if (position >= string.length())
throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
"' but got the end of the string");
if ( string.charAt(position) != character)
throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
"' but got '" + string.charAt(position) + "'");
position++;
}
protected String consumeIdentifier() {
int endIdentifier = nextStopCharIndex(position, string);
String identifier = string.substring(position, endIdentifier);
position = endIdentifier;
return identifier;
}
protected String consumeLabel() {
if (consumeOptional('\'')) {
int endQuote = string.indexOf('\'', position);
if (endQuote < 0)
throw new IllegalArgumentException("At value position " + position +
": A label quoted by a tick (') must end by another tick");
String label = string.substring(position, endQuote);
position = endQuote + 1;
return label;
}
else if (consumeOptional('"')) {
int endQuote = string.indexOf('"', position);
if (endQuote < 0)
throw new IllegalArgumentException("At value position " + position +
": A label quoted by a double quote (\") must end by another double quote");
String label = string.substring(position, endQuote);
position = endQuote + 1;
return label;
}
else {
return consumeIdentifier();
}
}
protected Number consumeNumber(TensorType.Value cellValueType) {
skipSpace();
int nextNumberEnd = nextStopCharIndex(position, string);
try {
String cellValueString = string.substring(position, nextNumberEnd);
try {
return switch (cellValueType) {
case DOUBLE -> Double.parseDouble(cellValueString);
case FLOAT -> Float.parseFloat(cellValueString);
case BFLOAT16 -> Float.parseFloat(cellValueString);
case INT8 -> Float.parseFloat(cellValueString);
default -> throw new IllegalArgumentException(cellValueType + " is not supported");
};
} catch (NumberFormatException e) {
throw new IllegalArgumentException("At value position " + position + ": '" +
cellValueString + "' is not a valid " + cellValueType);
}
}
finally {
position = nextNumberEnd;
}
}
protected boolean consumeOptional(char character) {
skipSpace();
if (position >= string.length())
return false;
if ( string.charAt(position) != character)
return false;
position++;
return true;
}
protected int nextStopCharIndex(int position, String valueString) {
while (position < valueString.length()) {
if (Character.isWhitespace(valueString.charAt(position))) return position;
if (valueString.charAt(position) == ',') return position;
if (valueString.charAt(position) == ']') return position;
if (valueString.charAt(position) == '}') return position;
if (valueString.charAt(position) == ':') return position;
position++;
}
throw new IllegalArgumentException("Malformed tensor string '" + valueString +
"': Expected a ',', ']' or '}', ':' after position " + position);
}
}
/** A single-use dense tensor string parser */
private static class DenseValueParser extends ValueParser {
private final IndexedTensor.DirectIndexBuilder builder;
private final IndexedTensor.Indexes indexes;
private final boolean hasInnerStructure;
public DenseValueParser(String string,
List dimensionOrder,
IndexedTensor.DirectIndexBuilder builder) {
super(string);
this.builder = builder;
indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder);
hasInnerStructure = hasInnerStructure(string);
}
public void parse() {
if (!hasInnerStructure)
consume('[');
while (indexes.hasNext()) {
indexes.next();
for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++)
consume('[');
consumeNumber();
for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++)
consume(']');
if (indexes.hasNext())
consume(',');
}
if (!hasInnerStructure)
consume(']');
}
public int position() { return position; }
/** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */
private static boolean hasInnerStructure(String valueString) {
valueString = valueString.trim();
valueString = valueString.substring(1);
int firstLeftBracket = valueString.indexOf('[');
return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']');
}
protected void consumeNumber() {
Number number = consumeNumber(builder.type().valueType());
switch (builder.type().valueType()) {
case DOUBLE -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double) number);
case FLOAT -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number);
case BFLOAT16 -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number);
case INT8 -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number);
}
}
}
/**
* Parses unbound tensor short forms - e.g. tensor(x[],y[]):[[1,2,3],[4,5,6]]
*/
private static class UnboundDenseValueParser extends ValueParser {
private final IndexedTensor.Builder builder;
private final long[] indexes;
public UnboundDenseValueParser(String string, IndexedTensor.Builder builder) {
super(string);
this.builder = builder;
this.indexes = new long[builder.type().dimensions().size()];
}
public void parse() {
consumeList(0);
}
private void consumeList(int dimension) {
consume('[');
indexes[dimension] = 0;
while ( ! atListEnd() ) {
if (isInnerMostDimension(dimension)) {
consumeNumber();
} else {
consumeList(dimension + 1);
}
indexes[dimension]++;
consumeOptional(',');
}
consume(']');
}
private void consumeNumber() {
Number number = consumeNumber(builder.type().valueType());
switch (builder.type().valueType()) {
case DOUBLE -> builder.cell((Double) number, indexes);
case FLOAT -> builder.cell((Float) number, indexes);
case BFLOAT16 -> builder.cell((Float) number, indexes);
case INT8 -> builder.cell((Float) number, indexes);
}
}
private boolean isInnerMostDimension(int dimension) {
return dimension == (indexes.length - 1);
}
protected boolean atListEnd() {
skipSpace();
if (position >= string.length()) {
throw new IllegalArgumentException("At value position " + position + ": Expected a ']'" +
" but got the end of the string");
}
return string.charAt(position) == ']';
}
}
/**
* Parses mixed tensor short forms {a:[1,2], ...} AND 1d mapped tensor short form {a:b, ...}.
*/
private static class MixedValueParser extends ValueParser {
private final Tensor.Builder builder;
private final List dimensionOrder;
public MixedValueParser(String string, List dimensionOrder, Tensor.Builder builder) {
super(string);
this.dimensionOrder = dimensionOrder;
this.builder = builder;
}
private void parse() {
TensorType.Dimension mappedDimension = findMappedDimension().get();
TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension));
if (dimensionOrder != null)
dimensionOrder.remove(mappedDimension.name());
skipSpace();
consume('{');
skipSpace();
while (position + 1 < string.length()) {
String label = consumeLabel();
consume(':');
TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build();
if (builder.type().rank() > 1)
parseDenseSubspace(mappedAddress, dimensionOrder);
else
consumeNumber(mappedAddress);
if ( ! consumeOptional(','))
consume('}');
skipSpace();
}
}
private Optional findMappedDimension() {
return findMappedDimension(builder.type());
}
static Optional findMappedDimension(TensorType type) {
Optional mappedDimension = type.dimensions().stream().filter(TensorType.Dimension::isMapped).findAny();
if (mappedDimension.isPresent()) return Optional.of(mappedDimension.get());
if (type.rank() == 1 && type.dimensions().get(0).size().isEmpty())
return Optional.of(type.dimensions().get(0));
return Optional.empty();
}
private void parseDenseSubspace(TensorAddress mappedAddress, List denseDimensionOrder) {
DenseValueParser denseParser = new DenseValueParser(string.substring(position),
denseDimensionOrder,
((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(mappedAddress));
denseParser.parse();
position += denseParser.position();
}
private void consumeNumber(TensorAddress address) {
Number number = consumeNumber(builder.type().valueType());
switch (builder.type().valueType()) {
case DOUBLE -> builder.cell(address, (Double) number);
case FLOAT -> builder.cell(address, (Float) number);
case BFLOAT16 -> builder.cell(address, (Float) number);
case INT8 -> builder.cell(address, (Float) number);
}
}
}
private static class MappedValueParser extends ValueParser {
private final Tensor.Builder builder;
public MappedValueParser(String string, Tensor.Builder builder) {
super(string);
this.builder = builder;
}
private void parse() {
consume('{');
skipSpace();
while (position + 1 < string.length()) {
TensorAddress address = consumeLabels();
if ( ! address.isEmpty())
consume(':');
else
consumeOptional(':');
int valueEnd = string.indexOf(',', position);
if (valueEnd < 0) { // last value
valueEnd = string.indexOf('}', position);
if (valueEnd < 0)
throw new IllegalArgumentException("A mapped tensor string must end by '}'");
}
TensorType.Value cellValueType = builder.type().valueType();
String cellValueString = string.substring(position, valueEnd).trim();
try {
switch (cellValueType) {
case DOUBLE -> builder.cell(address, Double.parseDouble(cellValueString));
case FLOAT -> builder.cell(address, Float.parseFloat(cellValueString));
case BFLOAT16 -> builder.cell(address, Float.parseFloat(cellValueString));
case INT8 -> builder.cell(address, Float.parseFloat(cellValueString));
default -> throw new IllegalArgumentException(cellValueType + " is not supported");
}
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" +
cellValueString + "' is not a valid " + cellValueType);
}
position = valueEnd+1;
skipSpace();
}
}
/** Creates a tensor address from a string on the form {dimension1:label1,dimension2:label2,...} */
private TensorAddress consumeLabels() {
TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type());
if ( ! consumeOptional('{')) return addressBuilder.build();
while ( ! consumeOptional('}')) {
String dimension = consumeIdentifier();
consume(':');
String label = consumeLabel();
addressBuilder.add(dimension, label);
consumeOptional(',');
}
return addressBuilder.build();
}
}
/** Parses a tensor *value* into a type */
private static class MappedValueTypeParser extends ValueParser {
private final TensorType.Builder builder;
public MappedValueTypeParser(String string, TensorType.Builder builder) {
super(string);
this.builder = builder;
}
/** Derives the tensor type from the first address string in the given tensor string */
public void parse() {
consume('{');
consumeLabels();
}
/** Consumes a mapped address into a set of the type builder */
private void consumeLabels() {
if ( ! consumeOptional('{')) return;
while ( ! consumeOptional('}')) {
String dimension = consumeIdentifier();
consume(':');
consumeLabel();
builder.mapped(dimension);
consumeOptional(',');
}
}
}
}