All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.org.apache.calcite.linq4j.Ord;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalAggregate;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.fun.SqlInternalOperators;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.org.apache.calcite.util.mapping.IntPair;

import com.hazelcast.com.google.common.collect.ArrayListMultimap;
import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.Multimap;

import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;

import static com.hazelcast.org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule.groupValue;
import static com.hazelcast.org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule.remap;

/**
 * Planner rule that rewrites an {@link Aggregate} that contains
 * {@code WITHIN DISTINCT} aggregate functions.
 *
 * 

For example, *

*
SELECT o.paymentType,
 *     COUNT(*) as "count",
 *     COUNT(*) WITHIN DISTINCT (o.orderId) AS orderCount,
 *     SUM(o.shipping) WITHIN DISTINCT (o.orderId) as sumShipping,
 *     SUM(i.units) as sumUnits
 *   FROM Orders AS o
 *   JOIN OrderItems AS i USING (orderId)
 *   GROUP BY o.paymentType
*
* *

becomes * *

*
 *     SELECT paymentType,
 *       COUNT(*) as "count",
 *       COUNT(*) FILTER (WHERE g = 0) AS orderCount,
 *       SUM(minShipping) FILTER (WHERE g = 0) AS sumShipping,
 *       SUM(sumUnits) FILTER (WHERE g = 1) as sumUnits
 *    FROM (
 *      SELECT o.paymentType,
 *        GROUPING(o.orderId) AS g,
 *        SUM(o.shipping) AS sumShipping,
 *        MIN(o.shipping) AS minShipping,
 *        SUM(i.units) AS sumUnits
 *        FROM Orders AS o
 *        JOIN OrderItems ON o.orderId = i.orderId
 *        GROUP BY GROUPING SETS ((o.paymentType), (o.paymentType, o.orderId)))
 *     GROUP BY o.paymentType
*
* *

By the way, note that {@code COUNT(*) WITHIN DISTINCT (o.orderId)} * is identical to {@code COUNT(DISTINCT o.orderId)}. * {@code WITHIN DISTINCT} is a generalization of aggregate(DISTINCT). * So, it is perhaps not surprising that the rewrite to {@code GROUPING SETS} * is similar. * *

If there are multiple arguments * (e.g. {@code SUM(a) WITHIN DISTINCT (x), SUM(a) WITHIN DISTINCT (y)}) * the rule creates separate {@code GROUPING SET}s. */ @Value.Enclosing public class AggregateExpandWithinDistinctRule extends RelRule { /** Creates an AggregateExpandWithinDistinctRule. */ protected AggregateExpandWithinDistinctRule(Config config) { super(config); } private static boolean hasWithinDistinct(Aggregate aggregate) { return aggregate.getAggCallList().stream() .anyMatch(c -> c.distinctKeys != null) // Wait until AggregateReduceFunctionsRule has dealt with AVG etc. && aggregate.getAggCallList().stream() .noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce) // Don't think we can handle GROUPING SETS yet && aggregate.getGroupType() == Aggregate.Group.SIMPLE; } //~ Methods ---------------------------------------------------------------- @Override public void onMatch(RelOptRuleCall call) { final Aggregate aggregate = call.rel(0); // Throughout this method, assume we are working on the following SQL: // // SELECT deptno, SUM(sal), SUM(sal) WITHIN DISTINCT (job) // FROM emp // GROUP BY deptno // // or in algebra, // // Aggregate($0, SUM($2), SUM($2) WITHIN DISTINCT ($4)) // Scan(emp) // // We plan to generate the following: // // SELECT deptno, SUM(sal), SUM(sal) WITHIN DISTINCT (job) // FROM ( // SELECT deptno, GROUPING(deptno, job), SUM(sal), MIN(sal) // FROM emp // GROUP BY GROUPING SETS ((deptno), (deptno, job))) // GROUP BY deptno // // This rewrite also handles DISTINCT aggregates. We treat // // SUM(DISTINCT sal) // SUM(DISTINCT sal) WITHIN DISTINCT (job) // // as if the user had written // // SUM(sal) WITHIN DISTINCT (sal) // final List aggCallList = aggregate.getAggCallList() .stream() .map(c -> unDistinct(c, aggregate.getInput()::fieldIsNullable)) .collect(Util.toImmutableList()); // Find all within-distinct expressions. final Multimap argLists = ArrayListMultimap.create(); // A bit set that represents an aggregate with no WITHIN DISTINCT. // Different from "WITHIN DISTINCT ()". final ImmutableBitSet notDistinct = ImmutableBitSet.of(aggregate.getInput().getRowType().getFieldCount()); for (AggregateCall aggCall : aggCallList) { ImmutableBitSet distinctKeys = aggCall.distinctKeys; if (distinctKeys == null) { distinctKeys = notDistinct; } else { if (distinctKeys.intersects(aggregate.getGroupSet())) { // Remove group keys. E.g. // sum(x) within distinct (y, z) ... group by y // can be simplified to // sum(x) within distinct (z) ... group by y // Note that this assumes a single grouping set for the original agg. distinctKeys = distinctKeys.rebuild() .removeAll(aggregate.getGroupSet()).build(); } } argLists.put(distinctKeys, aggCall); } // Compute the set of all grouping sets that will be used in the output // query. For each WITHIN DISTINCT aggregate call, we will need a grouping // set that is the union of the aggregate call's unique keys and the input // query's overall grouping. Redundant grouping sets can be reused for // multiple aggregate calls. final Set groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING); for (ImmutableBitSet key : argLists.keySet()) { groupSetTreeSet.add( (key == notDistinct) ? aggregate.getGroupSet() : ImmutableBitSet.of(key).union(aggregate.getGroupSet())); } final ImmutableList groupSets = ImmutableList.copyOf(groupSetTreeSet); final boolean hasMultipleGroupSets = groupSets.size() > 1; final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets); final Set fullGroupOrderedSet = new LinkedHashSet<>(); fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet()); fullGroupOrderedSet.addAll(fullGroupSet.asSet()); final ImmutableIntList fullGroupList = ImmutableIntList.copyOf(fullGroupOrderedSet); // Build the inner query // // SELECT deptno, SUM(sal) AS sum_sal, MIN(sal) AS min_sal, // MAX(sal) AS max_sal, GROUPING(deptno, job) AS g // FROM emp // GROUP BY GROUPING SETS ((deptno), (deptno, job)) // // or in algebra, // // Aggregate([($0), ($0, $4)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4)) // Scan(emp) final RelBuilder b = call.builder(); b.push(aggregate.getInput()); final List aggCalls = new ArrayList<>(); // Helper class for building the inner query. // CHECKSTYLE: IGNORE 1 class Registrar { final int g = fullGroupSet.cardinality(); /** Map of input fields (below the original aggregation) and filter args * to inner query aggregate calls. */ final Map args = new HashMap<>(); /** Map of aggregate calls from the original aggregation to inner query * aggregate calls. */ final Map aggs = new HashMap<>(); /** Map of aggregate calls from the original aggregation to inner-query * {@code COUNT(*)} calls, which are only needed for filters in the outer * aggregate when the original aggregate call does not ignore null * inputs. */ final Map counts = new HashMap<>(); List fields(List fields, int filterArg) { return Util.transform(fields, f -> this.field(f, filterArg)); } int field(int field, int filterArg) { return Objects.requireNonNull(args.get(IntPair.of(field, filterArg))); } /** Computes an aggregate call argument's values for a * {@code WITHIN DISTINCT} aggregate call. * *

For example, to compute * {@code SUM(x) WITHIN DISTINCT (y) GROUP BY (z)}, * the inner aggregate must first group {@code x} by {@code (y, z)} * — using {@code MIN} to select the (hopefully) unique value of * {@code x} for each {@code (y, z)} group. Actually summing over the * grouped {@code x} values must occur in an outer aggregate. * * @param field Index of an input field that's used in a * {@code WITHIN DISTINCT} aggregate call * @param filterArg Filter arg used in the original aggregate call, or * {@code -1} if there is no filter. We use the same filter in * the inner query. * @return Index of the inner query aggregate call representing the * grouped field, which can be referenced in the outer query * aggregate call */ int register(int field, int filterArg) { return args.computeIfAbsent(IntPair.of(field, filterArg), j -> { final int ordinal = g + aggCalls.size(); RelBuilder.AggCall groupedField = b.aggregateCall(SqlStdOperatorTable.MIN, b.field(field)); aggCalls.add( filterArg < 0 ? groupedField : groupedField.filter(b.field(filterArg))); if (config.throwIfNotUnique()) { groupedField = b.aggregateCall(SqlStdOperatorTable.MAX, b.field(field)); aggCalls.add( filterArg < 0 ? groupedField : groupedField.filter(b.field(filterArg))); } return ordinal; }); } /** Registers an aggregate call that is not a * {@code WITHIN DISTINCT} call. * *

Unlike the case handled by {@link #register(int, int)} above, * aggregate calls without any distinct keys do not need a second round * of aggregation in the outer query, so they can be computed "as-is" in * the inner query. * * @param i Index of the aggregate call in the original aggregation * @param aggregateCall Original aggregate call * @return Index of the aggregate call in the computed inner query */ int registerAgg(int i, RelBuilder.AggCall aggregateCall) { final int ordinal = g + aggCalls.size(); aggs.put(i, ordinal); aggCalls.add(aggregateCall); return ordinal; } int getAgg(int i) { return Objects.requireNonNull(aggs.get(i)); } /** Registers an extra {@code COUNT} aggregate call when it's needed to * filter out null inputs in the outer aggregate. * *

This should only be called for aggregate calls with filters. It's * possible that the filter would eliminate all input rows to the * {@code MIN} call in the inner query, so calls in the outer * aggregate may need to be aware of this. See usage of * {@link AggregateExpandWithinDistinctRule#mustBeCounted(AggregateCall)}. * * @param filterArg The original aggregate call's filter; must be * non-negative * @return Index of the {@code COUNT} call in the computed inner query */ int registerCount(int filterArg) { assert filterArg >= 0; return counts.computeIfAbsent(filterArg, i -> { final int ordinal = g + aggCalls.size(); aggCalls.add(b.aggregateCall(SqlStdOperatorTable.COUNT) .filter(b.field(filterArg))); return ordinal; }); } int getCount(int filterArg) { return Objects.requireNonNull(counts.get(filterArg)); } } final Registrar registrar = new Registrar(); Ord.forEach(aggCallList, (c, i) -> { if (c.distinctKeys == null) { registrar.registerAgg(i, b.aggregateCall(c.getAggregation(), b.fields(c.getArgList()))); } else { for (int inputIdx : c.getArgList()) { registrar.register(inputIdx, c.filterArg); } if (mustBeCounted(c)) { registrar.registerCount(c.filterArg); } } }); // Add an additional GROUPING() aggregate call so we can select only the // relevant inner-aggregate rows from the outer aggregate. If there is only // 1 grouping set (i.e. every aggregate call has the same distinct keys), // no GROUPING() call is necessary. final int grouping = hasMultipleGroupSets ? registrar.registerAgg(-1, b.aggregateCall( SqlStdOperatorTable.GROUPING, b.fields(fullGroupList))) : -1; b.aggregate(b.groupKey(fullGroupSet, groupSets), aggCalls); // Build the outer query // // SELECT deptno, // MIN(sum_sal) FILTER (g = 0), // SUM(min_sal) FILTER (g = 1) // FROM ( ... ) // GROUP BY deptno // // or in algebra, // // Aggregate($0, SUM($2 WHERE $4 = 0), SUM($3 WHERE $4 = 1)) // Aggregate([($0), ($0, $2)], SUM($2), MIN($2), GROUPING($0, $4)) // Scan(emp) // // If throwIfNotUnique, the "SUM(min_sal) FILTER (g = 1)" term above becomes // // SUM(min_sal) FILTER ( // $THROW_UNLESS(g <> 1 OR min_sal IS NOT DISTINCT FROM max_sal, // 'more than one distinct value in agg UNIQUE_VALUE') // AND g = 1) aggCalls.clear(); Ord.forEach(aggCallList, (c, i) -> { final List filters = new ArrayList<>(); RexNode groupFilter = null; if (hasMultipleGroupSets) { groupFilter = b.equals( b.field(grouping), b.literal( groupValue(fullGroupList, union(aggregate.getGroupSet(), c.distinctKeys)))); filters.add(groupFilter); } RelBuilder.AggCall aggCall; if (c.distinctKeys == null) { aggCall = b.aggregateCall(SqlStdOperatorTable.MIN, b.field(registrar.getAgg(i))); } else { // The inputs to this aggregate are outputs from MIN() calls from the // inner agg, and MIN() returns null iff it has no non-null inputs, // which can only happen if an original aggregate's filter causes all // non-null input rows to be discarded for a particular group in the // inner aggregate. In this case, it should be ignored by the outer // aggregate as well. In case the aggregate call does not naturally // ignore null inputs, we add a filter based on a COUNT() in the inner // aggregate. aggCall = b.aggregateCall( c.getAggregation(), b.fields(registrar.fields(c.getArgList(), c.filterArg))); if (mustBeCounted(c)) { filters.add(b.greaterThan(b.field(registrar.getCount(c.filterArg)), b.literal(0))); } if (config.throwIfNotUnique()) { for (int j : c.getArgList()) { RexNode isUniqueCondition = b.isNotDistinctFrom( b.field(registrar.field(j, c.filterArg)), b.field(registrar.field(j, c.filterArg) + 1)); if (groupFilter != null) { isUniqueCondition = b.or(b.not(groupFilter), isUniqueCondition); } String message = "more than one distinct value in agg UNIQUE_VALUE"; filters.add( b.call(SqlInternalOperators.THROW_UNLESS, isUniqueCondition, b.literal(message))); } } } if (filters.size() > 0) { aggCall = aggCall.filter(b.and(filters)); } aggCalls.add(aggCall); }); b.aggregate( b.groupKey( remap(fullGroupSet, aggregate.getGroupSet()), remap(fullGroupSet, aggregate.getGroupSets())), aggCalls); b.convert(aggregate.getRowType(), false); call.transformTo(b.build()); } private static boolean mustBeCounted(AggregateCall aggCall) { // Always count filtered inner aggregates to be safe. // // It's possible that, for some aggregate calls (namely, those that // completely ignore null inputs), we could neglect counting the // grouped-and-filtered rows of the inner aggregate and filtering the empty // ones out from the outer aggregate, since those empty groups would produce // null values as the result of MIN and thus be ignored by the outer // aggregate anyway. // // Note that using "aggCall.ignoreNulls()" is not sufficient to determine // when it's safe to do this, since for COUNT the value of ignoreNulls() // should generally be true even though COUNT(*) will never ignore anything. return aggCall.hasFilter(); } /** Converts a {@code DISTINCT} aggregate call into an equivalent one with * {@code WITHIN DISTINCT}. * *

Examples: *

    *
  • {@code SUM(DISTINCT x)} → * {@code SUM(x) WITHIN DISTINCT (x)} has distinct key (x); *
  • {@code SUM(DISTINCT x)} WITHIN DISTINCT (y) → * {@code SUM(x) WITHIN DISTINCT (x)} has distinct key (x); *
  • {@code SUM(x)} WITHIN DISTINCT (y, z) has distinct key (y, z); *
  • {@code SUM(x)} has no distinct key. *
*/ private static AggregateCall unDistinct(AggregateCall aggregateCall, IntPredicate isNullable) { if (aggregateCall.isDistinct()) { final List newArgList = aggregateCall.getArgList() .stream() .filter(i -> aggregateCall.getAggregation().getKind() != SqlKind.COUNT || aggregateCall.hasFilter() || isNullable.test(i)) .collect(Collectors.toList()); return aggregateCall.withDistinct(false) .withDistinctKeys(ImmutableBitSet.of(aggregateCall.getArgList())) .withArgList(newArgList); } return aggregateCall; } private static ImmutableBitSet union(ImmutableBitSet s0, @Nullable ImmutableBitSet s1) { return s1 == null ? s0 : s0.union(s1); } /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { Config DEFAULT = ImmutableAggregateExpandWithinDistinctRule.Config.of() .withOperandSupplier(b -> b.operand(LogicalAggregate.class) .predicate(AggregateExpandWithinDistinctRule::hasWithinDistinct) .anyInputs()); @Override default AggregateExpandWithinDistinctRule toRule() { return new AggregateExpandWithinDistinctRule(this); } /** Whether the code generated by the rule should throw if the arguments * are not functionally dependent. * *

For example, if implementing {@code SUM(sal) WITHIN DISTINCT job)} ... * {@code GROUP BY deptno}, * suppose that within department 10, (job, sal) has the values * ('CLERK', 100), ('CLERK', 120), ('MANAGER', 150), ('MANAGER', 150). If * {@code throwIfNotUnique} is true, the query would throw because of the * values [100, 120]; if false, the query would sum the distinct values * [100, 120, 150]. */ @Value.Default default boolean throwIfNotUnique() { return true; } /** Sets {@link #throwIfNotUnique()}. */ Config withThrowIfNotUnique(boolean throwIfNotUnique); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy