All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yahoo.tensor.functions.Reduce Maven / Gradle / Ivy

Go to download

Library for use in Java components of Vespa. Shared code which do not fit anywhere else.

There is a newer version: 8.409.18
Show newest version
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.Convert;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions
 * are collapsed to a single value using an aggregator function.
 *
 * @author bratseth
 */
public class Reduce extends PrimitiveTensorFunction {

    public enum Aggregator { avg, count, max, median, min, prod, sum ; }

    private final TensorFunction argument;
    private final List dimensions;
    private final Aggregator aggregator;

    /** Creates a reduce function reducing all dimensions */
    public Reduce(TensorFunction argument, Aggregator aggregator) {
        this(argument, aggregator, List.of());
    }

    /** Creates a reduce function reducing a single dimension */
    public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
        this(argument, aggregator, List.of(dimension));
    }

    /**
     * Creates a reduce function.
     *
     * @param argument the tensor to reduce
     * @param aggregator the aggregator function to use
     * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
     *                   producing a dimensionless tensor (a scalar).
     * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
     */
    public Reduce(TensorFunction argument, Aggregator aggregator, List dimensions) {
        this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null");
        this.aggregator  = Objects.requireNonNull(aggregator, "The aggregator cannot be null");
        this.dimensions = List.copyOf(dimensions);
    }

    public static TensorType outputType(TensorType inputType, List reduceDimensions) {
        return TypeResolver.reduce(inputType, reduceDimensions);
    }

    public TensorFunction argument() { return argument; }

    Aggregator aggregator() { return aggregator; }

    List dimensions() { return dimensions; }

    @Override
    public List> arguments() { return List.of(argument); }

    @Override
    public TensorFunction withArguments(List> arguments) {
        if ( arguments.size() != 1)
            throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
        return new Reduce<>(arguments.get(0), aggregator, dimensions);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new Reduce<>(argument.toPrimitive(), aggregator, dimensions);
    }

    @Override
    public String toString(ToStringContext context) {
        return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparatedNames(dimensions, context) + ")";
    }

    static  String commaSeparatedNames(List list, ToStringContext context) {
        StringBuilder b = new StringBuilder();
        for (String element  : list)
            b.append(", ").append(context.resolveBinding(element));
        return b.toString();
    }

    @Override
    public TensorType type(TypeContext context) {
        List resolvedDimensions = dimensions.stream().map(d -> context.resolveBinding(d)).toList();
        return outputType(argument.type(context), resolvedDimensions);
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        return evaluate(this.argument.evaluate(context), dimensions, aggregator);
    }

    @Override
    public int hashCode() {
        return Objects.hash("reduce", argument, dimensions, aggregator);
    }

    static Tensor evaluate(Tensor argument, List dimensions, Aggregator aggregator) {
        if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions))
            throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
                    dimensions + ": Not all those dimensions are present in this tensor");

        // Special case: Reduce all
        if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) {
            if (argument.isEmpty())
                return Tensor.from(0.0);
            else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
                return reduceIndexedVector((IndexedTensor) argument, aggregator);
            else
                return reduceAllGeneral(argument, aggregator);
        }

        TensorType reducedType = outputType(argument.type(), dimensions);
        int[] indexesToReduce = createIndexesToReduce(argument.type(), dimensions);
        int[] indexesToKeep = createIndexesToKeep(argument.type(), indexesToReduce);
        if (argument instanceof IndexedTensor indexedTensor && reducedType.hasOnlyIndexedBoundDimensions()) {
            return reduceIndexedTensor(indexedTensor, reducedType, indexesToKeep, indexesToReduce, aggregator);
        } else {
            return reduceGeneral(argument, reducedType, indexesToKeep, aggregator);
        }
    }

    private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) {
        int currentIndex = reduce[reduceIndex];
        int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
        if (reduceIndex + 1  < reduce.length) {
            int nextDimension = reduceIndex + 1;
            for (int i = 0; i < dimSize; i++) {
                address.setIndex(currentIndex, i);
                reduce(argument, aggregator, address, reduce, nextDimension);
            }
        } else {
            address.setIndex(currentIndex, 0);
            long increment = address.getStride(currentIndex);
            long directIndex = address.getDirectIndex();
            for (int i = 0; i < dimSize; i++) {
                aggregator.aggregate(argument.get(directIndex + i * increment));
            }
        }
    }

    private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) {
        if (keepIndex < toKeep.length) {
            int currentIndex = toKeep[keepIndex];
            int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));

            int nextKeep = keepIndex + 1;
            for (int i = 0; i < dimSize; i++) {
                address.setIndex(currentIndex, i);
                destAddress.setIndex(keepIndex, i);
                reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce);
            }
        } else {
            ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
            reduce(argument, valueAggregator, address, toReduce, 0);
            builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes());
        }

    }

    private static Tensor reduceIndexedTensor(IndexedTensor argument, TensorType reducedType, int[] indexesToKeep, int[] indexesToReduce, Aggregator aggregator) {

        var reducedBuilder = IndexedTensor.Builder.of(reducedType);
        DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType));
        reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce);
        return reducedBuilder.build();
    }

    private static Tensor reduceGeneral(Tensor argument, TensorType reducedType, int[] indexesToKeep, Aggregator aggregator) {
        // TODO cells.size() is most likely an overestimate, and might need a better heuristic
        // But the upside is larger than the downside.
        Map aggregatingCells = new HashMap<>(argument.sizeAsInt());
        for (Iterator i = argument.cellIterator(); i.hasNext(); ) {
            Map.Entry cell = i.next();
            TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep);
            ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator));
            aggr.aggregate(cell.getValue());
        }
        Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
        for (Map.Entry aggregatingCell : aggregatingCells.entrySet())
            reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());

        return reducedBuilder.build();
    }

    private static int[] createIndexesToReduce(TensorType tensorType, List dimensions) {
        int[] indexesToReduce = new int[dimensions.size()];
        for (int i = 0; i < dimensions.size(); i++) {
            indexesToReduce[i] = tensorType.indexOfDimension(dimensions.get(i)).get();
        }
        return indexesToReduce;
    }
    private static int[] createIndexesToKeep(TensorType argumentType, int[] indexesToReduce) {
        int[] indexesToKeep = new int[argumentType.rank() - indexesToReduce.length];
        int toKeepIndex = 0;
        for (int i = 0; i < argumentType.rank(); i++) {
            if ( ! contains(indexesToReduce, i))
                indexesToKeep[toKeepIndex++] = i;
        }
        return indexesToKeep;
    }
    private static boolean contains(int[] list, int key) {
        for (int candidate : list) {
            if (candidate == key) return true;
        }
        return false;
    }

    private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
        ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
        for (Iterator i = argument.valueIterator(); i.hasNext(); )
            valueAggregator.aggregate(i.next());
        return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build();
    }

    private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) {
        ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
        int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0));
        for (int i = 0; i < dimensionSize ; i++)
            valueAggregator.aggregate(argument.get(i));
        return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build();
    }

    static abstract class ValueAggregator {

        static ValueAggregator ofType(Aggregator aggregator) {
            return switch (aggregator) {
                case avg -> new AvgAggregator();
                case count -> new CountAggregator();
                case max -> new MaxAggregator();
                case median -> new MedianAggregator();
                case min -> new MinAggregator();
                case prod -> new ProdAggregator();
                case sum -> new SumAggregator();
                default -> throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
            };

        }

        /** Add a new value to those aggregated by this */
        public abstract void aggregate(double value);

        /** Returns the value aggregated by this */
        public abstract double aggregatedValue();

        /** Resets the aggregator */
        public abstract void reset();

        /** Returns a hash of this aggregator which only depends on its identity */
        @Override
        public abstract int hashCode();

    }

    private static class AvgAggregator extends ValueAggregator {

        private int valueCount = 0;
        private double valueSum = 0.0;

        @Override
        public void aggregate(double value) {
            valueCount++;
            valueSum+= value;
        }

        @Override
        public double aggregatedValue() {
            return valueSum / valueCount;
        }

        @Override
        public void reset() {
            valueCount = 0;
            valueSum = 0.0;
        }

        @Override
        public int hashCode() { return "avgAggregator".hashCode(); }

    }

    private static class CountAggregator extends ValueAggregator {

        private int valueCount = 0;

        @Override
        public void aggregate(double value) {
            valueCount++;
        }

        @Override
        public double aggregatedValue() {
            return valueCount;
        }

        @Override
        public void reset() {
            valueCount = 0;
        }

        @Override
        public int hashCode() { return "countAggregator".hashCode(); }

    }

    private static class MaxAggregator extends ValueAggregator {

        private double maxValue = Double.NEGATIVE_INFINITY;

        @Override
        public void aggregate(double value) {
            if (value > maxValue)
                maxValue = value;
        }

        @Override
        public double aggregatedValue() {
            return maxValue;
        }

        @Override
        public void reset() {
            maxValue = Double.NEGATIVE_INFINITY;
        }

        @Override
        public int hashCode() { return "maxAggregator".hashCode(); }

    }

    private static class MedianAggregator extends ValueAggregator {

        /** If any NaN is added, the result should be NaN */
        private boolean isNaN = false;

        private List values = new ArrayList<>();

        @Override
        public void aggregate(double value) {
            if ( Double.isNaN(value))
                isNaN = true;
            if ( ! isNaN)
                values.add(value);
        }

        @Override
        public double aggregatedValue() {
            if (isNaN || values.isEmpty()) return Double.NaN;
            Collections.sort(values);
            if (values.size() % 2 == 0) // even: average the two middle values
                return ( values.get(values.size() / 2 - 1) + values.get(values.size() / 2) ) / 2;
            else
                return values.get((values.size() - 1)/ 2);
        }

        @Override
        public void reset() {
            isNaN = false;
            values = new ArrayList<>();
        }

        @Override
        public int hashCode() { return "medianAggregator".hashCode(); }

    }

    private static class MinAggregator extends ValueAggregator {

        private double minValue = Double.POSITIVE_INFINITY;

        @Override
        public void aggregate(double value) {
            if (value < minValue)
                minValue = value;
        }

        @Override
        public double aggregatedValue() {
            return minValue;
        }

        @Override
        public void reset() {
            minValue = Double.POSITIVE_INFINITY;
        }

        @Override
        public int hashCode() { return "minAggregator".hashCode(); }

    }

    private static class ProdAggregator extends ValueAggregator {

        private double valueProd = 1.0;

        @Override
        public void aggregate(double value) {
            valueProd *= value;
        }

        @Override
        public double aggregatedValue() {
            return valueProd;
        }

        @Override
        public void reset() {
            valueProd = 1.0;
        }

        @Override
        public int hashCode() { return "prodAggregator".hashCode(); }

    }

    private static class SumAggregator extends ValueAggregator {

        private double valueSum = 0.0;

        @Override
        public void aggregate(double value) {
            valueSum += value;
        }

        @Override
        public double aggregatedValue() {
            return valueSum;
        }

        @Override
        public void reset() {
            valueSum = 0.0;
        }

        @Override
        public int hashCode() { return "sumAggregator".hashCode(); }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy