org.apache.drill.exec.planner.logical.DrillReduceAggregatesRule Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.logical;
import org.apache.drill.exec.planner.sql.DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper;
import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableList;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.drill.shaded.guava.com.google.common.collect.Maps;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper;
import org.apache.drill.exec.planner.sql.DrillSqlOperator;
import org.apache.drill.exec.planner.sql.TypeInferenceUtils;
import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Rule to reduce aggregates to simpler forms. Currently only AVG(x) to
* SUM(x)/COUNT(x), but eventually will handle others such as STDDEV.
*/
public class DrillReduceAggregatesRule extends RelOptRule {
//~ Static fields/initializers ---------------------------------------------
/**
* The singleton.
*/
public static final DrillReduceAggregatesRule INSTANCE =
new DrillReduceAggregatesRule(operand(LogicalAggregate.class, any()));
public static final DrillConvertSumToSumZero INSTANCE_SUM =
new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any()));
public static final DrillConvertWindowSumToSumZero INSTANCE_WINDOW_SUM =
new DrillConvertWindowSumToSumZero(operand(DrillWindowRel.class, any()));
private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false,
new SqlReturnTypeInference() {
@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
return TypeInferenceUtils.createCalciteTypeWithNullability(
opBinding.getTypeFactory(),
SqlTypeName.ANY,
opBinding.getOperandType(0).isNullable());
}
}, false);
//~ Constructors -----------------------------------------------------------
protected DrillReduceAggregatesRule(RelOptRuleOperand operand) {
super(operand, DrillRelFactories.LOGICAL_BUILDER, null);
}
//~ Methods ----------------------------------------------------------------
@Override
public boolean matches(RelOptRuleCall call) {
if (!super.matches(call)) {
return false;
}
Aggregate oldAggRel = (Aggregate) call.rels[0];
return containsAvgStddevVarCall(oldAggRel.getAggCallList());
}
@Override
public void onMatch(RelOptRuleCall ruleCall) {
Aggregate oldAggRel = (Aggregate) ruleCall.rels[0];
reduceAggs(ruleCall, oldAggRel);
}
/**
* Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*.
*
* @param aggCallList List of aggregate calls
*/
private boolean containsAvgStddevVarCall(List aggCallList) {
for (AggregateCall call : aggCallList) {
SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getAggregation());
if (sqlAggFunction instanceof SqlAvgAggFunction
|| sqlAggFunction instanceof SqlSumAggFunction) {
return true;
}
}
return false;
}
/*
private boolean isMatch(AggregateCall call) {
if (call.getAggregation() instanceof SqlAvgAggFunction) {
final SqlAvgAggFunction.Subtype subtype =
((SqlAvgAggFunction) call.getAggregation()).getSubtype();
return (subtype == SqlAvgAggFunction.Subtype.AVG);
}
return false;
}
*/
/**
* Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
* the aggregates list to.
*
* It handles newly generated common subexpressions since this was done
* at the sql2rel stage.
*/
private void reduceAggs(
RelOptRuleCall ruleCall,
Aggregate oldAggRel) {
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
List oldCalls = oldAggRel.getAggCallList();
final int nGroups = oldAggRel.getGroupCount();
List newCalls = new ArrayList<>();
Map aggCallMapping =
new HashMap<>();
List projList = new ArrayList<>();
// pass through group key
for (int i = 0; i < nGroups; ++i) {
projList.add(
rexBuilder.makeInputRef(
getFieldType(oldAggRel, i),
i));
}
// List of input expressions. If a particular aggregate needs more, it
// will add an expression to the end, and we will create an extra
// project.
RelNode input = oldAggRel.getInput();
List inputExprs = new ArrayList<>();
for (RelDataTypeField field : input.getRowType().getFieldList()) {
inputExprs.add(
rexBuilder.makeInputRef(
field.getType(), inputExprs.size()));
}
// create new agg function calls and rest of project list together
for (AggregateCall oldCall : oldCalls) {
projList.add(
reduceAgg(
oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
}
final int extraArgCount =
inputExprs.size() - input.getRowType().getFieldCount();
if (extraArgCount > 0) {
input =
relBuilderFactory
.create(input.getCluster(), null)
.push(input)
.projectNamed(
inputExprs,
CompositeList.of(
input.getRowType().getFieldNames(),
Collections.nCopies(
extraArgCount,
null)),
true)
.build();
}
Aggregate newAggRel =
newAggregateRel(
oldAggRel, input, newCalls);
RelNode projectRel =
relBuilderFactory
.create(newAggRel.getCluster(), null)
.push(newAggRel)
.projectNamed(projList, oldAggRel.getRowType().getFieldNames(), true)
.build();
ruleCall.transformTo(projectRel);
}
private RexNode reduceAgg(
Aggregate oldAggRel,
AggregateCall oldCall,
List newCalls,
Map aggCallMapping,
List inputExprs) {
final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
if (sqlAggFunction instanceof SqlSumAggFunction) {
// replace original SUM(x) with
// case COUNT(x) when 0 then null else SUM0(x) end
return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
}
if (sqlAggFunction instanceof SqlAvgAggFunction) {
// for DECIMAL data types does not produce rewriting of complex calls,
// since SUM returns value with 38 precision and further handling of the value
// causes the loss of the scale
if (oldCall.getType().getSqlTypeName() == SqlTypeName.DECIMAL) {
return oldAggRel.getCluster().getRexBuilder().addAggCall(
oldCall,
oldAggRel.getGroupCount(),
newCalls,
aggCallMapping,
ImmutableList.of(getFieldType(
oldAggRel.getInput(),
oldCall.getArgList().get(0))));
}
final SqlKind subtype = sqlAggFunction.getKind();
switch (subtype) {
case AVG:
// replace original AVG(x) with SUM(x) / COUNT(x)
return reduceAvg(
oldAggRel, oldCall, newCalls, aggCallMapping);
case STDDEV_POP:
// replace original STDDEV_POP(x) with
// SQRT(
// (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
// / COUNT(x))
return reduceStddev(
oldAggRel, oldCall, true, true, newCalls, aggCallMapping,
inputExprs);
case STDDEV_SAMP:
// replace original STDDEV_SAMP(x) with
// SQRT(
// (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
return reduceStddev(
oldAggRel, oldCall, false, true, newCalls, aggCallMapping,
inputExprs);
case VAR_POP:
// replace original VAR_POP(x) with
// (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
// / COUNT(x)
return reduceStddev(
oldAggRel, oldCall, true, false, newCalls, aggCallMapping,
inputExprs);
case VAR_SAMP:
// replace original VAR_SAMP(x) with
// (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
return reduceStddev(
oldAggRel, oldCall, false, false, newCalls, aggCallMapping,
inputExprs);
default:
throw Util.unexpected(subtype);
}
} else {
// anything else: preserve original call
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
final int nGroups = oldAggRel.getGroupCount();
List oldArgTypes = new ArrayList<>();
List ordinals = oldCall.getArgList();
assert ordinals.size() <= inputExprs.size();
for (int ordinal : ordinals) {
oldArgTypes.add(inputExprs.get(ordinal).getType());
}
//to solve AggregateCall returns true with equals method but has
// different RelDataTypes
if (aggCallMapping.containsKey(oldCall) && !aggCallMapping.get(oldCall)
.getType().equals(oldCall.getType())) {
int index = newCalls.size() + nGroups;
newCalls.add(oldCall);
return rexBuilder.makeInputRef(oldCall.getType(), index);
}
return rexBuilder.addAggCall(
oldCall,
nGroups,
newCalls,
aggCallMapping,
oldArgTypes);
}
}
private RexNode reduceAvg(
Aggregate oldAggRel,
AggregateCall oldCall,
List newCalls,
Map aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory =
oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType =
getFieldType(
oldAggRel.getInput(),
iAvgInput);
RelDataType sumType =
TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
ImmutableList.of())
.inferReturnType(oldCall.createBinding(oldAggRel));
sumType =
typeFactory.createTypeWithNullability(
sumType,
sumType.isNullable() || nGroups == 0);
SqlAggFunction sumAgg =
new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumCall = getAggCall(oldCall, sumAgg, sumType);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = getAggCall(oldCall, countAgg, countType);
RexNode tmpsumRef =
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));
RexNode tmpcountRef =
rexBuilder.addAggCall(
countCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));
RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
rexBuilder.constantNull(),
tmpsumRef);
// NOTE: these references are with respect to the output
// of newAggRel
/*
RexNode numeratorRef =
rexBuilder.makeCall(CastHighOp,
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType))
);
*/
RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
RexNode denominatorRef =
rexBuilder.addAggCall(
countCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));
if (isInferenceEnabled) {
return rexBuilder.makeCall(
new DrillSqlOperator(
"divide",
2,
true,
oldCall.getType(), false),
numeratorRef,
denominatorRef);
} else {
final RexNode divideRef =
rexBuilder.makeCall(
SqlStdOperatorTable.DIVIDE,
numeratorRef,
denominatorRef);
return rexBuilder.makeCast(
typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
}
}
private static AggregateCall getAggCall(AggregateCall oldCall,
SqlAggFunction aggFunction,
RelDataType sumType) {
return AggregateCall.create(aggFunction,
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
oldCall.getArgList(),
oldCall.filterArg,
oldCall.distinctKeys,
oldCall.getCollation(),
sumType,
null);
}
private RexNode reduceSum(
Aggregate oldAggRel,
AggregateCall oldCall,
List newCalls,
Map aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory =
oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int arg = oldCall.getArgList().get(0);
RelDataType argType =
getFieldType(
oldAggRel.getInput(),
arg);
final RelDataType sumType;
final SqlAggFunction sumZeroAgg;
if (isInferenceEnabled) {
sumType = oldCall.getType();
} else {
sumType =
typeFactory.createTypeWithNullability(
oldCall.getType(), argType.isNullable());
}
sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumZeroCall = getAggCall(oldCall, sumZeroAgg, sumType);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = getAggCall(oldCall, countAgg, countType);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode sumZeroRef =
rexBuilder.addAggCall(
sumZeroCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
if (!oldCall.getType().isNullable()) {
// If SUM(x) is not nullable, the validator must have determined that
// nulls are impossible (because the group is never empty and x is never
// null). Therefore we translate to SUM0(x).
return sumZeroRef;
}
RexNode countRef =
rexBuilder.addAggCall(
countCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
rexBuilder.constantNull(),
sumZeroRef);
}
private RexNode reduceStddev(
Aggregate oldAggRel,
AggregateCall oldCall,
boolean biased,
boolean sqrt,
List newCalls,
Map aggCallMapping,
List inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory =
oldAggRel.getCluster().getTypeFactory();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType =
getFieldType(
oldAggRel.getInput(),
argOrdinal);
// final RexNode argRef = inputExprs.get(argOrdinal);
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
inputExprs.set(argOrdinal, argRef);
final RexNode argSquared =
rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
RelDataType sumType =
TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
ImmutableList.of())
.inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, true);
final AggregateCall sumArgSquaredAggCall =
AggregateCall.create(
new DrillCalciteSqlAggFunctionWrapper(
new SqlSumAggFunction(sumType), sumType),
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
ImmutableIntList.of(argSquaredOrdinal),
oldCall.filterArg,
oldCall.distinctKeys,
oldCall.getCollation(),
sumType,
null);
final RexNode sumArgSquared =
rexBuilder.addAggCall(
sumArgSquaredAggCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
final AggregateCall sumArgAggCall =
AggregateCall.create(
new DrillCalciteSqlAggFunctionWrapper(
new SqlSumAggFunction(sumType), sumType),
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
ImmutableIntList.of(argOrdinal),
oldCall.filterArg,
oldCall.distinctKeys,
oldCall.getCollation(),
sumType,
null);
final RexNode sumArg =
rexBuilder.addAggCall(
sumArgAggCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
final RexNode sumSquaredArg =
rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType);
final RexNode countArg =
rexBuilder.addAggCall(
countArgAggCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
final RexNode avgSumSquaredArg =
rexBuilder.makeCall(
SqlStdOperatorTable.DIVIDE,
sumSquaredArg, countArg);
final RexNode diff =
rexBuilder.makeCall(
SqlStdOperatorTable.MINUS,
sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one =
rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul =
rexBuilder.makeNullLiteral(countArg.getType());
final RexNode countMinusOne =
rexBuilder.makeCall(
SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne =
rexBuilder.makeCall(
SqlStdOperatorTable.EQUALS, countArg, one);
denominator =
rexBuilder.makeCall(
SqlStdOperatorTable.CASE,
countEqOne, nul, countMinusOne);
}
final SqlOperator divide;
if (isInferenceEnabled) {
divide = new DrillSqlOperator(
"divide",
2,
true,
oldCall.getType(), false);
} else {
divide = SqlStdOperatorTable.DIVIDE;
}
final RexNode div =
rexBuilder.makeCall(
divide, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half =
rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result =
rexBuilder.makeCall(
SqlStdOperatorTable.POWER, div, half);
}
if (isInferenceEnabled) {
return result;
} else {
/*
* Currently calcite's strategy to infer the return type of aggregate functions
* is wrong because it uses the first known argument to determine output type. For
* instance if we are performing stddev on an integer column then it interprets the
* output type to be integer which is incorrect as it should be double. So based on
* this if we add cast after rewriting the aggregate we add an additional cast which
* would cause wrong results. So we simply add a cast to ANY.
*/
return rexBuilder.makeCast(
typeFactory.createSqlType(SqlTypeName.ANY), result);
}
}
/**
* Finds the ordinal of an element in a list, or adds it.
*
* @param list List
* @param element Element to lookup or add
* @param Element type
* @return Ordinal of element in list
*/
private static int lookupOrAdd(List list, T element) {
int ordinal = list.indexOf(element);
if (ordinal == -1) {
ordinal = list.size();
list.add(element);
}
return ordinal;
}
/**
* Do a shallow clone of oldAggRel and update aggCalls. Could be refactored
* into Aggregate and subclasses - but it's only needed for some
* subclasses.
*
* @param oldAggRel AggregateRel to clone.
* @param inputRel Input relational expression
* @param newCalls New list of AggregateCalls
* @return shallow clone with new list of AggregateCalls.
*/
protected Aggregate newAggregateRel(
Aggregate oldAggRel,
RelNode inputRel,
List newCalls) {
RelOptCluster cluster = inputRel.getCluster();
return new LogicalAggregate(cluster, cluster.traitSetOf(Convention.NONE), Collections.emptyList(),
inputRel, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
}
private RelDataType getFieldType(RelNode relNode, int i) {
final RelDataTypeField inputField =
relNode.getRowType().getFieldList().get(i);
return inputField.getType();
}
private static class DrillConvertSumToSumZero extends RelOptRule {
public DrillConvertSumToSumZero(RelOptRuleOperand operand) {
super(operand, DrillRelFactories.LOGICAL_BUILDER, null);
}
@Override
public boolean matches(RelOptRuleCall call) {
DrillAggregateRel oldAggRel = (DrillAggregateRel) call.rels[0];
for (AggregateCall aggregateCall : oldAggRel.getAggCallList()) {
if (isConversionToSumZeroNeeded(aggregateCall.getAggregation(), aggregateCall.getType())) {
return true;
}
}
return false;
}
@Override
public void onMatch(RelOptRuleCall call) {
final DrillAggregateRel oldAggRel = (DrillAggregateRel) call.rels[0];
final Map aggCallMapping = Maps.newHashMap();
final List newAggregateCalls = Lists.newArrayList();
for (AggregateCall oldAggregateCall : oldAggRel.getAggCallList()) {
if (isConversionToSumZeroNeeded(oldAggregateCall.getAggregation(), oldAggregateCall.getType())) {
final RelDataType argType = oldAggregateCall.getType();
final RelDataType sumType = oldAggRel.getCluster().getTypeFactory()
.createTypeWithNullability(argType, argType.isNullable());
final SqlAggFunction sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumZeroCall =
AggregateCall.create(
sumZeroAgg,
oldAggregateCall.isDistinct(),
oldAggregateCall.isApproximate(),
oldAggregateCall.ignoreNulls(),
oldAggregateCall.getArgList(),
oldAggregateCall.filterArg,
oldAggregateCall.distinctKeys,
oldAggregateCall.getCollation(),
sumType,
oldAggregateCall.getName());
oldAggRel.getCluster().getRexBuilder()
.addAggCall(sumZeroCall,
oldAggRel.getGroupCount(),
newAggregateCalls,
aggCallMapping,
ImmutableList.of(argType));
} else {
newAggregateCalls.add(oldAggregateCall);
}
}
call.transformTo(new DrillAggregateRel(
oldAggRel.getCluster(),
oldAggRel.getTraitSet(),
oldAggRel.getInput(),
oldAggRel.getGroupSet(),
oldAggRel.getGroupSets(),
newAggregateCalls));
}
}
private static class DrillConvertWindowSumToSumZero extends RelOptRule {
public DrillConvertWindowSumToSumZero(RelOptRuleOperand operand) {
super(operand, DrillRelFactories.LOGICAL_BUILDER, null);
}
@Override
public boolean matches(RelOptRuleCall call) {
final DrillWindowRel oldWinRel = (DrillWindowRel) call.rels[0];
for (Window.Group group : oldWinRel.groups) {
for (Window.RexWinAggCall rexWinAggCall : group.aggCalls) {
if (isConversionToSumZeroNeeded(rexWinAggCall.getOperator(), rexWinAggCall.getType())) {
return true;
}
}
}
return false;
}
@Override
public void onMatch(RelOptRuleCall call) {
final DrillWindowRel oldWinRel = (DrillWindowRel) call.rels[0];
final ImmutableList.Builder builder = ImmutableList.builder();
for (Window.Group group : oldWinRel.groups) {
final List aggCalls = Lists.newArrayList();
for (Window.RexWinAggCall rexWinAggCall : group.aggCalls) {
if (isConversionToSumZeroNeeded(rexWinAggCall.getOperator(), rexWinAggCall.getType())) {
final RelDataType argType = rexWinAggCall.getType();
final RelDataType sumType = oldWinRel.getCluster().getTypeFactory()
.createTypeWithNullability(argType, argType.isNullable());
final SqlAggFunction sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(
new SqlSumEmptyIsZeroAggFunction(), sumType);
final Window.RexWinAggCall sumZeroCall =
new Window.RexWinAggCall(
sumZeroAgg,
sumType,
rexWinAggCall.operands,
rexWinAggCall.ordinal,
rexWinAggCall.distinct);
aggCalls.add(sumZeroCall);
} else {
aggCalls.add(rexWinAggCall);
}
}
final Window.Group newGroup = new Window.Group(
group.keys,
group.isRows,
group.lowerBound,
group.upperBound,
group.orderKeys,
aggCalls);
builder.add(newGroup);
}
call.transformTo(new DrillWindowRel(
oldWinRel.getCluster(),
oldWinRel.getTraitSet(),
oldWinRel.getInput(),
oldWinRel.constants,
oldWinRel.getRowType(),
builder.build()));
}
}
private static boolean isConversionToSumZeroNeeded(SqlOperator sqlOperator, RelDataType type) {
sqlOperator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(sqlOperator);
if (sqlOperator instanceof SqlSumAggFunction
&& !type.isNullable()) {
// If SUM(x) is not nullable, the validator must have determined that
// nulls are impossible (because the group is never empty and x is never
// null). Therefore we translate to SUM0(x).
return true;
}
return false;
}
}