org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExpandDistinctAggregatesRule Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Planner rule that expands distinct aggregates
* (such as {@code COUNT(DISTINCT x)}) from a
* {@link org.apache.calcite.rel.core.Aggregate}.
*
* How this is done depends upon the arguments to the function. If all
* functions have the same argument
* (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
* {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
* sufficient.
*
*
If there are multiple arguments
* (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
* the rule creates separate {@code Aggregate}s and combines using a
* {@link org.apache.calcite.rel.core.Join}.
*/
// Stripped down version of org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule
// This is adapted for Hive, but should eventually be deleted from Hive and make use of above.
public final class HiveExpandDistinctAggregatesRule extends RelOptRule {
//~ Static fields/initializers ---------------------------------------------
/** The default instance of the rule; operates only on logical expressions. */
public static final HiveExpandDistinctAggregatesRule INSTANCE =
new HiveExpandDistinctAggregatesRule(HiveAggregate.class,
HiveProject.DEFAULT_PROJECT_FACTORY);
private static RelFactories.ProjectFactory projFactory;
//~ Constructors -----------------------------------------------------------
public HiveExpandDistinctAggregatesRule(
Class clazz,RelFactories.ProjectFactory projectFactory) {
super(operand(clazz, any()));
projFactory = projectFactory;
}
//~ Methods ----------------------------------------------------------------
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (!aggregate.containsDistinctCall()) {
return;
}
// Find all of the agg expressions. We use a LinkedHashSet to ensure
// determinism.
int nonDistinctCount = 0;
Set> argListSets = new LinkedHashSet>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
++nonDistinctCount;
continue;
}
ArrayList argList = new ArrayList();
for (Integer arg : aggCall.getArgList()) {
argList.add(arg);
}
argListSets.add(argList);
}
Util.permAssert(argListSets.size() > 0, "containsDistinctCall lied");
// If all of the agg expressions are distinct and have the same
// arguments then we can use a more efficient form.
if ((nonDistinctCount == 0) && (argListSets.size() == 1)) {
for (Integer arg : argListSets.iterator().next()) {
Set colOrigs = RelMetadataQuery.getColumnOrigins(aggregate, arg);
if (null != colOrigs) {
for (RelColumnOrigin colOrig : colOrigs) {
RelOptHiveTable hiveTbl = (RelOptHiveTable)colOrig.getOriginTable();
if(hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) {
// Encountered partitioning column, this will be better handled by MetadataOnly optimizer.
return;
}
}
}
}
RelNode converted =
convertMonopole(
aggregate,
argListSets.iterator().next());
call.transformTo(converted);
return;
}
}
/**
* Converts an aggregate relational expression that contains just one
* distinct aggregate function (or perhaps several over the same arguments)
* and no non-distinct aggregate functions.
*/
private RelNode convertMonopole(
Aggregate aggregate,
List argList) {
// For example,
// SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
// FROM (
// SELECT DISTINCT deptno, sal AS distinct_sal
// FROM EMP GROUP BY deptno)
// GROUP BY deptno
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
Map sourceOf = new HashMap();
final Aggregate distinct =
createSelectDistinct(aggregate, argList, sourceOf);
// Create an aggregate on top, with the new aggregate list.
final List newAggCalls =
Lists.newArrayList(aggregate.getAggCallList());
rewriteAggCalls(newAggCalls, argList, sourceOf);
final int cardinality = aggregate.getGroupSet().cardinality();
return aggregate.copy(aggregate.getTraitSet(), distinct,
aggregate.indicator, ImmutableBitSet.range(cardinality), null,
newAggCalls);
}
private static void rewriteAggCalls(
List newAggCalls,
List argList,
Map sourceOf) {
// Rewrite the agg calls. Each distinct agg becomes a non-distinct call
// to the corresponding field from the right; for example,
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// Ignore agg calls which are not distinct or have the wrong set
// arguments. If we're rewriting aggs whose args are {sal}, we will
// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List newArgs = new ArrayList(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final AggregateCall newAggCall =
new AggregateCall(
aggCall.getAggregation(),
false,
newArgs,
aggCall.getType(),
aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
/**
* Given an {@link org.apache.calcite.rel.logical.LogicalAggregate}
* and the ordinals of the arguments to a
* particular call to an aggregate function, creates a 'select distinct'
* relational expression which projects the group columns and those
* arguments but nothing else.
*
* For example, given
*
*
* select f0, count(distinct f1), count(distinct f2)
* from t group by f0
*
*
* and the arglist
*
* {2}
*
* returns
*
*
* select distinct f0, f2 from t
*
*
* '
*
* The sourceOf
map is populated with the source of each
* column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
*
* @param aggregate Aggregate relational expression
* @param argList Ordinals of columns to make distinct
* @param sourceOf Out parameter, is populated with a map of where each
* output field came from
* @return Aggregate relational expression which projects the required
* columns
*/
private static Aggregate createSelectDistinct(
Aggregate aggregate,
List argList,
Map sourceOf) {
final List> projects =
new ArrayList>();
final RelNode child = aggregate.getInput();
final List childFields =
child.getRowType().getFieldList();
for (int i : aggregate.getGroupSet()) {
sourceOf.put(i, projects.size());
projects.add(RexInputRef.of2(i, childFields));
}
for (Integer arg : argList) {
if (sourceOf.get(arg) != null) {
continue;
}
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
final RelNode project =
projFactory.createProject(child, Pair.left(projects), Pair.right(projects));
// Get the distinct values of the GROUP BY fields and the arguments
// to the agg functions.
return aggregate.copy(aggregate.getTraitSet(), project, false,
ImmutableBitSet.range(projects.size()),
null, ImmutableList.of());
}
}
// End AggregateExpandDistinctAggregatesRule.java