All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.deephaven.engine.table.impl.by.ChunkedWeightedAverageOperator Maven / Gradle / Ivy

There is a newer version: 0.37.1
Show newest version
/**
 * Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending
 */
package io.deephaven.engine.table.impl.by;

import io.deephaven.base.verify.Assert;
import io.deephaven.chunk.attributes.ChunkLengths;
import io.deephaven.chunk.attributes.ChunkPositions;
import io.deephaven.chunk.attributes.Values;
import io.deephaven.engine.table.ColumnSource;
import io.deephaven.engine.rowset.chunkattributes.RowKeys;
import io.deephaven.util.QueryConstants;
import io.deephaven.engine.util.NullSafeAddition;
import io.deephaven.engine.table.impl.sources.*;
import io.deephaven.chunk.*;
import io.deephaven.engine.table.impl.util.cast.ToDoubleCast;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.apache.commons.lang3.mutable.MutableInt;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import static io.deephaven.engine.table.impl.by.RollupConstants.*;

class ChunkedWeightedAverageOperator implements IterativeChunkedAggregationOperator {
    private final ChunkType chunkType;
    private final DoubleWeightRecordingInternalOperator weightOperator;
    private final String resultName;
    private final boolean exposeInternalColumns;

    private long tableSize;
    private final LongArraySource normalCount;
    private LongArraySource nanCount;
    private final DoubleArraySource sumOfWeights;
    private final DoubleArraySource weightedSum;
    private final DoubleArraySource resultColumn;

    ChunkedWeightedAverageOperator(
            ChunkType chunkType,
            DoubleWeightRecordingInternalOperator weightOperator,
            String name,
            boolean exposeInternalColumns) {
        this.chunkType = chunkType;
        this.weightOperator = weightOperator;
        this.resultName = name;
        this.exposeInternalColumns = exposeInternalColumns;

        tableSize = 0;
        normalCount = new LongArraySource();
        weightedSum = new DoubleArraySource();
        sumOfWeights = new DoubleArraySource();
        resultColumn = new DoubleArraySource();
    }

    @Override
    public void addChunk(BucketedContext bucketedContext, Chunk values,
            LongChunk inputRowKeys, IntChunk destinations,
            IntChunk startPositions, IntChunk length,
            WritableBooleanChunk stateModified) {
        final Context context = (Context) bucketedContext;
        final DoubleChunk doubleValues = context.toDoubleCast.cast(values);
        final DoubleChunk weightValues = weightOperator.getAddedWeights();
        Assert.neqNull(weightValues, "weightValues");
        for (int ii = 0; ii < startPositions.size(); ++ii) {
            final int startPosition = startPositions.get(ii);
            stateModified.set(ii, addChunk(doubleValues, weightValues, startPosition, length.get(ii),
                    destinations.get(startPosition)));
        }
    }

    @Override
    public void removeChunk(BucketedContext bucketedContext, Chunk values,
            LongChunk inputRowKeys, IntChunk destinations,
            IntChunk startPositions, IntChunk length,
            WritableBooleanChunk stateModified) {
        final Context context = (Context) bucketedContext;
        final DoubleChunk doubleValues = context.prevToDoubleCast.cast(values);
        final DoubleChunk weightValues = weightOperator.getRemovedWeights();
        Assert.neqNull(weightValues, "weightValues");
        for (int ii = 0; ii < startPositions.size(); ++ii) {
            final int startPosition = startPositions.get(ii);
            stateModified.set(ii, removeChunk(doubleValues, weightValues, startPosition, length.get(ii),
                    destinations.get(startPosition)));
        }
    }

    @Override
    public void modifyChunk(BucketedContext bucketedContext, Chunk previousValues,
            Chunk newValues, LongChunk postShiftRowKeys,
            IntChunk destinations, IntChunk startPositions, IntChunk length,
            WritableBooleanChunk stateModified) {
        final Context context = (Context) bucketedContext;
        final DoubleChunk prevDoubleValues = context.prevToDoubleCast.cast(previousValues);
        final DoubleChunk prevWeightValues = weightOperator.getRemovedWeights();
        final DoubleChunk newDoubleValues = context.toDoubleCast.cast(newValues);
        final DoubleChunk newWeightValues = weightOperator.getAddedWeights();
        for (int ii = 0; ii < startPositions.size(); ++ii) {
            final int startPosition = startPositions.get(ii);
            stateModified.set(ii, modifyChunk(prevDoubleValues, prevWeightValues, newDoubleValues, newWeightValues,
                    startPosition, length.get(ii), destinations.get(startPosition)));
        }
    }

    @Override
    public boolean addChunk(SingletonContext singletonContext, int chunkSize, Chunk values,
            LongChunk inputRowKeys, long destination) {
        final Context context = (Context) singletonContext;
        final DoubleChunk doubleValues = context.toDoubleCast.cast(values);
        final DoubleChunk weightValues = weightOperator.getAddedWeights();
        return addChunk(doubleValues, weightValues, 0, values.size(), destination);
    }

    @Override
    public boolean removeChunk(SingletonContext singletonContext, int chunkSize, Chunk values,
            LongChunk inputRowKeys, long destination) {
        final Context context = (Context) singletonContext;
        final DoubleChunk doubleValues = context.prevToDoubleCast.cast(values);
        final DoubleChunk weightValues = weightOperator.getRemovedWeights();
        return removeChunk(doubleValues, weightValues, 0, values.size(), destination);
    }

    @Override
    public boolean modifyChunk(SingletonContext singletonContext, int chunkSize, Chunk previousValues,
            Chunk newValues, LongChunk postShiftRowKeys, long destination) {
        final Context context = (Context) singletonContext;
        final DoubleChunk newDoubleValues = context.toDoubleCast.cast(newValues);
        final DoubleChunk newWeightValues = weightOperator.getAddedWeights();

        final DoubleChunk prevDoubleValues = context.prevToDoubleCast.cast(previousValues);
        final DoubleChunk prevWeightValues = weightOperator.getRemovedWeights();

        return modifyChunk(prevDoubleValues, prevWeightValues, newDoubleValues, newWeightValues, 0,
                newDoubleValues.size(), destination);
    }

    private static void sumChunks(DoubleChunk doubleValues,
            DoubleChunk weightValues,
            int start,
            int length,
            MutableInt nansOut,
            MutableInt normalOut,
            MutableDouble sumOfWeightsOut,
            MutableDouble weightedSumOut) {
        long nans = 0;
        long normal = 0;
        double sumOfWeights = 0.0;
        double weightedSum = 0.0;

        for (int ii = 0; ii < length; ++ii) {
            final double weight = weightValues.get(start + ii);
            final double component = doubleValues.get(start + ii);

            if (Double.isNaN(weight) || Double.isNaN(component)) {
                nans++;
                continue;
            }

            if (Double.isInfinite(weight) || Double.isInfinite(component)) {
                nans++;
                continue;
            }

            if (weight == QueryConstants.NULL_DOUBLE || component == QueryConstants.NULL_DOUBLE) {
                continue;
            }

            normal++;
            sumOfWeights += weight;
            weightedSum += weight * component;
        }

        nansOut.setValue(nans);
        normalOut.setValue(normal);
        sumOfWeightsOut.setValue(sumOfWeights);
        weightedSumOut.setValue(weightedSum);
    }

    private boolean addChunk(DoubleChunk doubleValues, DoubleChunk weightValues,
            int start, int length, long destination) {
        final MutableInt nanOut = new MutableInt();
        final MutableInt normalOut = new MutableInt();
        final MutableDouble sumOfWeightsOut = new MutableDouble();
        final MutableDouble weightedSumOut = new MutableDouble();

        sumChunks(doubleValues, weightValues, start, length, nanOut, normalOut, sumOfWeightsOut, weightedSumOut);

        final long newNans = nanOut.intValue();
        final long newNormal = normalOut.intValue();
        final double newSumOfWeights = sumOfWeightsOut.doubleValue();
        final double newWeightedSum = weightedSumOut.doubleValue();

        final long totalNans;
        if (nanCount == null && newNans > 0) {
            totalNans = allocateNans(destination, newNans);
        } else if (nanCount != null) {
            final long oldNans = nanCount.getUnsafe(destination);
            totalNans = NullSafeAddition.plusLong(oldNans, newNans);
            if (newNans > 0) {
                nanCount.set(destination, totalNans);
            }
        } else {
            totalNans = 0;
        }

        final long totalNormal;
        final long existingNormal = normalCount.getUnsafe(destination);
        totalNormal = NullSafeAddition.plusLong(existingNormal, newNormal);
        Assert.geq(totalNormal, "totalNormal", newNormal, "newNormal");
        if (newNormal > 0) {
            normalCount.set(destination, totalNormal);
        }

        if (totalNormal > 0) {
            final double existingSumOfWeights = sumOfWeights.getUnsafe(destination);
            final double existingWeightedSum = weightedSum.getUnsafe(destination);

            final double totalWeightedSum = NullSafeAddition.plusDouble(existingWeightedSum, newWeightedSum);
            final double totalSumOfWeights = NullSafeAddition.plusDouble(existingSumOfWeights, newSumOfWeights);

            if (newNormal > 0) {
                weightedSum.set(destination, totalWeightedSum);
                sumOfWeights.set(destination, totalSumOfWeights);
            }

            if (totalNans > 0) {
                if (newNans == totalNans) {
                    resultColumn.set(destination, Double.NaN);
                    return true;
                }
                return false;
            } else {
                final double newResult = totalWeightedSum / totalSumOfWeights;
                final double existingResult = resultColumn.getAndSetUnsafe(destination, newResult);
                return newResult != existingResult;
            }
        } else {
            if (totalNans > 0 && totalNans == newNans) {
                resultColumn.set(destination, Double.NaN);
                return true;
            }
            return false;
        }
    }

    private long allocateNans(long destination, long newNans) {
        nanCount = new LongArraySource();
        nanCount.ensureCapacity(tableSize);
        nanCount.set(destination, newNans);
        return newNans;
    }

    private boolean removeChunk(DoubleChunk doubleValues, DoubleChunk weightValues,
            int start, int length, long destination) {
        final MutableInt nanOut = new MutableInt();
        final MutableInt normalOut = new MutableInt();
        final MutableDouble sumOfWeightsOut = new MutableDouble();
        final MutableDouble weightedSumOut = new MutableDouble();

        sumChunks(doubleValues, weightValues, start, length, nanOut, normalOut, sumOfWeightsOut, weightedSumOut);

        final int newNans = nanOut.intValue();
        final int newNormal = normalOut.intValue();
        final double newSumOfWeights = sumOfWeightsOut.doubleValue();
        final double newWeightedSum = weightedSumOut.doubleValue();

        final long totalNans;
        if (newNans > 0) {
            final long oldNans = nanCount.getUnsafe(destination);
            totalNans = NullSafeAddition.minusLong(oldNans, newNans);
            nanCount.set(destination, totalNans);
        } else if (nanCount != null) {
            totalNans = nanCount.getUnsafe(destination);
        } else {
            totalNans = 0;
        }

        final long totalNormal;
        final long existingNormal = normalCount.getUnsafe(destination);
        if (newNormal > 0) {
            totalNormal = NullSafeAddition.minusLong(existingNormal, newNormal);
            normalCount.set(destination, totalNormal);
        } else {
            totalNormal = NullSafeAddition.plusLong(existingNormal, 0);
        }
        Assert.geqZero(totalNormal, "totalNormal");

        final double totalWeightedSum;
        final double totalSumOfWeights;
        if (newNormal > 0) {
            if (totalNormal == 0) {
                weightedSum.set(destination, totalWeightedSum = 0.0);
                sumOfWeights.set(destination, totalSumOfWeights = 0.0);
            } else {
                final double existingSumOfWeights = sumOfWeights.getUnsafe(destination);
                final double existingWeightedSum = weightedSum.getUnsafe(destination);
                totalWeightedSum = existingWeightedSum - newWeightedSum;
                totalSumOfWeights = existingSumOfWeights - newSumOfWeights;
                weightedSum.set(destination, totalWeightedSum);
                sumOfWeights.set(destination, totalSumOfWeights);
            }
        } else {
            totalWeightedSum = weightedSum.getUnsafe(destination);
            totalSumOfWeights = sumOfWeights.getUnsafe(destination);
        }

        if (totalNans > 0) {
            // if we had nans before and removed some, but not all nothing could have changed
            return false;
        } else if (totalNormal == 0) {
            if (newNans > 0 || newNormal > 0) {
                resultColumn.set(destination, QueryConstants.NULL_DOUBLE);
                return true;
            }
            return false;
        } else {
            final double newResult = totalWeightedSum / totalSumOfWeights;
            final double existingResult = resultColumn.getAndSetUnsafe(destination, newResult);
            return newResult != existingResult;
        }
    }

    private boolean modifyChunk(DoubleChunk prevDoubleValues,
            DoubleChunk prevWeightValues, DoubleChunk newDoubleValues,
            DoubleChunk newWeightValues, int start, int length, long destination) {
        final MutableInt nanOut = new MutableInt();
        final MutableInt normalOut = new MutableInt();
        final MutableDouble sumOfWeightsOut = new MutableDouble();
        final MutableDouble weightedSumOut = new MutableDouble();

        sumChunks(prevDoubleValues, prevWeightValues, start, length, nanOut, normalOut, sumOfWeightsOut,
                weightedSumOut);

        final int prevNans = nanOut.intValue();
        final int prevNormal = normalOut.intValue();
        final double prevSumOfWeights = sumOfWeightsOut.doubleValue();
        final double prevWeightedSum = weightedSumOut.doubleValue();

        sumChunks(newDoubleValues, newWeightValues, start, length, nanOut, normalOut, sumOfWeightsOut, weightedSumOut);

        final int newNans = nanOut.intValue();
        final int newNormal = normalOut.intValue();
        final double newSumOfWeights = sumOfWeightsOut.doubleValue();
        final double newWeightedSum = weightedSumOut.doubleValue();


        final long totalNans;
        if (nanCount == null && newNans > 0) {
            totalNans = allocateNans(destination, newNans);
        } else if (nanCount != null) {
            final long oldNans = nanCount.getUnsafe(destination);
            totalNans = NullSafeAddition.plusLong(oldNans, newNans - prevNans);
            if (newNans != prevNans) {
                nanCount.set(destination, totalNans);
            }
        } else {
            totalNans = 0;
        }

        final long totalNormal;
        final long existingNormal = normalCount.getUnsafe(destination);
        totalNormal = NullSafeAddition.plusLong(existingNormal, newNormal - prevNormal);
        Assert.geq(totalNormal, "totalNormal", newNormal, "newNormal");
        if (newNormal != prevNormal) {
            normalCount.set(destination, totalNormal);
        }

        if (totalNormal > 0) {
            final double existingSumOfWeights = sumOfWeights.getUnsafe(destination);
            final double existingWeightedSum = weightedSum.getUnsafe(destination);

            final double totalWeightedSum =
                    NullSafeAddition.plusDouble(existingWeightedSum, newWeightedSum - prevWeightedSum);
            final double totalSumOfWeights =
                    NullSafeAddition.plusDouble(existingSumOfWeights, newSumOfWeights - prevSumOfWeights);

            if (totalWeightedSum != existingWeightedSum) {
                weightedSum.set(destination, totalWeightedSum);
            }
            if (totalSumOfWeights != existingWeightedSum) {
                sumOfWeights.set(destination, totalSumOfWeights);
            }

            if (totalNans > 0) {
                resultColumn.set(destination, Double.NaN);
                return prevNans == 0;
            } else {
                final double newResult = totalWeightedSum / totalSumOfWeights;
                final double existingResult = resultColumn.getAndSetUnsafe(destination, newResult);
                return newResult != existingResult;
            }
        } else {
            if (prevNormal > 0) {
                weightedSum.set(destination, 0.0);
                sumOfWeights.set(destination, 0.0);
            }
            if (totalNans == 0) {
                if (prevNans > 0 || prevNormal > 0) {
                    resultColumn.set(destination, QueryConstants.NULL_DOUBLE);
                    return true;
                }
                return false;
            } else {
                if (prevNans == 0) {
                    resultColumn.set(destination, Double.NaN);
                    return true;
                }
                return false;
            }
        }
    }

    @Override
    public void ensureCapacity(long tableSize) {
        this.tableSize = tableSize;
        if (nanCount != null) {
            nanCount.ensureCapacity(tableSize);
        }
        normalCount.ensureCapacity(tableSize);
        weightedSum.ensureCapacity(tableSize);
        sumOfWeights.ensureCapacity(tableSize);
        resultColumn.ensureCapacity(tableSize);
    }

    @Override
    public Map> getResultColumns() {
        if (exposeInternalColumns) {
            final Map> results = new LinkedHashMap<>(2);
            results.put(resultName, resultColumn);
            results.put(resultName + ROLLUP_SUM_WEIGHTS_COLUMN_ID + ROLLUP_COLUMN_SUFFIX, sumOfWeights);
            return results;
        } else {
            return Collections.singletonMap(resultName, resultColumn);
        }
    }

    @Override
    public void startTrackingPrevValues() {
        resultColumn.startTrackingPrevValues();
        if (exposeInternalColumns) {
            sumOfWeights.startTrackingPrevValues();
        }
    }

    private class Context implements BucketedContext, SingletonContext {
        private final ToDoubleCast toDoubleCast;
        private final ToDoubleCast prevToDoubleCast;

        private Context(int size) {
            toDoubleCast = ToDoubleCast.makeToDoubleCast(chunkType, size);
            prevToDoubleCast = ToDoubleCast.makeToDoubleCast(chunkType, size);
        }

        @Override
        public void close() {
            toDoubleCast.close();
            prevToDoubleCast.close();
        }
    }

    @Override
    public BucketedContext makeBucketedContext(int size) {
        return new Context(size);
    }

    @Override
    public SingletonContext makeSingletonContext(int size) {
        return new Context(size);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy