org.apache.flink.table.planner.plan.rules.logical.FlinkAggregateExpandDistinctAggregatesRule Maven / Gradle / Ivy
Show all versions of flink-table-planner-blink_2.11 Show documentation
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.rules.logical;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.util.Preconditions;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
/**
* This rules is copied from Calcite's {@link
* org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}. Modification: - Throws an
* exception if an aggregate contains both approximate distinct aggregate call and accurate distinct
* aggregate call. - Excludes non-simple aggregate(e.g. CUBE, ROLLUP). - Fix bug: Some aggregate
* functions (e.g. COUNT) has a non-null result even without any input. - Fix bug: Add filter
* argument into rewritten aggregateCall if its filter argument is not -1.
*/
/**
* Planner rule that expands distinct aggregates (such as {@code COUNT(DISTINCT x)}) from a {@link
* org.apache.calcite.rel.core.Aggregate}.
*
* How this is done depends upon the arguments to the function. If all functions have the same
* argument (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument {@code x}) then
* one extra {@link org.apache.calcite.rel.core.Aggregate} is sufficient.
*
*
If there are multiple arguments (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)}) the rule
* creates separate {@code Aggregate}s and combines using a {@link
* org.apache.calcite.rel.core.Join}.
*/
public final class FlinkAggregateExpandDistinctAggregatesRule extends RelOptRule {
// ~ Static fields/initializers ---------------------------------------------
/** The default instance of the rule; operates only on logical expressions. */
public static final FlinkAggregateExpandDistinctAggregatesRule INSTANCE =
new FlinkAggregateExpandDistinctAggregatesRule(
LogicalAggregate.class, true, RelFactories.LOGICAL_BUILDER);
/** Instance of the rule that operates only on logical expressions and generates a join. */
public static final FlinkAggregateExpandDistinctAggregatesRule JOIN =
new FlinkAggregateExpandDistinctAggregatesRule(
LogicalAggregate.class, false, RelFactories.LOGICAL_BUILDER);
public final boolean useGroupingSets;
// ~ Constructors -----------------------------------------------------------
public FlinkAggregateExpandDistinctAggregatesRule(
Class clazz,
boolean useGroupingSets,
RelBuilderFactory relBuilderFactory) {
super(operand(clazz, any()), relBuilderFactory, null);
this.useGroupingSets = useGroupingSets;
}
@Deprecated // to be removed before 2.0
public FlinkAggregateExpandDistinctAggregatesRule(
Class clazz,
boolean useGroupingSets,
RelFactories.JoinFactory joinFactory) {
this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of(joinFactory)));
}
@Deprecated // to be removed before 2.0
public FlinkAggregateExpandDistinctAggregatesRule(
Class clazz, RelFactories.JoinFactory joinFactory) {
this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory)));
}
// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (!AggregateUtil.containsAccurateDistinctCall(aggregate.getAggCallList())) {
return;
}
// Check unsupported aggregate which contains both approximate distinct call and
// accurate distinct call.
if (AggregateUtil.containsApproximateDistinctCall(aggregate.getAggCallList())) {
throw new TableException(
"There are both Distinct AggCall and Approximate Distinct AggCall in one sql statement, "
+ "it is not supported yet.\nPlease choose one of them.");
}
// If this aggregate is a non-simple aggregate(e.g. CUBE, ROLLUP)
// and contains distinct calls, it should be transformed to simple aggregate first
// by DecomposeGroupingSetsRule. Then this rule expands it's distinct aggregates.
if (aggregate.getGroupSets().size() > 1) {
return;
}
// Find all of the agg expressions. We use a LinkedHashSet to ensure determinism.
// Find all aggregate calls without distinct
int nonDistinctAggCallCount = 0;
// Find all aggregate calls without distinct but ignore MAX, MIN, BIT_AND, BIT_OR
int nonDistinctAggCallExcludingIgnoredCount = 0;
int filterCount = 0;
int unsupportedNonDistinctAggCallCount = 0;
final Set, Integer>> argLists = new LinkedHashSet<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (aggCall.filterArg >= 0) {
++filterCount;
}
if (!aggCall.isDistinct()) {
++nonDistinctAggCallCount;
final SqlKind aggCallKind = aggCall.getAggregation().getKind();
// We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization
switch (aggCallKind) {
case COUNT:
case SUM:
case SUM0:
case MIN:
case MAX:
break;
default:
++unsupportedNonDistinctAggCallCount;
}
if (aggCall.getAggregation().getDistinctOptionality() == Optionality.IGNORED) {
argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
} else {
++nonDistinctAggCallExcludingIgnoredCount;
}
} else {
argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
}
}
final int distinctAggCallCount =
aggregate.getAggCallList().size() - nonDistinctAggCallCount;
Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
// If all of the agg expressions are distinct and have the same
// arguments then we can use a more efficient form.
// MAX, MIN, BIT_AND, BIT_OR always ignore distinct attribute,
// when they are mixed in with other distinct agg calls,
// we can still use this promotion.
if (nonDistinctAggCallExcludingIgnoredCount == 0
&& argLists.size() == 1
&& aggregate.getGroupType() == Group.SIMPLE) {
final Pair, Integer> pair =
com.google.common.collect.Iterables.getOnlyElement(argLists);
final RelBuilder relBuilder = call.builder();
convertMonopole(relBuilder, aggregate, pair.left, pair.right);
call.transformTo(relBuilder.build());
return;
}
if (useGroupingSets) {
rewriteUsingGroupingSets(call, aggregate);
return;
}
// If only one distinct aggregate and one or more non-distinct aggregates,
// we can generate multi-phase aggregates
if (distinctAggCallCount == 1 // one distinct aggregate
&& filterCount == 0 // no filter
&& unsupportedNonDistinctAggCallCount
== 0 // sum/min/max/count in non-distinct aggregate
&& nonDistinctAggCallCount > 0) { // one or more non-distinct aggregates
final RelBuilder relBuilder = call.builder();
convertSingletonDistinct(relBuilder, aggregate, argLists);
call.transformTo(relBuilder.build());
return;
}
// Create a list of the expressions which will yield the final result.
// Initially, the expressions point to the input field.
final List aggFields = aggregate.getRowType().getFieldList();
final List refs = new ArrayList<>();
final List fieldNames = aggregate.getRowType().getFieldNames();
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final int groupCount = aggregate.getGroupCount();
for (int i : Util.range(groupCount)) {
refs.add(RexInputRef.of(i, aggFields));
}
// Aggregate the original relation, including any non-distinct aggregates.
final List newAggCallList = new ArrayList<>();
int i = -1;
for (AggregateCall aggCall : aggregate.getAggCallList()) {
++i;
if (aggCall.isDistinct()) {
refs.add(null);
continue;
}
refs.add(
new RexInputRef(
groupCount + newAggCallList.size(),
aggFields.get(groupCount + i).getType()));
newAggCallList.add(aggCall);
}
// In the case where there are no non-distinct aggregates (regardless of
// whether there are group bys), there's no need to generate the
// extra aggregate and join.
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
int n = 0;
if (!newAggCallList.isEmpty()) {
final RelBuilder.GroupKey groupKey =
relBuilder.groupKey(groupSet, aggregate.getGroupSets());
relBuilder.aggregate(groupKey, newAggCallList);
++n;
}
// For each set of operands, find and rewrite all calls which have that
// set of operands.
for (Pair, Integer> argList : argLists) {
doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
}
relBuilder.project(refs, fieldNames);
call.transformTo(relBuilder.build());
}
/**
* Converts an aggregate with one distinct aggregate and one or more non-distinct aggregates to
* multi-phase aggregates (see reference example below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*/
private RelBuilder convertSingletonDistinct(
RelBuilder relBuilder,
Aggregate aggregate,
Set, Integer>> argLists) {
// In this case, we are assuming that there is a single distinct function.
// So make sure that argLists is of size one.
Preconditions.checkArgument(argLists.size() == 1);
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List originalAggCalls = aggregate.getAggCallList();
final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
final SortedSet bottomGroupSet = new TreeSet<>();
bottomGroupSet.addAll(aggregate.getGroupSet().asList());
for (AggregateCall aggCall : originalAggCalls) {
if (aggCall.isDistinct()) {
bottomGroupSet.addAll(aggCall.getArgList());
break; // since we only have single distinct call
}
}
// Generate the intermediate aggregate B, the one on the bottom that converts
// a distinct call to group by call.
// Bottom aggregate is the same as the original aggregate, except that
// the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
final List bottomAggregateCalls = new ArrayList<>();
for (AggregateCall aggCall : originalAggCalls) {
// Project the column corresponding to the distinct aggregate. Project
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall =
AggregateCall.create(
aggCall.getAggregation(),
false,
aggCall.isApproximate(),
false,
aggCall.getArgList(),
-1,
RelCollations.EMPTY,
ImmutableBitSet.of(bottomGroupSet).cardinality(),
relBuilder.peek(),
null,
aggCall.name);
bottomAggregateCalls.add(newCall);
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(
aggregate.copy(
aggregate.getTraitSet(),
relBuilder.build(),
ImmutableBitSet.of(bottomGroupSet),
null,
bottomAggregateCalls));
// Add aggregate A (see the reference example above), the top aggregate
// to handle the rest of the aggregation that the bottom aggregate hasn't handled
final List topAggregateCalls =
com.google.common.collect.Lists.newArrayList();
// Use the remapped arguments for the (non)distinct aggregate calls
int nonDistinctAggCallProcessedSoFar = 0;
for (AggregateCall aggCall : originalAggCalls) {
final AggregateCall newCall;
if (aggCall.isDistinct()) {
List newArgList = new ArrayList<>();
for (int arg : aggCall.getArgList()) {
newArgList.add(bottomGroupSet.headSet(arg).size());
}
newCall =
AggregateCall.create(
aggCall.getAggregation(),
false,
aggCall.isApproximate(),
false,
newArgList,
-1,
RelCollations.EMPTY,
originalGroupSet.cardinality(),
relBuilder.peek(),
aggCall.getType(),
aggCall.name);
} else {
// If aggregate B had a COUNT aggregate call the corresponding aggregate at
// aggregate A must be SUM. For other aggregates, it remains the same.
final List newArgs =
com.google.common.collect.Lists.newArrayList(
bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
newCall =
AggregateCall.create(
new SqlSumEmptyIsZeroAggFunction(),
false,
aggCall.isApproximate(),
false,
newArgs,
-1,
RelCollations.EMPTY,
originalGroupSet.cardinality(),
relBuilder.peek(),
aggCall.getType(),
aggCall.getName());
} else {
newCall =
AggregateCall.create(
aggCall.getAggregation(),
false,
aggCall.isApproximate(),
false,
newArgs,
-1,
RelCollations.EMPTY,
originalGroupSet.cardinality(),
relBuilder.peek(),
aggCall.getType(),
aggCall.name);
}
nonDistinctAggCallProcessedSoFar++;
}
topAggregateCalls.add(newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
// The top groupset is basically an identity (first X fields of aggregate B's
// output), minus the distinct aggCall's input.
final Set topGroupSet = new HashSet<>();
int groupSetToAdd = 0;
for (int bottomGroup : bottomGroupSet) {
if (originalGroupSet.get(bottomGroup)) {
topGroupSet.add(groupSetToAdd);
}
groupSetToAdd++;
}
relBuilder.push(
aggregate.copy(
aggregate.getTraitSet(),
relBuilder.build(),
ImmutableBitSet.of(topGroupSet),
null,
topAggregateCalls));
return relBuilder;
}
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
final Set groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
final Map groupSetToDistinctAggCallFilterArg = new HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
ImmutableBitSet groupSet =
ImmutableBitSet.of(aggCall.getArgList())
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
.union(aggregate.getGroupSet());
groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
groupSetTreeSet.add(groupSet);
}
}
final com.google.common.collect.ImmutableList groupSets =
com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List distinctAggCalls = new ArrayList<>();
for (Pair aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
AggregateCall newAggCall =
aggCall.left.adaptTo(
aggregate.getInput(),
aggCall.left.getArgList(),
aggCall.left.filterArg,
aggregate.getGroupCount(),
fullGroupSet.cardinality());
distinctAggCalls.add(newAggCall.rename(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
final Map filters = new LinkedHashMap<>();
final int z = groupCount + distinctAggCalls.size();
distinctAggCalls.add(
AggregateCall.create(
SqlStdOperatorTable.GROUPING,
false,
false,
false,
ImmutableIntList.copyOf(fullGroupSet),
-1,
RelCollations.EMPTY,
groupSets.size(),
relBuilder.peek(),
null,
"$g"));
for (Ord groupSet : Ord.zip(groupSets)) {
filters.put(groupSet.e, z + groupSet.i);
}
relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), distinctAggCalls);
final RelNode distinct = relBuilder.peek();
// GROUPING returns an integer (0 or 1). Add a project to convert those
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
for (Map.Entry entry : filters.entrySet()) {
final long v = groupValue(fullGroupSet, entry.getKey());
// Get and remap the filterArg of the distinct aggregate call.
int distinctAggCallFilterArg =
remap(
fullGroupSet,
groupSetToDistinctAggCallFilterArg.getOrDefault(
entry.getKey(), -1));
RexNode expr;
if (distinctAggCallFilterArg < 0) {
expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
} else {
RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
// merge the filter of the distinct aggregate call itself.
expr =
relBuilder.and(
relBuilder.equals(nodeZ, relBuilder.literal(v)),
rexBuilder.makeCall(
SqlStdOperatorTable.IS_TRUE,
relBuilder.field(distinctAggCallFilterArg)));
}
nodes.add(relBuilder.alias(expr, "$g_" + v));
}
relBuilder.project(nodes);
}
int aggCallIdx = 0;
int x = groupCount;
final List newCalls = new ArrayList<>();
// TODO supports more aggCalls (currently only supports COUNT)
// Some aggregate functions (e.g. COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
final List needDefaultValueAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg = filters.get(aggregate.getGroupSet());
switch (aggCall.getAggregation().getKind()) {
case COUNT:
needDefaultValueAggCalls.add(aggCallIdx);
break;
default:
}
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
newFilterArg =
filters.get(
ImmutableBitSet.of(aggCall.getArgList())
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
.union(aggregate.getGroupSet()));
}
final AggregateCall newCall =
AggregateCall.create(
aggregation,
false,
aggCall.isApproximate(),
false,
newArgList,
newFilterArg,
RelCollations.EMPTY,
aggregate.getGroupCount(),
distinct,
null,
aggCall.name);
newCalls.add(newCall);
aggCallIdx++;
}
relBuilder.aggregate(
relBuilder.groupKey(
remap(fullGroupSet, aggregate.getGroupSet()),
remap(fullGroupSet, aggregate.getGroupSets())),
newCalls);
if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
final Aggregate newAgg = (Aggregate) relBuilder.peek();
final List nodes = new ArrayList<>();
for (int i = 0; i < newAgg.getGroupCount(); ++i) {
nodes.add(RexInputRef.of(i, newAgg.getRowType()));
}
for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
final RexNode inputRef =
RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
RexNode newNode = inputRef;
if (needDefaultValueAggCalls.contains(i)) {
SqlKind originalFunKind =
aggregate.getAggCallList().get(i).getAggregation().getKind();
switch (originalFunKind) {
case COUNT:
newNode =
relBuilder.call(
SqlStdOperatorTable.CASE,
relBuilder.isNotNull(inputRef),
inputRef,
relBuilder.literal(BigDecimal.ZERO));
break;
default:
}
}
nodes.add(newNode);
}
relBuilder.project(nodes);
}
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}
private static long groupValue(ImmutableBitSet fullGroupSet, ImmutableBitSet groupSet) {
long v = 0;
long x = 1L << (fullGroupSet.cardinality() - 1);
assert fullGroupSet.contains(groupSet);
for (int i : fullGroupSet) {
if (!groupSet.get(i)) {
v |= x;
}
x >>= 1;
}
return v;
}
private static ImmutableBitSet remap(ImmutableBitSet groupSet, ImmutableBitSet bitSet) {
final ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
for (Integer bit : bitSet) {
builder.set(remap(groupSet, bit));
}
return builder.build();
}
private static com.google.common.collect.ImmutableList remap(
ImmutableBitSet groupSet, Iterable bitSets) {
final com.google.common.collect.ImmutableList.Builder builder =
com.google.common.collect.ImmutableList.builder();
for (ImmutableBitSet bitSet : bitSets) {
builder.add(remap(groupSet, bitSet));
}
return builder.build();
}
private static List remap(ImmutableBitSet groupSet, List argList) {
ImmutableIntList list = ImmutableIntList.of();
for (int arg : argList) {
list = list.append(remap(groupSet, arg));
}
return list;
}
private static int remap(ImmutableBitSet groupSet, int arg) {
return arg < 0 ? -1 : groupSet.indexOf(arg);
}
/**
* Converts an aggregate relational expression that contains just one distinct aggregate
* function (or perhaps several over the same arguments) and no non-distinct aggregate
* functions.
*/
private RelBuilder convertMonopole(
RelBuilder relBuilder, Aggregate aggregate, List argList, int filterArg) {
// For example,
// SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
// FROM (
// SELECT DISTINCT deptno, sal AS distinct_sal
// FROM EMP GROUP BY deptno)
// GROUP BY deptno
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
final Map sourceOf = new HashMap<>();
createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
// Create an aggregate on top, with the new aggregate list.
final List newAggCalls =
com.google.common.collect.Lists.newArrayList(aggregate.getAggCallList());
rewriteAggCalls(newAggCalls, argList, sourceOf);
final int cardinality = aggregate.getGroupSet().cardinality();
relBuilder.push(
aggregate.copy(
aggregate.getTraitSet(),
relBuilder.build(),
ImmutableBitSet.range(cardinality),
null,
newAggCalls));
return relBuilder;
}
/**
* Converts all distinct aggregate calls to a given set of arguments.
*
* This method is called several times, one for each set of arguments. Each time it is
* called, it generates a JOIN to a new SELECT DISTINCT relational expression, and modifies the
* set of top-level calls.
*
* @param aggregate Original aggregate
* @param n Ordinal of this in a join. {@code relBuilder} contains the input relational
* expression (either the original aggregate, the output from the previous call to this
* method. {@code n} is 0 if we're converting the first distinct aggregate in a query with
* no non-distinct aggregates)
* @param argList Arguments to the distinct aggregate function
* @param filterArg Argument that filters input to aggregate function, or -1
* @param refs Array of expressions which will be the projected by the result of this rule.
* Those relating to this arg list will be modified @return Relational expression
*/
private void doRewrite(
RelBuilder relBuilder,
Aggregate aggregate,
int n,
List argList,
int filterArg,
List refs) {
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List leftFields;
if (n == 0) {
leftFields = null;
} else {
leftFields = relBuilder.peek().getRowType().getFieldList();
}
// Aggregate(
// child,
// {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
//
// becomes
//
// Aggregate(
// Join(
// child,
// Aggregate(child, < all columns > {}),
// INNER,
// ))
//
// E.g.
// SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
// FROM Emps
// GROUP BY deptno
//
// becomes
//
// SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
// FROM (
// SELECT deptno, MAX(age) as max_age
// FROM Emps GROUP BY deptno) AS e
// JOIN (
// SELECT deptno, COUNT(gender) AS count_gender FROM (
// SELECT DISTINCT deptno, gender FROM Emps) AS dgender
// GROUP BY deptno) AS adgender
// ON e.deptno = adgender.deptno
// JOIN (
// SELECT deptno, SUM(sal) AS sum_sal FROM (
// SELECT DISTINCT deptno, sal FROM Emps) AS dsal
// GROUP BY deptno) AS adsal
// ON e.deptno = adsal.deptno
// GROUP BY e.deptno
//
// Note that if a query contains no non-distinct aggregates, then the
// very first join/group by is omitted. In the example above, if
// MAX(age) is removed, then the sub-select of "e" is not needed, and
// instead the two other group by's are joined to one another.
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
final Map sourceOf = new HashMap<>();
createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
// Now compute the aggregate functions on top of the distinct dataset.
// Each distinct agg becomes a non-distinct call to the corresponding
// field from the right; for example,
// "COUNT(DISTINCT e.sal)"
// becomes
// "COUNT(distinct_e.sal)".
final List aggCallList = new ArrayList<>();
final List aggCalls = aggregate.getAggCallList();
final int groupCount = aggregate.getGroupCount();
int i = groupCount - 1;
for (AggregateCall aggCall : aggCalls) {
++i;
// Ignore agg calls which are not distinct or have the wrong set
// arguments. If we're rewriting aggs whose args are {sal}, we will
// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List newArgs = new ArrayList<>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final int newFilterArg = aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1;
final AggregateCall newAggCall =
AggregateCall.create(
aggCall.getAggregation(),
false,
aggCall.isApproximate(),
false,
newArgs,
newFilterArg,
RelCollations.EMPTY,
aggCall.getType(),
aggCall.getName());
assert refs.get(i) == null;
if (n == 0) {
refs.set(i, new RexInputRef(groupCount + aggCallList.size(), newAggCall.getType()));
} else {
refs.set(
i,
new RexInputRef(
leftFields.size() + groupCount + aggCallList.size(),
newAggCall.getType()));
}
aggCallList.add(newAggCall);
}
final Map map = new HashMap<>();
for (Integer key : aggregate.getGroupSet()) {
map.put(key, map.size());
}
final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
assert newGroupSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
relBuilder.push(
aggregate.copy(
aggregate.getTraitSet(),
relBuilder.build(),
newGroupSet,
null,
aggCallList));
// If there's no left child yet, no need to create the join
if (n == 0) {
return;
}
// Create the join condition. It is of the form
// 'left.f0 = right.f0 and left.f1 = right.f1 and ...'
// where {f0, f1, ...} are the GROUP BY fields.
final List distinctFields = relBuilder.peek().getRowType().getFieldList();
final List conditions = com.google.common.collect.Lists.newArrayList();
for (i = 0; i < groupCount; ++i) {
// null values form its own group
// use "is not distinct from" so that the join condition
// allows null values to match.
conditions.add(
rexBuilder.makeCall(
SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
RexInputRef.of(i, leftFields),
new RexInputRef(
leftFields.size() + i, distinctFields.get(i).getType())));
}
// Join in the new 'select distinct' relation.
relBuilder.join(JoinRelType.INNER, conditions);
}
private static void rewriteAggCalls(
List newAggCalls,
List argList,
Map sourceOf) {
// Rewrite the agg calls. Each distinct agg becomes a non-distinct call
// to the corresponding field from the right; for example,
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// Ignore agg calls which are not distinct or have the wrong set
// arguments. If we're rewriting aggregates whose args are {sal}, we will
// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()
&& aggCall.getAggregation().getDistinctOptionality() != Optionality.IGNORED) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List newArgs = new ArrayList<>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final AggregateCall newAggCall =
AggregateCall.create(
aggCall.getAggregation(),
false,
aggCall.isApproximate(),
false,
newArgs,
-1,
RelCollations.EMPTY,
aggCall.getType(),
aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
/**
* Given an {@link org.apache.calcite.rel.core.Aggregate} and the ordinals of the arguments to a
* particular call to an aggregate function, creates a 'select distinct' relational expression
* which projects the group columns and those arguments but nothing else.
*
* For example, given
*
*
*
* select f0, count(distinct f1), count(distinct f2)
* from t group by f0
*
*
*
* and the argument list
*
*
*
* {2}
*
*
*
* returns
*
*
*
* select distinct f0, f2 from t
*
*
*
* The sourceOf
map is populated with the source of each column; in this case
* sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
*
* @param relBuilder Relational expression builder
* @param aggregate Aggregate relational expression
* @param argList Ordinals of columns to make distinct
* @param filterArg Ordinal of column to filter on, or -1
* @param sourceOf Out parameter, is populated with a map of where each output field came from
* @return Aggregate relational expression which projects the required columns
*/
private RelBuilder createSelectDistinct(
RelBuilder relBuilder,
Aggregate aggregate,
List argList,
int filterArg,
Map sourceOf) {
relBuilder.push(aggregate.getInput());
final List> projects = new ArrayList<>();
final List childFields = relBuilder.peek().getRowType().getFieldList();
for (int i : aggregate.getGroupSet()) {
sourceOf.put(i, projects.size());
projects.add(RexInputRef.of2(i, childFields));
}
if (filterArg >= 0) {
sourceOf.put(filterArg, projects.size());
projects.add(RexInputRef.of2(filterArg, childFields));
}
for (Integer arg : argList) {
if (filterArg >= 0) {
// Implement
// agg(DISTINCT arg) FILTER $f
// by generating
// SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
// and then applying
// agg(arg)
// as usual.
//
// It works except for (rare) agg functions that need to see null
// values.
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
final Pair argRef = RexInputRef.of2(arg, childFields);
RexNode condition =
rexBuilder.makeCall(
SqlStdOperatorTable.CASE,
filterRef,
argRef.left,
rexBuilder.makeNullLiteral(argRef.left.getType()));
sourceOf.put(arg, projects.size());
projects.add(Pair.of(condition, "i$" + argRef.right));
continue;
}
if (sourceOf.get(arg) != null) {
continue;
}
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
relBuilder.project(Pair.left(projects), Pair.right(projects));
// Get the distinct values of the GROUP BY fields and the arguments
// to the agg functions.
relBuilder.push(
aggregate.copy(
aggregate.getTraitSet(),
relBuilder.build(),
ImmutableBitSet.range(projects.size()),
null,
com.google.common.collect.ImmutableList.of()));
return relBuilder;
}
}
// End FlinkAggregateExpandDistinctAggregatesRule.java