All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.Join Maven / Gradle / Ivy

// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.TensorAddressAny;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;

/**
 * The join tensor operation produces a tensor from the argument tensors containing the set of cells
 * given by the cross product of the cells of the given tensors, having as values the value produced by
 * applying the given combinator function on the values from the two source cells.
 *
 * @author bratseth
 */
public class Join extends PrimitiveTensorFunction {

    private final TensorFunction argumentA, argumentB;
    private final DoubleBinaryOperator combinator;

    public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(combinator, "The combinator function cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.combinator = combinator;
    }

    /** Returns the type resulting from applying Join to the two given types */
    public static TensorType outputType(TensorType a, TensorType b) {
        try {
            return TypeResolver.join(a, b);
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
        }
    }

    public DoubleBinaryOperator combinator() { return combinator; }

    @Override
    public List> arguments() { return List.of(argumentA, argumentB); }

    @Override
    public TensorFunction withArguments(List> arguments) {
        if ( arguments.size() != 2)
            throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
        return new Join<>(arguments.get(0), arguments.get(1), combinator);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
    }

    @Override
    public String toString(ToStringContext context) {
        return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
    }

    @Override
    public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); }

    @Override
    public TensorType type(TypeContext context) {
        return outputType(argumentA.type(context), argumentB.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor a = argumentA.evaluate(context);
        Tensor b = argumentB.evaluate(context);
        TensorType joinedType = outputType(a.type(), b.type());
        return evaluate(a, b, joinedType, combinator);
    }

    static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        // Choose join algorithm
        if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
            return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
        else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
            return singleSpaceJoin(a, b, joinedType, combinator);
        else if (a.type().dimensions().containsAll(b.type().dimensions()))
            return subspaceJoin(b, a, joinedType, true, combinator);
        else if (b.type().dimensions().containsAll(a.type().dimensions()))
            return subspaceJoin(a, b, joinedType, false, combinator);
        else
            return generalJoin(a, b, joinedType, combinator);

    }

    private static boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
        int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
        Iterator aIterator = a.valueIterator();
        Iterator bIterator = b.valueIterator();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
        for (int i = 0; i < joinedRank; i++)
            builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
        return builder.build();
    }

    /** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */
    private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        for (Iterator i = a.cellIterator(); i.hasNext(); ) {
            Map.Entry aCell = i.next();
            var key = aCell.getKey();
            Double bVal = b.getAsDouble(key);
            if (bVal != null) {
                builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
            }
        }
        return builder.build();
    }

    /** Join a tensor into a superspace */
    private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
        if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
            return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator);
        else
            return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator);
    }

    private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
        if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes
            return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();

        DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);

        IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);

        // Find dimensions which are only in the supertype
        Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
        superDimensionNames.removeAll(subspace.type().dimensionNames());

        for (Iterator i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
            IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
            joinSubspaces(subspace.valueIterator(), subspace.size(),
                    subspaceInSuper, subspaceInSuper.size(),
                    reversedArgumentOrder, builder, combinator);
        }

        return builder.build();
    }

    private static void joinSubspaces(Iterator subspace, long subspaceSize,
                                      Iterator superspace, long superspaceSize,
                                      boolean reversedArgumentOrder, IndexedTensor.Builder builder,
                                      DoubleBinaryOperator combinator) {
        int joinedLength = (int)Math.min(subspaceSize, superspaceSize);
        if (reversedArgumentOrder) {
            for (int i = 0; i < joinedLength; i++) {
                Tensor.Cell supercell = superspace.next();
                builder.cell(supercell, combinator.applyAsDouble(supercell.getValue(), subspace.next()));
            }
        } else {
            for (int i = 0; i < joinedLength; i++) {
                Tensor.Cell supercell = superspace.next();
                builder.cell(supercell, combinator.applyAsDouble(subspace.next(), supercell.getValue()));
            }
        }
    }

    private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); i++) {
            String dimensionName = joinedType.dimensions().get(i).name();
            Optional aIndex = a.type().indexOfDimension(dimensionName);
            Optional bIndex = b.type().indexOfDimension(dimensionName);
            if (aIndex.isPresent() && bIndex.isPresent())
                builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get())));
            else if (aIndex.isPresent())
                builder.set(i, a.dimensionSizes().size(aIndex.get()));
            else if (bIndex.isPresent())
                builder.set(i, b.dimensionSizes().size(bIndex.get()));
        }
        return builder.build();
    }

    private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
        int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        for (Iterator i = superspace.cellIterator(); i.hasNext(); ) {
            Map.Entry supercell = i.next();
            TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
            Double subspaceValue = subspace.getAsDouble(subaddress);
            if (subspaceValue != null) {
                builder.cell(supercell.getKey(),
                        reversedArgumentOrder
                                ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
                                : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
            }
        }
        return builder.build();
    }

    /** Returns the indexes in the superspace type which should be retained to create the subspace type */
    private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
        int[] subspaceIndexes = new int[subtype.dimensions().size()];
        for (int i = 0; i < subtype.dimensions().size(); i++)
            subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
        return subspaceIndexes;
    }

    /** Slow join which works for any two tensors */
    private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        if (a instanceof IndexedTensor && b instanceof IndexedTensor)
            return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator);
        else
            return mappedHashJoin(a, b, joinedType, combinator);
    }

    private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        DimensionSizes joinedSize = joinedSize(joinedType, a, b);
        Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
        int[] aToIndexes = mapIndexes(a.type(), joinedType);
        int[] bToIndexes = mapIndexes(b.type(), joinedType);
        joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator);
        return builder.build();
    }

    private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
                               int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
                               DoubleBinaryOperator combinator) {
        Set sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
        int sharedDimensionSize = sharedDimensions.size();  // Expensive to compute size after intersection
        Set dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));

        DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
        DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);

        // for each combination of dimensions only in a
        for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
            IndexedTensor.SubspaceIterator aSubspace = ia.next();
            // for each combination of dimensions in a which is also in b
            while (aSubspace.hasNext()) {
                Tensor.Cell aCell = aSubspace.next();
                PartialAddress matchingBCells = sharedDimensionSize > 0
                        ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
                        : empty;
                // for each matching combination of dimensions ony in b
                for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
                    Tensor.Cell bCell = bSubspace.next();
                    TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
                    double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
                    builder.cell(joinedAddress, joinedValue);
                }
            }
        }
    }

    private static final PartialAddress empty = new PartialAddress.Builder(0).build();
    private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
                                                 Set retainDimensions, int sharedDimensionSize) {
        PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
        for (int i = 0; i < addressType.dimensions().size(); i++) {
            String dimension = addressType.dimensions().get(i).name();
            if (retainDimensions.contains(dimension))
                builder.add(dimension, address.numericLabel(i));
        }
        return builder.build();
    }

    /** Returns the sizes from the joined sizes which are present in the type argument */
    private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
        int dimensionIndex = 0;
        for (int i = 0; i < joinedType.dimensions().size(); i++) {
            if (type.dimensionNames().contains(joinedType.dimensions().get(i).name()))
                builder.set(dimensionIndex++, joinedSizes.size(i));
        }
        return builder.build();
    }

    private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        int[] aToIndexes = mapIndexes(a.type(), joinedType);
        int[] bToIndexes = mapIndexes(b.type(), joinedType);
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        for (Iterator aIterator = a.cellIterator(); aIterator.hasNext(); ) {
            Map.Entry aCell = aIterator.next();
            for (Iterator bIterator = b.cellIterator(); bIterator.hasNext(); ) {
                Map.Entry bCell = bIterator.next();
                TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
                                                              bCell.getKey(), bToIndexes, joinedType);
                if (combinedAddress == null) continue; // not combinable
                builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
            }
        }
        return builder.build();
    }

    private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        TensorType commonDimensionType = commonDimensions(a, b);
        if (commonDimensionType.dimensions().isEmpty()) {
            return mappedGeneralJoin(a, b, joinedType, combinator); // fallback
        }

        boolean swapTensors = a.size() > b.size();
        if (swapTensors) {
            Tensor temp = a;
            a = b;
            b = temp;
        }

        // Map dimension indexes to common and joined type
        int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type());
        int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type());
        int[] aIndexesInJoined = mapIndexes(a.type(), joinedType);
        int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);

        // Iterate once through the smaller tensor and construct a hash map for common dimensions
        Map> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
        for (Iterator cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
            Tensor.Cell aCell = cellIterator.next();
            TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
            aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
        }

        // Iterate once through the larger tensor and use the hash map to find joinable cells
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        for (Iterator cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
            Tensor.Cell bCell = cellIterator.next();
            TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
            for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
                TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
                                                              bCell.getKey(), bIndexesInJoined, joinedType);
                if (combinedAddress == null) continue; // not combinable
                double combinedValue = swapTensors ?
                        combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) :
                        combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
                builder.cell(combinedAddress, combinedValue);
            }
        }

        return builder.build();
    }

    /**
     * Returns an array having one entry in order for each dimension of fromType
     * containing the index at which toType contains the same dimension name.
     * That is, if the returned array contains n at index i then
     * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
     * If some dimension in fromType is not present in toType, the corresponding index will be -1
     */
    static int[] mapIndexes(TensorType fromType, TensorType toType) {
        int[] toIndexes = new int[fromType.dimensions().size()];
        for (int i = 0; i < fromType.dimensions().size(); i++)
            toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name());
        return toIndexes;
    }

    private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
                                               TensorType joinedType) {
        long[] joinedLabels = new long[joinedType.dimensions().size()];
        Arrays.fill(joinedLabels, Tensor.invalidIndex);
        mapContent(a, joinedLabels, aToIndexes);
        boolean compatible = mapContent(b, joinedLabels, bToIndexes);
        if ( ! compatible) return null;
        return TensorAddressAny.ofUnsafe(joinedLabels);
    }

    /**
     * Maps the content in the given list to the given array, using the given index map.
     *
     * @return true if the mapping was successful, false if one of the destination positions was
     *         occupied by a different value
     */
    private static boolean mapContent(TensorAddress from, long[] to, int[] indexMap) {
        for (int i = 0, size = from.size(); i < size; i++) {
            int toIndex = indexMap[i];
            long label = from.numericLabel(i);
            if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != label)
                return false;
            to[toIndex] = label;
        }
        return true;
    }

    /** Returns common dimension of a and b as a new tensor type */
    private static TensorType commonDimensions(Tensor a, Tensor b) {
        TensorType aType = a.type();
        TensorType bType = b.type();
        TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType));
        for (int i = 0; i < aType.dimensions().size(); ++i) {
            TensorType.Dimension aDim = aType.dimensions().get(i);
            for (int j = 0; j < bType.dimensions().size(); ++j) {
                TensorType.Dimension bDim = bType.dimensions().get(j);
                if (aDim.equals(bDim)) {
                    typeBuilder.set(bDim);
                }
            }
        }
        return typeBuilder.build();
    }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy