All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.deephaven.server.table.ops.ComboAggregateGrpcImpl Maven / Gradle / Ivy

The newest version!
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.server.table.ops;

import com.google.rpc.Code;
import io.deephaven.api.ColumnName;
import io.deephaven.api.agg.Aggregation;
import io.deephaven.api.util.NameValidator;
import io.deephaven.auth.codegen.impl.TableServiceContextualAuthWiring;
import io.deephaven.base.verify.Assert;
import io.deephaven.engine.table.ColumnDefinition;
import io.deephaven.engine.table.Table;
import io.deephaven.proto.backplane.grpc.BatchTableRequest;
import io.deephaven.proto.backplane.grpc.ComboAggregateRequest;
import io.deephaven.proto.util.Exceptions;
import io.deephaven.server.session.SessionState;
import io.grpc.StatusRuntimeException;
import org.jetbrains.annotations.NotNull;

import javax.inject.Inject;
import javax.inject.Singleton;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static io.deephaven.api.agg.Aggregation.*;

@Singleton
public class ComboAggregateGrpcImpl extends GrpcTableOperation {

    @Inject
    public ComboAggregateGrpcImpl(final TableServiceContextualAuthWiring authWiring) {
        super(authWiring::checkPermissionComboAggregate, BatchTableRequest.Operation::getComboAggregate,
                ComboAggregateRequest::getResultId, ComboAggregateRequest::getSourceId);
    }

    @Override
    public void validateRequest(ComboAggregateRequest request) throws StatusRuntimeException {
        if (request.getAggregatesCount() == 0) {
            throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                    "ComboAggregateRequest incorrectly has zero aggregates provided");
        }
        for (String groupByColumn : request.getGroupByColumnsList()) {
            if (!NameValidator.isValidColumnName(groupByColumn)) {
                throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT, "ComboAggregateRequest group by");
            }
        }
        if (isSimpleAggregation(request)) {
            // this is a simple aggregation, make sure the user didn't mistakenly set extra properties
            // which would suggest they meant to set force_combo=true
            ComboAggregateRequest.Aggregate aggregate = request.getAggregates(0);
            if (aggregate.getMatchPairsCount() != 0) {
                throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                        "force_combo is false and only one aggregate provided, but match_pairs is specified");
            }
            if (aggregate.getPercentile() != 0) {
                throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                        "force_combo is false and only one aggregate provided, but percentile is specified");
            }
            if (aggregate.getAvgMedian()) {
                throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                        "force_combo is false and only one aggregate provided, but avg_median is specified");
            }
            if (aggregate.getType() != ComboAggregateRequest.AggType.COUNT
                    && aggregate.getType() != ComboAggregateRequest.AggType.WEIGHTED_AVG) {
                if (!aggregate.getColumnName().isEmpty()) {
                    throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                            "force_combo is false and only one aggregate provided, but column_name is specified for type other than COUNT or WEIGHTED_AVG");
                }
            }
        } else {
            for (ComboAggregateRequest.Aggregate aggregate : request.getAggregatesList()) {
                if (aggregate.getType() != ComboAggregateRequest.AggType.PERCENTILE) {
                    if (aggregate.getPercentile() != 0) {
                        throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                                "percentile is specified for type " + aggregate.getType());
                    }
                    if (aggregate.getAvgMedian()) {
                        throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                                "avg_median is specified for type " + aggregate.getType());
                    }
                }
                if (aggregate.getType() == ComboAggregateRequest.AggType.COUNT) {
                    if (aggregate.getMatchPairsCount() != 0) {
                        throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                                "match_pairs is specified for type COUNT");
                    }
                }
                if (aggregate.getType() != ComboAggregateRequest.AggType.COUNT
                        && aggregate.getType() != ComboAggregateRequest.AggType.WEIGHTED_AVG) {
                    if (!aggregate.getColumnName().isEmpty()) {
                        throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                                "column_name is specified for type " + aggregate.getType());
                    }
                }
            }
        }
    }

    private boolean isSimpleAggregation(ComboAggregateRequest request) {
        return !request.getForceCombo() && request.getAggregatesCount() == 1
                && request.getAggregates(0).getColumnName().isEmpty()
                && request.getAggregates(0).getType() != ComboAggregateRequest.AggType.PERCENTILE
                && request.getAggregates(0).getMatchPairsCount() == 0;
    }

    @Override
    public Table create(final ComboAggregateRequest request,
            final List> sourceTables) {
        Assert.eq(sourceTables.size(), "sourceTables.size()", 1);
        final Table parent = sourceTables.get(0).get();
        final ColumnName[] groupByColumns = request.getGroupByColumnsList()
                .stream()
                .map(ColumnName::of)
                .toArray(ColumnName[]::new);
        if (isSimpleAggregation(request)) {
            // This is a special case with a special operator that can be invoked right off of the table api.
            return singleAggregateHelper(parent, groupByColumns, request.getAggregates(0));
        } else {
            return comboAggregateHelper(parent, groupByColumns, request.getAggregatesList());
        }
    }

    private static Table singleAggregateHelper(final Table parent, final ColumnName[] groupByColumns,
            final ComboAggregateRequest.Aggregate aggregate) {
        switch (aggregate.getType()) {
            case SUM:
                return parent.sumBy(groupByColumns);
            case ABS_SUM:
                return parent.absSumBy(groupByColumns);
            case GROUP:
                return parent.groupBy(Arrays.asList(groupByColumns));
            case AVG:
                return parent.avgBy(groupByColumns);
            case COUNT:
                return parent.countBy(aggregate.getColumnName(), groupByColumns);
            case FIRST:
                return parent.firstBy(groupByColumns);
            case LAST:
                return parent.lastBy(groupByColumns);
            case MIN:
                return parent.minBy(groupByColumns);
            case MAX:
                return parent.maxBy(groupByColumns);
            case MEDIAN:
                return parent.medianBy(groupByColumns);
            case STD:
                return parent.stdBy(groupByColumns);
            case VAR:
                return parent.varBy(groupByColumns);
            case WEIGHTED_AVG:
                return parent.wavgBy(aggregate.getColumnName(), groupByColumns);
            default:
                throw new UnsupportedOperationException("Unsupported aggregate: " + aggregate.getType());
        }
    }

    private static Table comboAggregateHelper(final Table parent, final ColumnName[] groupByColumns,
            final List aggregates) {
        final Set groupByColumnSet =
                Arrays.stream(groupByColumns).map(ColumnName::name).collect(Collectors.toSet());
        final Function getPairs =
                agg -> getColumnPairs(parent, groupByColumnSet, agg);

        final Collection aggregations = aggregates.stream().map(
                agg -> makeAggregation(agg, getPairs)).collect(Collectors.toList());

        return parent.aggBy(aggregations, Arrays.asList(groupByColumns));
    }

    private static String[] getColumnPairs(@NotNull final Table parent,
            @NotNull final Set groupByColumnSet,
            @NotNull final ComboAggregateRequest.Aggregate agg) {
        // See io.deephaven.qst.table.AggAllByExclusions
        if (agg.getMatchPairsCount() == 0) {
            // If not specified, we apply the aggregate to all columns not "otherwise involved"
            return parent.getDefinition().getColumnStream()
                    .map(ColumnDefinition::getName)
                    .filter(n -> !(groupByColumnSet.contains(n) ||
                            (agg.getType() == ComboAggregateRequest.AggType.WEIGHTED_AVG
                                    && agg.getColumnName().equals(n))))
                    .toArray(String[]::new);
        }
        return agg.getMatchPairsList().toArray(String[]::new);
    }

    private static Aggregation makeAggregation(
            @NotNull final ComboAggregateRequest.Aggregate agg,
            @NotNull final Function getPairs) {
        switch (agg.getType()) {
            case SUM:
                return AggSum(getPairs.apply(agg));
            case ABS_SUM:
                return AggAbsSum(getPairs.apply(agg));
            case GROUP:
                return AggGroup(getPairs.apply(agg));
            case AVG:
                return Aggregation.AggAvg(getPairs.apply(agg));
            case COUNT:
                return Aggregation.AggCount(agg.getColumnName());
            case FIRST:
                return Aggregation.AggFirst(getPairs.apply(agg));
            case LAST:
                return Aggregation.AggLast(getPairs.apply(agg));
            case MIN:
                return Aggregation.AggMin(getPairs.apply(agg));
            case MAX:
                return Aggregation.AggMax(getPairs.apply(agg));
            case MEDIAN:
                return Aggregation.AggMed(getPairs.apply(agg));
            case PERCENTILE:
                return Aggregation.AggPct(agg.getPercentile(), agg.getAvgMedian(), getPairs.apply(agg));
            case STD:
                return Aggregation.AggStd(getPairs.apply(agg));
            case VAR:
                return Aggregation.AggVar(getPairs.apply(agg));
            case WEIGHTED_AVG:
                return Aggregation.AggWAvg(agg.getColumnName(), getPairs.apply(agg));
            default:
                throw new UnsupportedOperationException("Unsupported aggregate: " + agg.getType());
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy