com.yahoo.tensor.functions.ReduceJoin Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vespajlib Show documentation
Show all versions of vespajlib Show documentation
Library for use in Java components of Vespa. Shared code which do
not fit anywhere else.
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;
/**
* An optimization for tensor expressions where a join immediately follows a
* reduce. Evaluating this as one operation is significantly more efficient
* than evaluating each separately.
*
* This implementation optimizes the case where the reduce is done on the same
* dimensions as the join. A particularly efficient evaluation is done if there
* is one common dimension that is joined and reduced on, which is a common
* case as it covers vector and matrix like multiplications.
*
* @author lesters
*/
public class ReduceJoin extends CompositeTensorFunction {
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
private final Reduce.Aggregator aggregator;
private final List dimensions;
public ReduceJoin(Reduce reduce, Join join) {
this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
}
public ReduceJoin(TensorFunction argumentA,
TensorFunction argumentB,
DoubleBinaryOperator combinator,
Reduce.Aggregator aggregator,
List dimensions) {
this.argumentA = argumentA;
this.argumentB = argumentB;
this.combinator = combinator;
this.aggregator = aggregator;
this.dimensions = List.copyOf(dimensions);
}
@Override
public List> arguments() {
return List.of(argumentA, argumentB);
}
@Override
public TensorFunction withArguments(List> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size());
return new ReduceJoin<>(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
Join join = new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
return new Reduce<>(join, aggregator, dimensions);
}
@Override
public final Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
if (canOptimize(a, b)) {
return evaluate((IndexedTensor)a, (IndexedTensor)b, joinedType);
}
return Reduce.evaluate(Join.evaluate(a, b, joinedType, combinator), dimensions, aggregator);
}
/**
* Tests whether or not the reduce is over the join dimensions. The
* remaining logic in this class assumes this to be true.
*
* If no dimensions are given, the join must be on all tensor dimensions.
*
* @return {@code true} if the implementation can optimize evaluation
* given the two tensors.
*/
public boolean canOptimize(Tensor a, Tensor b) {
if (a.type().dimensions().isEmpty() || b.type().dimensions().isEmpty()) // TODO: support scalars
return false;
if ( ! (a instanceof IndexedTensor))
return false;
if ( ! (a.type().hasOnlyIndexedBoundDimensions()))
return false;
if ( ! (b instanceof IndexedTensor))
return false;
if ( ! (b.type().hasOnlyIndexedBoundDimensions()))
return false;
TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
if (dimensions.isEmpty()) {
if (a.type().dimensions().size() != commonDimensions.dimensions().size())
return false;
if (b.type().dimensions().size() != commonDimensions.dimensions().size())
return false;
} else if (dimensions.size() != commonDimensions.dimensions().size()) {
return false;
} else {
for (TensorType.Dimension dimension : commonDimensions.dimensions()) {
if (!dimensions.contains(dimension.name()))
return false;
}
}
return true;
}
/**
* Evaluates the reduce-join. Special handling for common cases where the
* reduce dimension is the innermost dimension in both tensors.
*/
private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
TensorType reducedType = Reduce.outputType(joinedType, dimensions);
if (reduceDimensionIsInnermost(a, b)) {
if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) {
return vectorVectorProduct(a, b, reducedType);
}
if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 2) {
return vectorMatrixProduct(a, b, reducedType, false);
}
if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 1) {
return vectorMatrixProduct(b, a, reducedType, true);
}
if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 2) {
return matrixMatrixProduct(a, b, reducedType);
}
}
return evaluateGeneral(a, b, reducedType);
}
private Tensor vectorVectorProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 1) {
throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product");
}
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
long commonSize = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
for (int ic = 0; ic < commonSize; ++ic) {
double va = a.get(ic);
double vb = b.get(ic);
agg.aggregate(combinator.applyAsDouble(va, vb));
}
builder.cellByDirectIndex(0, agg.aggregatedValue());
return builder.build();
}
private Tensor vectorMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType, boolean swapped) {
if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 2) {
throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product");
}
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
DimensionSizes sizesA = a.dimensionSizes();
DimensionSizes sizesB = b.dimensionSizes();
Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
for (int ib = 0; ib < sizesB.size(0); ++ib) {
agg.reset();
for (int ic = 0; ic < Math.min(sizesA.size(0), sizesB.size(1)); ++ic) {
double va = a.get(ic);
double vb = b.get(ib * sizesB.size(1) + ic);
double result = swapped ? combinator.applyAsDouble(vb, va) : combinator.applyAsDouble(va, vb);
agg.aggregate(result);
}
builder.cellByDirectIndex(ib, agg.aggregatedValue());
}
return builder.build();
}
private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
if ( a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) {
throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product");
}
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
DimensionSizes sizesA = a.dimensionSizes();
DimensionSizes sizesB = b.dimensionSizes();
int iaToReduced = reducedType.indexOfDimension(a.type().dimensions().get(0).name()).get();
int ibToReduced = reducedType.indexOfDimension(b.type().dimensions().get(0).name()).get();
long strideA = iaToReduced < ibToReduced ? sizesB.size(0) : 1;
long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1;
Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
for (int ia = 0; ia < sizesA.size(0); ++ia) {
for (int ib = 0; ib < sizesB.size(0); ++ib) {
agg.reset();
for (int ic = 0; ic < Math.min(sizesA.size(1), sizesB.size(1)); ++ic) {
double va = a.get(ia * sizesA.size(1) + ic);
double vb = b.get(ib * sizesB.size(1) + ic);
agg.aggregate(combinator.applyAsDouble(va, vb));
}
builder.cellByDirectIndex(ia * strideA + ib * strideB, agg.aggregatedValue());
}
}
return builder.build();
}
private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
TensorType onlyInA = Reduce.outputType(a.type(), dimensions);
TensorType onlyInB = Reduce.outputType(b.type(), dimensions);
TensorType common = dimensionsInCommon(a, b);
// pre-calculate strides for each index position
long[] stridesA = strides(a.type());
long[] stridesB = strides(b.type());
long[] stridesResult = strides(reducedType);
// mapping of dimension indexes
int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type());
int[] mapCommonToA = Join.mapIndexes(common, a.type());
int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type());
int[] mapCommonToB = Join.mapIndexes(common, b.type());
int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reducedType);
int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reducedType);
// TODO: refactor with code in IndexedTensor and Join
MultiDimensionIterator ic = new MultiDimensionIterator(common);
Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
for (MultiDimensionIterator ia = new MultiDimensionIterator(onlyInA); ia.hasNext(); ia.next()) {
for (MultiDimensionIterator ib = new MultiDimensionIterator(onlyInB); ib.hasNext(); ib.next()) {
agg.reset();
for (ic.reset(); ic.hasNext(); ic.next()) {
double va = a.get(toDirectIndex(ia, ic, stridesA, mapOnlyAToA, mapCommonToA));
double vb = b.get(toDirectIndex(ib, ic, stridesB, mapOnlyBToB, mapCommonToB));
agg.aggregate(combinator.applyAsDouble(va, vb));
}
builder.cellByDirectIndex(toDirectIndex(ia, ib, stridesResult, mapOnlyAToResult, mapOnlyBToResult),
agg.aggregatedValue());
}
}
return builder.build();
}
private long toDirectIndex(MultiDimensionIterator iter, MultiDimensionIterator common, long[] strides, int[] map, int[] commonmap) {
long directIndex = 0;
for (int i = 0; i < iter.length(); ++i) {
directIndex += strides[map[i]] * iter.iterator[i];
}
for (int i = 0; i < common.length(); ++i) {
directIndex += strides[commonmap[i]] * common.iterator[i];
}
return directIndex;
}
private long[] strides(TensorType type) {
long[] strides = new long[type.dimensions().size()];
if (strides.length > 0) {
long previous = 1;
strides[strides.length - 1] = previous;
for (int i = strides.length - 2; i >= 0; --i) {
strides[i] = previous * type.dimensions().get(i + 1).size().get();
previous = strides[i];
}
}
return strides;
}
private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a.type(), b.type()));
for (TensorType.Dimension aDim : a.type().dimensions()) {
for (TensorType.Dimension bDim : b.type().dimensions()) {
if (aDim.name().equals(bDim.name())) {
if ( ! aDim.size().isPresent()) {
builder.set(aDim);
} else if ( ! bDim.size().isPresent()) {
builder.set(bDim);
} else {
builder.set(aDim.size().get() < bDim.size().get() ? aDim : bDim); // minimum size of dimension
}
}
}
}
return builder.build();
}
/**
* Tests if there is exactly one reduce dimension and it is the innermost
* dimension in both tensors.
*/
private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) {
List reducingDimensions = dimensions;
if (reducingDimensions.isEmpty()) {
reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream()
.map(TensorType.Dimension::name)
.toList();
}
if (reducingDimensions.size() != 1) {
return false;
}
String dimension = reducingDimensions.get(0);
int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() ->
new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A."));
if (indexInA != (a.type().dimensions().size() - 1)) {
return false;
}
int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() ->
new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B."));
if (indexInB < (b.type().dimensions().size() - 1)) {
return false;
}
return true;
}
@Override
public String toString(ToStringContext context) {
return "reduce_join(" + argumentA.toString(context) + ", " +
argumentB.toString(context) + ", " +
combinator + ", " +
aggregator +
Reduce.commaSeparatedNames(dimensions, context) + ")";
}
@Override
public int hashCode() {
return Objects.hash("reduce_join", argumentA, argumentB, combinator, aggregator, dimensions);
}
private static class MultiDimensionIterator {
private final long[] bounds;
private final long[] iterator;
private long remaining;
MultiDimensionIterator(TensorType type) {
bounds = new long[type.dimensions().size()];
iterator = new long[type.dimensions().size()];
for (int i = 0; i < bounds.length; ++i) {
bounds[i] = type.dimensions().get(i).size().get();
}
reset();
}
public int length() {
return iterator.length;
}
public boolean hasNext() {
return remaining > 0;
}
public void reset() {
remaining = 1;
for (int i = iterator.length - 1; i >= 0; --i) {
iterator[i] = 0;
remaining *= bounds[i];
}
}
public void next() {
for (int i = iterator.length - 1; i >= 0; --i) {
iterator[i] += 1;
if (iterator[i] < bounds[i]) {
break;
}
iterator[i] = 0;
}
remaining -= 1;
}
@Override
public String toString() {
return Arrays.toString(iterator);
}
}
}