Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.yahoo.tensor.functions.Join Maven / Gradle / Ivy
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
* The join tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
*
* @author bratseth
*/
public class Join extends PrimitiveTensorFunction {
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(combinator, "The combinator function cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.combinator = combinator;
}
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
try {
return TypeResolver.join(a, b);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
}
}
public DoubleBinaryOperator combinator() { return combinator; }
@Override
public List> arguments() { return List.of(argumentA, argumentB); }
@Override
public TensorFunction withArguments(List> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
return new Join<>(arguments.get(0), arguments.get(1), combinator);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
@Override
public String toString(ToStringContext context) {
return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
}
@Override
public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); }
@Override
public TensorType type(TypeContext context) {
return outputType(argumentA.type(context), argumentB.type(context));
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = outputType(a.type(), b.type());
return evaluate(a, b, joinedType, combinator);
}
static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
// Choose join algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType, combinator);
else if (a.type().dimensions().containsAll(b.type().dimensions()))
return subspaceJoin(b, a, joinedType, true, combinator);
else if (b.type().dimensions().containsAll(a.type().dimensions()))
return subspaceJoin(a, b, joinedType, false, combinator);
else
return generalJoin(a, b, joinedType, combinator);
}
private static boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator aIterator = a.valueIterator();
Iterator bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
/** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */
private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator i = a.cellIterator(); i.hasNext(); ) {
Map.Entry aCell = i.next();
var key = aCell.getKey();
Double bVal = b.getAsDouble(key);
if (bVal != null) {
builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
}
}
return builder.build();
}
/** Join a tensor into a superspace */
private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator);
else
return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator);
}
private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
// Find dimensions which are only in the supertype
Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
for (Iterator i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder, combinator);
}
return builder.build();
}
private static void joinSubspaces(Iterator subspace, long subspaceSize,
Iterator superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder,
DoubleBinaryOperator combinator) {
int joinedLength = (int)Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
builder.cell(supercell, combinator.applyAsDouble(supercell.getValue(), subspace.next()));
}
} else {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
builder.cell(supercell, combinator.applyAsDouble(subspace.next(), supercell.getValue()));
}
}
}
private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
for (int i = 0; i < builder.dimensions(); i++) {
String dimensionName = joinedType.dimensions().get(i).name();
Optional aIndex = a.type().indexOfDimension(dimensionName);
Optional bIndex = b.type().indexOfDimension(dimensionName);
if (aIndex.isPresent() && bIndex.isPresent())
builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get())));
else if (aIndex.isPresent())
builder.set(i, a.dimensionSizes().size(aIndex.get()));
else if (bIndex.isPresent())
builder.set(i, b.dimensionSizes().size(bIndex.get()));
}
return builder.build();
}
private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry supercell = i.next();
TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
Double subspaceValue = subspace.getAsDouble(subaddress);
if (subspaceValue != null) {
builder.cell(supercell.getKey(),
reversedArgumentOrder
? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
: combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
}
return builder.build();
}
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];
for (int i = 0; i < subtype.dimensions().size(); i++)
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
/** Slow join which works for any two tensors */
private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator);
else
return mappedHashJoin(a, b, joinedType, combinator);
}
private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
DimensionSizes joinedSize = joinedSize(joinedType, a, b);
Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator);
return builder.build();
}
private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
Set sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
Set dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
PartialAddress matchingBCells = sharedDimensionSize > 0
? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
: empty;
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
builder.cell(joinedAddress, joinedValue);
}
}
}
}
private static final PartialAddress empty = new PartialAddress.Builder(0).build();
private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
Set retainDimensions, int sharedDimensionSize) {
PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
for (int i = 0; i < addressType.dimensions().size(); i++) {
String dimension = addressType.dimensions().get(i).name();
if (retainDimensions.contains(dimension))
builder.add(dimension, address.numericLabel(i));
}
return builder.build();
}
/** Returns the sizes from the joined sizes which are present in the type argument */
private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
int dimensionIndex = 0;
for (int i = 0; i < joinedType.dimensions().size(); i++) {
if (type.dimensionNames().contains(joinedType.dimensions().get(i).name()))
builder.set(dimensionIndex++, joinedSizes.size(i));
}
return builder.build();
}
private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator aIterator = a.cellIterator(); aIterator.hasNext(); ) {
Map.Entry aCell = aIterator.next();
for (Iterator bIterator = b.cellIterator(); bIterator.hasNext(); ) {
Map.Entry bCell = bIterator.next();
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
bCell.getKey(), bToIndexes, joinedType);
if (combinedAddress == null) continue; // not combinable
builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
}
}
return builder.build();
}
private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
TensorType commonDimensionType = commonDimensions(a, b);
if (commonDimensionType.dimensions().isEmpty()) {
return mappedGeneralJoin(a, b, joinedType, combinator); // fallback
}
boolean swapTensors = a.size() > b.size();
if (swapTensors) {
Tensor temp = a;
a = b;
b = temp;
}
// Map dimension indexes to common and joined type
int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type());
int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type());
int[] aIndexesInJoined = mapIndexes(a.type(), joinedType);
int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);
// Iterate once through the smaller tensor and construct a hash map for common dimensions
Map> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
// Iterate once through the larger tensor and use the hash map to find joinable cells
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell bCell = cellIterator.next();
TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
bCell.getKey(), bIndexesInJoined, joinedType);
if (combinedAddress == null) continue; // not combinable
double combinedValue = swapTensors ?
combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) :
combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
builder.cell(combinedAddress, combinedValue);
}
}
return builder.build();
}
/**
* Returns an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
* That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
static int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name());
return toIndexes;
}
private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
long[] joinedLabels = new long[joinedType.dimensions().size()];
Arrays.fill(joinedLabels, Tensor.invalidIndex);
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
return TensorAddressAny.ofUnsafe(joinedLabels);
}
/**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
private static boolean mapContent(TensorAddress from, long[] to, int[] indexMap) {
for (int i = 0, size = from.size(); i < size; i++) {
int toIndex = indexMap[i];
long label = from.numericLabel(i);
if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != label)
return false;
to[toIndex] = label;
}
return true;
}
/** Returns common dimension of a and b as a new tensor type */
private static TensorType commonDimensions(Tensor a, Tensor b) {
TensorType aType = a.type();
TensorType bType = b.type();
TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
TensorType.Dimension bDim = bType.dimensions().get(j);
if (aDim.equals(bDim)) {
typeBuilder.set(bDim);
}
}
}
return typeBuilder.build();
}
}