com.yahoo.tensor.functions.Join Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
* The join tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
*
* @author bratseth
*/
public class Join extends PrimitiveTensorFunction {
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(combinator, "The combinator function cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.combinator = combinator;
}
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
try {
return TypeResolver.join(a, b);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
}
}
public DoubleBinaryOperator combinator() { return combinator; }
@Override
public List> arguments() { return List.of(argumentA, argumentB); }
@Override
public TensorFunction withArguments(List> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
return new Join<>(arguments.get(0), arguments.get(1), combinator);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
@Override
public String toString(ToStringContext context) {
return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
}
@Override
public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); }
@Override
public TensorType type(TypeContext context) {
return outputType(argumentA.type(context), argumentB.type(context));
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = outputType(a.type(), b.type());
return evaluate(a, b, joinedType, combinator);
}
static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
// Choose join algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType, combinator);
else if (a.type().dimensions().containsAll(b.type().dimensions()))
return subspaceJoin(b, a, joinedType, true, combinator);
else if (b.type().dimensions().containsAll(a.type().dimensions()))
return subspaceJoin(a, b, joinedType, false, combinator);
else
return generalJoin(a, b, joinedType, combinator);
}
private static boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator aIterator = a.valueIterator();
Iterator bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
/** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */
private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator i = a.cellIterator(); i.hasNext(); ) {
Map.Entry aCell = i.next();
var key = aCell.getKey();
Double bVal = b.getAsDouble(key);
if (bVal != null) {
builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
}
}
return builder.build();
}
/** Join a tensor into a superspace */
private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator);
else
return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator);
}
private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
// Find dimensions which are only in the supertype
Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
for (Iterator i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder, combinator);
}
return builder.build();
}
private static void joinSubspaces(Iterator subspace, long subspaceSize,
Iterator superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder,
DoubleBinaryOperator combinator) {
int joinedLength = (int)Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
builder.cell(supercell, combinator.applyAsDouble(supercell.getValue(), subspace.next()));
}
} else {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
builder.cell(supercell, combinator.applyAsDouble(subspace.next(), supercell.getValue()));
}
}
}
private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
for (int i = 0; i < builder.dimensions(); i++) {
String dimensionName = joinedType.dimensions().get(i).name();
Optional aIndex = a.type().indexOfDimension(dimensionName);
Optional bIndex = b.type().indexOfDimension(dimensionName);
if (aIndex.isPresent() && bIndex.isPresent())
builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get())));
else if (aIndex.isPresent())
builder.set(i, a.dimensionSizes().size(aIndex.get()));
else if (bIndex.isPresent())
builder.set(i, b.dimensionSizes().size(bIndex.get()));
}
return builder.build();
}
private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry supercell = i.next();
TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
Double subspaceValue = subspace.getAsDouble(subaddress);
if (subspaceValue != null) {
builder.cell(supercell.getKey(),
reversedArgumentOrder
? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
: combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
}
return builder.build();
}
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];
for (int i = 0; i < subtype.dimensions().size(); i++)
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
/** Slow join which works for any two tensors */
private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator);
else
return mappedHashJoin(a, b, joinedType, combinator);
}
private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
DimensionSizes joinedSize = joinedSize(joinedType, a, b);
Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator);
return builder.build();
}
private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
Set sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
Set dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
PartialAddress matchingBCells = sharedDimensionSize > 0
? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
: empty;
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
builder.cell(joinedAddress, joinedValue);
}
}
}
}
private static final PartialAddress empty = new PartialAddress.Builder(0).build();
private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
Set retainDimensions, int sharedDimensionSize) {
PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
for (int i = 0; i < addressType.dimensions().size(); i++) {
String dimension = addressType.dimensions().get(i).name();
if (retainDimensions.contains(dimension))
builder.add(dimension, address.numericLabel(i));
}
return builder.build();
}
/** Returns the sizes from the joined sizes which are present in the type argument */
private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
int dimensionIndex = 0;
for (int i = 0; i < joinedType.dimensions().size(); i++) {
if (type.dimensionNames().contains(joinedType.dimensions().get(i).name()))
builder.set(dimensionIndex++, joinedSizes.size(i));
}
return builder.build();
}
private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator aIterator = a.cellIterator(); aIterator.hasNext(); ) {
Map.Entry aCell = aIterator.next();
for (Iterator bIterator = b.cellIterator(); bIterator.hasNext(); ) {
Map.Entry bCell = bIterator.next();
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
bCell.getKey(), bToIndexes, joinedType);
if (combinedAddress == null) continue; // not combinable
builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
}
}
return builder.build();
}
private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
TensorType commonDimensionType = commonDimensions(a, b);
if (commonDimensionType.dimensions().isEmpty()) {
return mappedGeneralJoin(a, b, joinedType, combinator); // fallback
}
boolean swapTensors = a.size() > b.size();
if (swapTensors) {
Tensor temp = a;
a = b;
b = temp;
}
// Map dimension indexes to common and joined type
int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type());
int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type());
int[] aIndexesInJoined = mapIndexes(a.type(), joinedType);
int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);
// Iterate once through the smaller tensor and construct a hash map for common dimensions
Map> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
// Iterate once through the larger tensor and use the hash map to find joinable cells
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell bCell = cellIterator.next();
TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
bCell.getKey(), bIndexesInJoined, joinedType);
if (combinedAddress == null) continue; // not combinable
double combinedValue = swapTensors ?
combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) :
combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
builder.cell(combinedAddress, combinedValue);
}
}
return builder.build();
}
/**
* Returns an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
* That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
static int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name());
return toIndexes;
}
private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
long[] joinedLabels = new long[joinedType.dimensions().size()];
Arrays.fill(joinedLabels, Tensor.invalidIndex);
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
return TensorAddressAny.ofUnsafe(joinedLabels);
}
/**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
private static boolean mapContent(TensorAddress from, long[] to, int[] indexMap) {
for (int i = 0, size = from.size(); i < size; i++) {
int toIndex = indexMap[i];
long label = from.numericLabel(i);
if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != label)
return false;
to[toIndex] = label;
}
return true;
}
/** Returns common dimension of a and b as a new tensor type */
private static TensorType commonDimensions(Tensor a, Tensor b) {
TensorType aType = a.type();
TensorType bType = b.type();
TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
TensorType.Dimension bDim = bType.dimensions().get(j);
if (aDim.equals(bDim)) {
typeBuilder.set(bDim);
}
}
}
return typeBuilder.build();
}
}