com.hazelcast.org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.org.apache.calcite.rel.rules;
import com.hazelcast.org.apache.calcite.linq4j.Ord;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalAggregate;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.fun.SqlInternalOperators;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.org.apache.calcite.util.mapping.IntPair;
import com.hazelcast.com.google.common.collect.ArrayListMultimap;
import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.Multimap;
import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;
import static com.hazelcast.org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule.groupValue;
import static com.hazelcast.org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule.remap;
/**
* Planner rule that rewrites an {@link Aggregate} that contains
* {@code WITHIN DISTINCT} aggregate functions.
*
* For example,
*
* SELECT o.paymentType,
* COUNT(*) as "count",
* COUNT(*) WITHIN DISTINCT (o.orderId) AS orderCount,
* SUM(o.shipping) WITHIN DISTINCT (o.orderId) as sumShipping,
* SUM(i.units) as sumUnits
* FROM Orders AS o
* JOIN OrderItems AS i USING (orderId)
* GROUP BY o.paymentType
*
*
* becomes
*
*
*
* SELECT paymentType,
* COUNT(*) as "count",
* COUNT(*) FILTER (WHERE g = 0) AS orderCount,
* SUM(minShipping) FILTER (WHERE g = 0) AS sumShipping,
* SUM(sumUnits) FILTER (WHERE g = 1) as sumUnits
* FROM (
* SELECT o.paymentType,
* GROUPING(o.orderId) AS g,
* SUM(o.shipping) AS sumShipping,
* MIN(o.shipping) AS minShipping,
* SUM(i.units) AS sumUnits
* FROM Orders AS o
* JOIN OrderItems ON o.orderId = i.orderId
* GROUP BY GROUPING SETS ((o.paymentType), (o.paymentType, o.orderId)))
* GROUP BY o.paymentType
*
*
* By the way, note that {@code COUNT(*) WITHIN DISTINCT (o.orderId)}
* is identical to {@code COUNT(DISTINCT o.orderId)}.
* {@code WITHIN DISTINCT} is a generalization of aggregate(DISTINCT).
* So, it is perhaps not surprising that the rewrite to {@code GROUPING SETS}
* is similar.
*
*
If there are multiple arguments
* (e.g. {@code SUM(a) WITHIN DISTINCT (x), SUM(a) WITHIN DISTINCT (y)})
* the rule creates separate {@code GROUPING SET}s.
*/
@Value.Enclosing
public class AggregateExpandWithinDistinctRule
extends RelRule {
/** Creates an AggregateExpandWithinDistinctRule. */
protected AggregateExpandWithinDistinctRule(Config config) {
super(config);
}
private static boolean hasWithinDistinct(Aggregate aggregate) {
return aggregate.getAggCallList().stream()
.anyMatch(c -> c.distinctKeys != null)
// Wait until AggregateReduceFunctionsRule has dealt with AVG etc.
&& aggregate.getAggCallList().stream()
.noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce)
// Don't think we can handle GROUPING SETS yet
&& aggregate.getGroupType() == Aggregate.Group.SIMPLE;
}
//~ Methods ----------------------------------------------------------------
@Override public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
// Throughout this method, assume we are working on the following SQL:
//
// SELECT deptno, SUM(sal), SUM(sal) WITHIN DISTINCT (job)
// FROM emp
// GROUP BY deptno
//
// or in algebra,
//
// Aggregate($0, SUM($2), SUM($2) WITHIN DISTINCT ($4))
// Scan(emp)
//
// We plan to generate the following:
//
// SELECT deptno, SUM(sal), SUM(sal) WITHIN DISTINCT (job)
// FROM (
// SELECT deptno, GROUPING(deptno, job), SUM(sal), MIN(sal)
// FROM emp
// GROUP BY GROUPING SETS ((deptno), (deptno, job)))
// GROUP BY deptno
//
// This rewrite also handles DISTINCT aggregates. We treat
//
// SUM(DISTINCT sal)
// SUM(DISTINCT sal) WITHIN DISTINCT (job)
//
// as if the user had written
//
// SUM(sal) WITHIN DISTINCT (sal)
//
final List aggCallList =
aggregate.getAggCallList()
.stream()
.map(c -> unDistinct(c, aggregate.getInput()::fieldIsNullable))
.collect(Util.toImmutableList());
// Find all within-distinct expressions.
final Multimap argLists =
ArrayListMultimap.create();
// A bit set that represents an aggregate with no WITHIN DISTINCT.
// Different from "WITHIN DISTINCT ()".
final ImmutableBitSet notDistinct =
ImmutableBitSet.of(aggregate.getInput().getRowType().getFieldCount());
for (AggregateCall aggCall : aggCallList) {
ImmutableBitSet distinctKeys = aggCall.distinctKeys;
if (distinctKeys == null) {
distinctKeys = notDistinct;
} else {
if (distinctKeys.intersects(aggregate.getGroupSet())) {
// Remove group keys. E.g.
// sum(x) within distinct (y, z) ... group by y
// can be simplified to
// sum(x) within distinct (z) ... group by y
// Note that this assumes a single grouping set for the original agg.
distinctKeys = distinctKeys.rebuild()
.removeAll(aggregate.getGroupSet()).build();
}
}
argLists.put(distinctKeys, aggCall);
}
// Compute the set of all grouping sets that will be used in the output
// query. For each WITHIN DISTINCT aggregate call, we will need a grouping
// set that is the union of the aggregate call's unique keys and the input
// query's overall grouping. Redundant grouping sets can be reused for
// multiple aggregate calls.
final Set groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
for (ImmutableBitSet key : argLists.keySet()) {
groupSetTreeSet.add(
(key == notDistinct)
? aggregate.getGroupSet()
: ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
}
final ImmutableList groupSets =
ImmutableList.copyOf(groupSetTreeSet);
final boolean hasMultipleGroupSets = groupSets.size() > 1;
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final Set fullGroupOrderedSet = new LinkedHashSet<>();
fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet());
fullGroupOrderedSet.addAll(fullGroupSet.asSet());
final ImmutableIntList fullGroupList =
ImmutableIntList.copyOf(fullGroupOrderedSet);
// Build the inner query
//
// SELECT deptno, SUM(sal) AS sum_sal, MIN(sal) AS min_sal,
// MAX(sal) AS max_sal, GROUPING(deptno, job) AS g
// FROM emp
// GROUP BY GROUPING SETS ((deptno), (deptno, job))
//
// or in algebra,
//
// Aggregate([($0), ($0, $4)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
// Scan(emp)
final RelBuilder b = call.builder();
b.push(aggregate.getInput());
final List aggCalls = new ArrayList<>();
// Helper class for building the inner query.
// CHECKSTYLE: IGNORE 1
class Registrar {
final int g = fullGroupSet.cardinality();
/** Map of input fields (below the original aggregation) and filter args
* to inner query aggregate calls. */
final Map args = new HashMap<>();
/** Map of aggregate calls from the original aggregation to inner query
* aggregate calls. */
final Map aggs = new HashMap<>();
/** Map of aggregate calls from the original aggregation to inner-query
* {@code COUNT(*)} calls, which are only needed for filters in the outer
* aggregate when the original aggregate call does not ignore null
* inputs. */
final Map counts = new HashMap<>();
List fields(List fields, int filterArg) {
return Util.transform(fields, f -> this.field(f, filterArg));
}
int field(int field, int filterArg) {
return Objects.requireNonNull(args.get(IntPair.of(field, filterArg)));
}
/** Computes an aggregate call argument's values for a
* {@code WITHIN DISTINCT} aggregate call.
*
* For example, to compute
* {@code SUM(x) WITHIN DISTINCT (y) GROUP BY (z)},
* the inner aggregate must first group {@code x} by {@code (y, z)}
* — using {@code MIN} to select the (hopefully) unique value of
* {@code x} for each {@code (y, z)} group. Actually summing over the
* grouped {@code x} values must occur in an outer aggregate.
*
* @param field Index of an input field that's used in a
* {@code WITHIN DISTINCT} aggregate call
* @param filterArg Filter arg used in the original aggregate call, or
* {@code -1} if there is no filter. We use the same filter in
* the inner query.
* @return Index of the inner query aggregate call representing the
* grouped field, which can be referenced in the outer query
* aggregate call
*/
int register(int field, int filterArg) {
return args.computeIfAbsent(IntPair.of(field, filterArg), j -> {
final int ordinal = g + aggCalls.size();
RelBuilder.AggCall groupedField =
b.aggregateCall(SqlStdOperatorTable.MIN, b.field(field));
aggCalls.add(
filterArg < 0
? groupedField
: groupedField.filter(b.field(filterArg)));
if (config.throwIfNotUnique()) {
groupedField =
b.aggregateCall(SqlStdOperatorTable.MAX, b.field(field));
aggCalls.add(
filterArg < 0
? groupedField
: groupedField.filter(b.field(filterArg)));
}
return ordinal;
});
}
/** Registers an aggregate call that is not a
* {@code WITHIN DISTINCT} call.
*
*
Unlike the case handled by {@link #register(int, int)} above,
* aggregate calls without any distinct keys do not need a second round
* of aggregation in the outer query, so they can be computed "as-is" in
* the inner query.
*
* @param i Index of the aggregate call in the original aggregation
* @param aggregateCall Original aggregate call
* @return Index of the aggregate call in the computed inner query
*/
int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
final int ordinal = g + aggCalls.size();
aggs.put(i, ordinal);
aggCalls.add(aggregateCall);
return ordinal;
}
int getAgg(int i) {
return Objects.requireNonNull(aggs.get(i));
}
/** Registers an extra {@code COUNT} aggregate call when it's needed to
* filter out null inputs in the outer aggregate.
*
*
This should only be called for aggregate calls with filters. It's
* possible that the filter would eliminate all input rows to the
* {@code MIN} call in the inner query, so calls in the outer
* aggregate may need to be aware of this. See usage of
* {@link AggregateExpandWithinDistinctRule#mustBeCounted(AggregateCall)}.
*
* @param filterArg The original aggregate call's filter; must be
* non-negative
* @return Index of the {@code COUNT} call in the computed inner query
*/
int registerCount(int filterArg) {
assert filterArg >= 0;
return counts.computeIfAbsent(filterArg, i -> {
final int ordinal = g + aggCalls.size();
aggCalls.add(b.aggregateCall(SqlStdOperatorTable.COUNT)
.filter(b.field(filterArg)));
return ordinal;
});
}
int getCount(int filterArg) {
return Objects.requireNonNull(counts.get(filterArg));
}
}
final Registrar registrar = new Registrar();
Ord.forEach(aggCallList, (c, i) -> {
if (c.distinctKeys == null) {
registrar.registerAgg(i,
b.aggregateCall(c.getAggregation(),
b.fields(c.getArgList())));
} else {
for (int inputIdx : c.getArgList()) {
registrar.register(inputIdx, c.filterArg);
}
if (mustBeCounted(c)) {
registrar.registerCount(c.filterArg);
}
}
});
// Add an additional GROUPING() aggregate call so we can select only the
// relevant inner-aggregate rows from the outer aggregate. If there is only
// 1 grouping set (i.e. every aggregate call has the same distinct keys),
// no GROUPING() call is necessary.
final int grouping =
hasMultipleGroupSets
? registrar.registerAgg(-1,
b.aggregateCall(
SqlStdOperatorTable.GROUPING,
b.fields(fullGroupList)))
: -1;
b.aggregate(b.groupKey(fullGroupSet, groupSets), aggCalls);
// Build the outer query
//
// SELECT deptno,
// MIN(sum_sal) FILTER (g = 0),
// SUM(min_sal) FILTER (g = 1)
// FROM ( ... )
// GROUP BY deptno
//
// or in algebra,
//
// Aggregate($0, SUM($2 WHERE $4 = 0), SUM($3 WHERE $4 = 1))
// Aggregate([($0), ($0, $2)], SUM($2), MIN($2), GROUPING($0, $4))
// Scan(emp)
//
// If throwIfNotUnique, the "SUM(min_sal) FILTER (g = 1)" term above becomes
//
// SUM(min_sal) FILTER (
// $THROW_UNLESS(g <> 1 OR min_sal IS NOT DISTINCT FROM max_sal,
// 'more than one distinct value in agg UNIQUE_VALUE')
// AND g = 1)
aggCalls.clear();
Ord.forEach(aggCallList, (c, i) -> {
final List filters = new ArrayList<>();
RexNode groupFilter = null;
if (hasMultipleGroupSets) {
groupFilter =
b.equals(
b.field(grouping),
b.literal(
groupValue(fullGroupList, union(aggregate.getGroupSet(), c.distinctKeys))));
filters.add(groupFilter);
}
RelBuilder.AggCall aggCall;
if (c.distinctKeys == null) {
aggCall = b.aggregateCall(SqlStdOperatorTable.MIN,
b.field(registrar.getAgg(i)));
} else {
// The inputs to this aggregate are outputs from MIN() calls from the
// inner agg, and MIN() returns null iff it has no non-null inputs,
// which can only happen if an original aggregate's filter causes all
// non-null input rows to be discarded for a particular group in the
// inner aggregate. In this case, it should be ignored by the outer
// aggregate as well. In case the aggregate call does not naturally
// ignore null inputs, we add a filter based on a COUNT() in the inner
// aggregate.
aggCall =
b.aggregateCall(
c.getAggregation(),
b.fields(registrar.fields(c.getArgList(), c.filterArg)));
if (mustBeCounted(c)) {
filters.add(b.greaterThan(b.field(registrar.getCount(c.filterArg)), b.literal(0)));
}
if (config.throwIfNotUnique()) {
for (int j : c.getArgList()) {
RexNode isUniqueCondition =
b.isNotDistinctFrom(
b.field(registrar.field(j, c.filterArg)),
b.field(registrar.field(j, c.filterArg) + 1));
if (groupFilter != null) {
isUniqueCondition = b.or(b.not(groupFilter), isUniqueCondition);
}
String message = "more than one distinct value in agg UNIQUE_VALUE";
filters.add(
b.call(SqlInternalOperators.THROW_UNLESS, isUniqueCondition, b.literal(message)));
}
}
}
if (filters.size() > 0) {
aggCall = aggCall.filter(b.and(filters));
}
aggCalls.add(aggCall);
});
b.aggregate(
b.groupKey(
remap(fullGroupSet, aggregate.getGroupSet()),
remap(fullGroupSet, aggregate.getGroupSets())),
aggCalls);
b.convert(aggregate.getRowType(), false);
call.transformTo(b.build());
}
private static boolean mustBeCounted(AggregateCall aggCall) {
// Always count filtered inner aggregates to be safe.
//
// It's possible that, for some aggregate calls (namely, those that
// completely ignore null inputs), we could neglect counting the
// grouped-and-filtered rows of the inner aggregate and filtering the empty
// ones out from the outer aggregate, since those empty groups would produce
// null values as the result of MIN and thus be ignored by the outer
// aggregate anyway.
//
// Note that using "aggCall.ignoreNulls()" is not sufficient to determine
// when it's safe to do this, since for COUNT the value of ignoreNulls()
// should generally be true even though COUNT(*) will never ignore anything.
return aggCall.hasFilter();
}
/** Converts a {@code DISTINCT} aggregate call into an equivalent one with
* {@code WITHIN DISTINCT}.
*
* Examples:
*
* - {@code SUM(DISTINCT x)} →
* {@code SUM(x) WITHIN DISTINCT (x)} has distinct key (x);
*
- {@code SUM(DISTINCT x)} WITHIN DISTINCT (y) →
* {@code SUM(x) WITHIN DISTINCT (x)} has distinct key (x);
*
- {@code SUM(x)} WITHIN DISTINCT (y, z) has distinct key (y, z);
*
- {@code SUM(x)} has no distinct key.
*
*/
private static AggregateCall unDistinct(AggregateCall aggregateCall,
IntPredicate isNullable) {
if (aggregateCall.isDistinct()) {
final List newArgList = aggregateCall.getArgList()
.stream()
.filter(i ->
aggregateCall.getAggregation().getKind() != SqlKind.COUNT
|| aggregateCall.hasFilter()
|| isNullable.test(i))
.collect(Collectors.toList());
return aggregateCall.withDistinct(false)
.withDistinctKeys(ImmutableBitSet.of(aggregateCall.getArgList()))
.withArgList(newArgList);
}
return aggregateCall;
}
private static ImmutableBitSet union(ImmutableBitSet s0,
@Nullable ImmutableBitSet s1) {
return s1 == null ? s0 : s0.union(s1);
}
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableAggregateExpandWithinDistinctRule.Config.of()
.withOperandSupplier(b -> b.operand(LogicalAggregate.class)
.predicate(AggregateExpandWithinDistinctRule::hasWithinDistinct)
.anyInputs());
@Override default AggregateExpandWithinDistinctRule toRule() {
return new AggregateExpandWithinDistinctRule(this);
}
/** Whether the code generated by the rule should throw if the arguments
* are not functionally dependent.
*
* For example, if implementing {@code SUM(sal) WITHIN DISTINCT job)} ...
* {@code GROUP BY deptno},
* suppose that within department 10, (job, sal) has the values
* ('CLERK', 100), ('CLERK', 120), ('MANAGER', 150), ('MANAGER', 150). If
* {@code throwIfNotUnique} is true, the query would throw because of the
* values [100, 120]; if false, the query would sum the distinct values
* [100, 120, 150]. */
@Value.Default default boolean throwIfNotUnique() {
return true;
}
/** Sets {@link #throwIfNotUnique()}. */
Config withThrowIfNotUnique(boolean throwIfNotUnique);
}
}