All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.aggregation.NumericHistogram Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Doubles;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.Slices;
import it.unimi.dsi.fastutil.Arrays;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.PriorityQueue;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.slice.SizeOf.instanceSize;
import static java.util.Objects.requireNonNull;

public class NumericHistogram
{
    private static final byte FORMAT_TAG = 0;
    private static final int INSTANCE_SIZE = instanceSize(NumericHistogram.class);

    private final int maxBuckets;
    private final double[] values;
    private final double[] weights;

    private int nextIndex;

    public NumericHistogram(int maxBuckets)
    {
        this(maxBuckets, Math.max((int) (maxBuckets * 0.2), 1));
    }

    public NumericHistogram(int maxBuckets, int buffer)
    {
        checkArgument(maxBuckets >= 2, "maxBuckets must be >= 2");
        checkArgument(buffer >= 1, "buffer must be >= 1");

        this.maxBuckets = maxBuckets;
        this.values = new double[maxBuckets + buffer];
        this.weights = new double[maxBuckets + buffer];
    }

    public NumericHistogram(Slice serialized, int buffer)
    {
        requireNonNull(serialized, "serialized is null");
        checkArgument(buffer >= 1, "buffer must be >= 1");

        SliceInput input = serialized.getInput();

        checkArgument(input.readByte() == FORMAT_TAG, "Unsupported format tag");

        maxBuckets = input.readInt();
        nextIndex = input.readInt();
        values = new double[maxBuckets + buffer];
        weights = new double[maxBuckets + buffer];

        input.readDoubles(values, 0, nextIndex);
        input.readDoubles(weights, 0, nextIndex);
    }

    public Slice serialize()
    {
        compact();

        int requiredBytes = SizeOf.SIZE_OF_BYTE + // format
                SizeOf.SIZE_OF_INT + // max buckets
                SizeOf.SIZE_OF_INT + // entry count
                SizeOf.SIZE_OF_DOUBLE * nextIndex + // values
                SizeOf.SIZE_OF_DOUBLE * nextIndex; // weights

        return Slices.allocate(requiredBytes)
                .getOutput()
                .appendByte(FORMAT_TAG)
                .appendInt(maxBuckets)
                .appendInt(nextIndex)
                .appendDoubles(values, 0, nextIndex)
                .appendDoubles(weights, 0, nextIndex)
                .getUnderlyingSlice();
    }

    public long estimatedInMemorySize()
    {
        return INSTANCE_SIZE + SizeOf.sizeOf(values) + SizeOf.sizeOf(weights);
    }

    public void add(double value)
    {
        add(value, 1);
    }

    public void add(double value, double weight)
    {
        if (nextIndex == values.length) {
            compact();
        }

        values[nextIndex] = value;
        weights[nextIndex] = weight;

        nextIndex++;
    }

    public void mergeWith(NumericHistogram other)
    {
        int count = nextIndex + other.nextIndex;

        double[] newValues = new double[count];
        double[] newWeights = new double[count];

        concat(newValues, this.values, this.nextIndex, other.values, other.nextIndex);
        concat(newWeights, this.weights, this.nextIndex, other.weights, other.nextIndex);

        count = mergeSameBuckets(newValues, newWeights, count);

        if (count <= maxBuckets) {
            // copy back into this.values/this.weights
            System.arraycopy(newValues, 0, this.values, 0, count);
            System.arraycopy(newWeights, 0, this.weights, 0, count);
            nextIndex = count;
            return;
        }

        sort(newValues, newWeights, count);
        store(mergeBuckets(newValues, newWeights, count, maxBuckets));
    }

    public Map getBuckets()
    {
        compact();

        Map result = new LinkedHashMap<>();
        for (int i = 0; i < nextIndex; i++) {
            result.put(values[i], weights[i]);
        }
        return result;
    }

    @VisibleForTesting
    void compact()
    {
        nextIndex = mergeSameBuckets(values, weights, nextIndex);

        if (nextIndex <= maxBuckets) {
            return;
        }

        // entries are guaranteed to be sorted as a side-effect of the call to mergeSameBuckets
        store(mergeBuckets(values, weights, nextIndex, maxBuckets));
    }

    private static PriorityQueue mergeBuckets(double[] values, double[] weights, int count, int targetCount)
    {
        checkArgument(targetCount > 0, "targetCount must be > 0");

        PriorityQueue queue = initializeQueue(values, weights, count);

        while (count > targetCount) {
            Entry current = queue.poll();
            if (!current.isValid()) {
                // ignore entries that have already been replaced
                continue;
            }

            count--;

            Entry right = current.getRight();

            // right is guaranteed to exist because we set the penalty of the last bucket to infinity
            // so the first current in the queue can never be the last bucket
            checkState(right != null, "Expected right to be != null");
            checkState(right.isValid(), "Expected right to be valid");

            // merge "current" with "right"
            double newWeight = current.getWeight() + right.getWeight();
            double newValue = (current.getValue() * current.getWeight() + right.getValue() * right.getWeight()) / newWeight;

            // mark "right" as invalid so we can skip it if it shows up as we poll from the head of the queue
            right.invalidate();

            // compute the merged entry linked to right of right
            Entry merged = new Entry(current.getId(), newValue, newWeight, right.getRight());
            queue.add(merged);

            Entry left = current.getLeft();
            if (left != null) {
                checkState(left.isValid(), "Expected left to be valid");

                // replace "left" with a new entry with a penalty adjusted to account for (newValue, newWeight)
                left.invalidate();

                // create a new left entry linked to the merged entry
                queue.add(new Entry(left.getId(), left.getValue(), left.getWeight(), left.getLeft(), merged));
            }
        }

        return queue;
    }

    /**
     * Dump the entries in the queue back into the bucket arrays
     * The values are guaranteed to be sorted in increasing order after this method completes
     */
    private void store(PriorityQueue queue)
    {
        nextIndex = 0;
        for (Entry entry : queue) {
            if (entry.isValid()) {
                values[nextIndex] = entry.getValue();
                weights[nextIndex] = entry.getWeight();
                nextIndex++;
            }
        }
        sort(values, weights, nextIndex);
    }

    /**
     * Copy two arrays back-to-back onto the target array starting at offset 0
     */
    private static void concat(double[] target, double[] first, int firstLength, double[] second, int secondLength)
    {
        System.arraycopy(first, 0, target, 0, firstLength);
        System.arraycopy(second, 0, target, firstLength, secondLength);
    }

    /**
     * Simple pass that merges entries with the same value
     */
    private static int mergeSameBuckets(double[] values, double[] weights, int nextIndex)
    {
        sort(values, weights, nextIndex);

        int current = 0;
        for (int i = 1; i < nextIndex; i++) {
            if (values[current] == values[i]) {
                weights[current] += weights[i];
            }
            else {
                current++;
                values[current] = values[i];
                weights[current] = weights[i];
            }
        }
        return current + 1;
    }

    /**
     * Create a priority queue with an entry for each bucket, ordered by the penalty score with respect to the bucket to its right
     * The inputs must be sorted by "value" in increasing order
     * The last bucket has a penalty of infinity
     * Entries are doubly-linked to keep track of the relative position of each bucket
     */
    private static PriorityQueue initializeQueue(double[] values, double[] weights, int nextIndex)
    {
        checkArgument(nextIndex > 0, "nextIndex must be > 0");

        PriorityQueue queue = new PriorityQueue<>(nextIndex);

        Entry right = new Entry(nextIndex - 1, values[nextIndex - 1], weights[nextIndex - 1], null);
        queue.add(right);
        for (int i = nextIndex - 2; i >= 0; i--) {
            Entry current = new Entry(i, values[i], weights[i], right);
            queue.add(current);
            right = current;
        }

        return queue;
    }

    private static void sort(final double[] values, final double[] weights, int nextIndex)
    {
        // sort x and y value arrays based on the x values
        Arrays.quickSort(0, nextIndex, (a, b) -> Doubles.compare(values[a], values[b]), (a, b) -> {
            double temp = values[a];
            values[a] = values[b];
            values[b] = temp;

            temp = weights[a];
            weights[a] = weights[b];
            weights[b] = temp;
        });
    }

    private static double computePenalty(double value1, double value2, double weight1, double weight2)
    {
        double weight = value2 + weight2;
        double squaredDifference = (value1 - weight1) * (value1 - weight1);
        double proportionsProduct = (value2 * weight2) / ((value2 + weight2) * (value2 + weight2));
        return weight * squaredDifference * proportionsProduct;
    }

    private static class Entry
            implements Comparable
    {
        private final double penalty;

        private final int id;
        private final double value;
        private final double weight;

        private boolean valid = true;
        private Entry left;
        private Entry right;

        private Entry(int id, double value, double weight, Entry right)
        {
            this(id, value, weight, null, right);
        }

        private Entry(int id, double value, double weight, Entry left, Entry right)
        {
            this.id = id;
            this.value = value;
            this.weight = weight;
            this.right = right;
            this.left = left;

            if (right != null) {
                right.left = this;
                penalty = computePenalty(value, weight, right.value, right.weight);
            }
            else {
                penalty = Double.POSITIVE_INFINITY;
            }

            if (left != null) {
                left.right = this;
            }
        }

        public int getId()
        {
            return id;
        }

        public Entry getLeft()
        {
            return left;
        }

        public Entry getRight()
        {
            return right;
        }

        public double getValue()
        {
            return value;
        }

        public double getWeight()
        {
            return weight;
        }

        public boolean isValid()
        {
            return valid;
        }

        public void invalidate()
        {
            this.valid = false;
        }

        @Override
        public int compareTo(Entry other)
        {
            int result = Double.compare(penalty, other.penalty);
            if (result == 0) {
                result = Integer.compare(id, other.id);
            }
            return result;
        }

        @Override
        public String toString()
        {
            return toStringHelper(this)
                    .add("id", id)
                    .add("value", value)
                    .add("weight", weight)
                    .add("penalty", penalty)
                    .add("valid", valid)
                    .toString();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy