com.yahoo.tensor.functions.Concat Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
*
* @author bratseth
*/
public class Concat extends PrimitiveTensorFunction {
enum DimType { common, separate, concat }
private final TensorFunction argumentA, argumentB;
private final String dimension;
public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(dimension, "The dimension cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.dimension = dimension;
}
@Override
public List> arguments() { return List.of(argumentA, argumentB); }
@Override
public TensorFunction withArguments(List> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
return new Concat<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
}
@Override
public String toString(ToStringContext context) {
return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) +
", " + context.resolveBinding(dimension) + ")";
}
@Override
public int hashCode() { return Objects.hash("concat", argumentA, argumentB, dimension); }
@Override
public TensorType type(TypeContext context) {
return TypeResolver.concat(argumentA.type(context), argumentB.type(context), context.resolveBinding(dimension));
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
return oldEvaluate(a, b);
}
var helper = new Helper(a, b, dimension);
return helper.result;
}
private Tensor oldEvaluate(Tensor a, Tensor b) {
TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
a = ensureIndexedDimension(dimension, a, concatType.valueType());
b = ensureIndexedDimension(dimension, b, concatType.valueType());
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
for (Iterator ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
IndexedTensor.SubspaceIterator iaSubspace = ia.next();
TensorAddress aAddress = iaSubspace.address();
for (Iterator ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
IndexedTensor.SubspaceIterator ibSubspace = ib.next();
while (ibSubspace.hasNext()) {
Tensor.Cell bCell = ibSubspace.next();
TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
concatType, offset, dimension);
if (combinedAddress == null) continue; // incompatible
builder.cell(combinedAddress, bCell.getValue());
}
iaSubspace.reset();
}
}
}
private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
Optional dimension = tensor.type().dimension(dimensionName);
if ( dimension.isPresent() ) {
if ( ! dimension.get().isIndexed())
throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
"' requires that dimension to be indexed or absent, " +
"but got a tensor with type " + tensor.type());
return tensor;
}
else { // extend tensor with this dimension
if (tensor.type().hasMappedDimensions())
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
.indexed(dimensionName, 1)
.build())
.cell(1,0)
.build();
return tensor.multiply(unitTensor);
}
}
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
for (int i = 0; i < concatSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
if (currentDimension.equals(concatDimension))
concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
concatSizes.set(i, Math.min(aSize, bSize));
else
concatSizes.set(i, Math.max(aSize, bSize));
}
return concatSizes.build();
}
/**
* Combine two addresses, adding the offset to the concat dimension
*
* @return the combined address or null if the addresses are incompatible
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType concatType, long concatOffset, String concatDimension) {
long[] combinedLabels = new long[concatType.dimensions().size()];
Arrays.fill(combinedLabels, Tensor.invalidIndex);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
if ( ! compatible) return null;
return TensorAddress.of(combinedLabels);
}
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
* That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
// TODO: Stolen from join
private int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.invalidIndex);
return toIndexes;
}
/**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (concatDimension == toIndex) {
to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != from.numericLabel(i)) return false;
to[toIndex] = from.numericLabel(i);
}
}
return true;
}
static class CellVector {
ArrayList values = new ArrayList<>();
void setValue(int ccDimIndex, double value) {
while (values.size() <= ccDimIndex) {
values.add(0.0);
}
values.set(ccDimIndex, value);
}
}
static class CellVectorMap {
Map map = new HashMap<>();
CellVector lookupCreate(TensorAddress addr) {
return map.computeIfAbsent(addr, k -> new CellVector());
}
}
static class CellVectorMapMap {
Map map = new HashMap<>();
CellVectorMap lookupCreate(TensorAddress addr) {
return map.computeIfAbsent(addr, k -> new CellVectorMap());
}
}
static class SplitHow {
List handleDims = new ArrayList<>();
long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); }
long numSeparate() { return handleDims.stream().filter(t -> (t == DimType.separate)).count(); }
}
static class ConcatPlan {
final TensorType resultType;
final String concatDimension;
SplitHow splitInfoA = new SplitHow();
SplitHow splitInfoB = new SplitHow();
enum CombineHow { left, right, both, concat }
List combineHow = new ArrayList<>();
void aOnly(String dimName) {
if (dimName.equals(concatDimension)) {
splitInfoA.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
splitInfoA.handleDims.add(DimType.separate);
combineHow.add(CombineHow.left);
}
}
void bOnly(String dimName) {
if (dimName.equals(concatDimension)) {
splitInfoB.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
splitInfoB.handleDims.add(DimType.separate);
combineHow.add(CombineHow.right);
}
}
void bothAandB(String dimName) {
if (dimName.equals(concatDimension)) {
splitInfoA.handleDims.add(DimType.concat);
splitInfoB.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
splitInfoA.handleDims.add(DimType.common);
splitInfoB.handleDims.add(DimType.common);
combineHow.add(CombineHow.both);
}
}
ConcatPlan(TensorType aType, TensorType bType, String concatDimension) {
this.resultType = TypeResolver.concat(aType, bType, concatDimension);
this.concatDimension = concatDimension;
var aDims = aType.dimensions();
var bDims = bType.dimensions();
int i = 0;
int j = 0;
while (i < aDims.size() && j < bDims.size()) {
String aName = aDims.get(i).name();
String bName = bDims.get(j).name();
int cmp = aName.compareTo(bName);
if (cmp == 0) {
bothAandB(aName);
++i;
++j;
} else if (cmp < 0) {
aOnly(aName);
++i;
} else {
bOnly(bName);
++j;
}
}
while (i < aDims.size()) {
aOnly(aDims.get(i++).name());
}
while (j < bDims.size()) {
bOnly(bDims.get(j++).name());
}
if (combineHow.size() < resultType.rank()) {
var idx = resultType.indexOfDimension(concatDimension);
combineHow.add(idx.get(), CombineHow.concat);
}
}
}
static class Helper {
ConcatPlan plan;
Tensor result;
Helper(Tensor a, Tensor b, String dimension) {
this.plan = new ConcatPlan(a.type(), b.type(), dimension);
CellVectorMapMap aData = decompose(a, plan.splitInfoA);
CellVectorMapMap bData = decompose(b, plan.splitInfoB);
this.result = merge(aData, bData);
}
static int concatDimensionSize(CellVectorMapMap data) {
Set sizes = new HashSet<>();
data.map.forEach((m, cvmap) ->
cvmap.map.forEach((e, vector) ->
sizes.add(vector.values.size())));
if (sizes.isEmpty()) {
return 1;
}
if (sizes.size() == 1) {
return sizes.iterator().next();
}
throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values");
}
TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
long[] labels = new long[plan.resultType.rank()];
int out = 0;
int m = 0;
int a = 0;
int b = 0;
for (var how : plan.combineHow) {
switch (how) {
case left -> labels[out++] = leftOnly.numericLabel(a++);
case right -> labels[out++] = rightOnly.numericLabel(b++);
case both -> labels[out++] = match.numericLabel(m++);
case concat -> labels[out++] = concatDimIdx;
default -> throw new IllegalArgumentException("cannot handle: " + how);
}
}
return TensorAddressAny.ofUnsafe(labels);
}
Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
var builder = Tensor.Builder.of(plan.resultType);
int aConcatSize = concatDimensionSize(a);
for (var entry : a.map.entrySet()) {
TensorAddress common = entry.getKey();
if (b.map.containsKey(common)) {
var lhs = entry.getValue();
var rhs = b.map.get(common);
lhs.map.forEach((leftOnly, leftCells) -> {
rhs.map.forEach((rightOnly, rightCells) -> {
for (int i = 0; i < leftCells.values.size(); i++) {
TensorAddress addr = combine(common, leftOnly, rightOnly, i);
builder.cell(addr, leftCells.values.get(i));
}
for (int i = 0; i < rightCells.values.size(); i++) {
TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
builder.cell(addr, rightCells.values.get(i));
}
});
});
}
}
return builder.build();
}
CellVectorMapMap decompose(Tensor input, SplitHow how) {
var iter = input.cellIterator();
long[] commonLabels = new long[(int)how.numCommon()];
long[] separateLabels = new long[(int)how.numSeparate()];
CellVectorMapMap result = new CellVectorMapMap();
while (iter.hasNext()) {
var cell = iter.next();
var addr = cell.getKey();
long ccDimIndex = 0;
int commonIdx = 0;
int separateIdx = 0;
for (int i = 0; i < how.handleDims.size(); i++) {
switch (how.handleDims.get(i)) {
case common -> commonLabels[commonIdx++] = addr.numericLabel(i);
case separate -> separateLabels[separateIdx++] = addr.numericLabel(i);
case concat -> ccDimIndex = addr.numericLabel(i);
default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i));
}
}
TensorAddress commonAddr = TensorAddressAny.ofUnsafe(commonLabels);
TensorAddress separateAddr = TensorAddressAny.ofUnsafe(separateLabels);
result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
}
return result;
}
}
}