All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.rel.rules.AggregateReduceFunctionsRule Maven / Gradle / Ivy

There is a newer version: 5.4.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.org.apache.calcite.plan.RelOptCluster;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleOperand;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.RelCollations;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.AggregateCall;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalAggregate;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.sql.SqlAggFunction;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.CompositeList;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Util;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.ImmutableSet;

import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.IntPredicate;
import java.util.function.Predicate;

/**
 * Planner rule that reduces aggregate functions in
 * {@link com.hazelcast.org.apache.calcite.rel.core.Aggregate}s to simpler forms.
 *
 * 

Rewrites: *

    * *
  • AVG(x) → SUM(x) / COUNT(x) * *
  • STDDEV_POP(x) → SQRT( * (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) * / COUNT(x)) * *
  • STDDEV_SAMP(x) → SQRT( * (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) * / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END) * *
  • VAR_POP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) * / COUNT(x) * *
  • VAR_SAMP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) * / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END * *
  • COVAR_POP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) * / REGR_COUNT(x, y)) / REGR_COUNT(x, y) * *
  • COVAR_SAMP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) * / CASE REGR_COUNT(x, y) WHEN 1 THEN NULL ELSE REGR_COUNT(x, y) - 1 END * *
  • REGR_SXX(x, y) → REGR_COUNT(x, y) * VAR_POP(y) * *
  • REGR_SYY(x, y) → REGR_COUNT(x, y) * VAR_POP(x) * *
* *

Since many of these rewrites introduce multiple occurrences of simpler * forms like {@code COUNT(x)}, the rule gathers common sub-expressions as it * goes. * * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS */ @Value.Enclosing public class AggregateReduceFunctionsRule extends RelRule implements TransformationRule { //~ Static fields/initializers --------------------------------------------- private static void validateFunction(SqlKind function) { if (!isValid(function)) { throw new IllegalArgumentException("AggregateReduceFunctionsRule doesn't " + "support function: " + function.sql); } } private static boolean isValid(SqlKind function) { return SqlKind.AVG_AGG_FUNCTIONS.contains(function) || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(function) || function == SqlKind.SUM; } private final Set functionsToReduce; //~ Constructors ----------------------------------------------------------- /** Creates an AggregateReduceFunctionsRule. */ protected AggregateReduceFunctionsRule(Config config) { super(config); this.functionsToReduce = ImmutableSet.copyOf(config.actualFunctionsToReduce()); } @Deprecated // to be removed before 2.0 public AggregateReduceFunctionsRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory) { this(Config.DEFAULT .withRelBuilderFactory(relBuilderFactory) .withOperandSupplier(b -> b.exactly(operand)) .as(Config.class) // reduce all functions handled by this rule .withFunctionsToReduce(null)); } @Deprecated // to be removed before 2.0 public AggregateReduceFunctionsRule(Class aggregateClass, RelBuilderFactory relBuilderFactory, EnumSet functionsToReduce) { this(Config.DEFAULT .withRelBuilderFactory(relBuilderFactory) .as(Config.class) .withOperandFor(aggregateClass) // reduce specific functions provided by the client .withFunctionsToReduce(Objects.requireNonNull(functionsToReduce, "functionsToReduce"))); } //~ Methods ---------------------------------------------------------------- @Override public boolean matches(RelOptRuleCall call) { if (!super.matches(call)) { return false; } Aggregate oldAggRel = (Aggregate) call.rels[0]; return containsAvgStddevVarCall(oldAggRel.getAggCallList()); } @Override public void onMatch(RelOptRuleCall ruleCall) { Aggregate oldAggRel = (Aggregate) ruleCall.rels[0]; reduceAggs(ruleCall, oldAggRel); } /** * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*. * * @param aggCallList List of aggregate calls */ private boolean containsAvgStddevVarCall(List aggCallList) { return aggCallList.stream().anyMatch(this::canReduce); } /** Returns whether this rule can reduce a given aggregate function call. */ public boolean canReduce(AggregateCall call) { return functionsToReduce.contains(call.getAggregation().getKind()) && config.extraCondition().test(call); } /** Returns whether this rule can reduce some agg-call, * which its arg exists in the aggregate's group. */ public boolean canReduceAggCallByGrouping(Aggregate oldAggRel, AggregateCall call) { if (!Aggregate.isSimple(oldAggRel)) { return false; } if (call.hasFilter() || call.distinctKeys != null || call.collation != RelCollations.EMPTY) { return false; } final List argList = call.getArgList(); if (argList.size() != 1) { return false; } if (!oldAggRel.getGroupSet().asSet().contains(argList.get(0))) { // arg doesn't exist in aggregate's group. return false; } final SqlKind kind = call.getAggregation().getKind(); switch (kind) { case AVG: case MAX: case MIN: case ANY_VALUE: case FIRST_VALUE: case LAST_VALUE: return true; default: return false; } } /** * Reduces calls to functions AVG, SUM, STDDEV_POP, STDDEV_SAMP, VAR_POP, * VAR_SAMP, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY if the function is * present in {@link AggregateReduceFunctionsRule#functionsToReduce} * *

It handles newly generated common subexpressions since this was done * at the sql2rel stage. */ private void reduceAggs( RelOptRuleCall ruleCall, Aggregate oldAggRel) { RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); List oldCalls = oldAggRel.getAggCallList(); final int groupCount = oldAggRel.getGroupCount(); final List newCalls = new ArrayList<>(); final Map aggCallMapping = new HashMap<>(); final List projList = new ArrayList<>(); // pass through group key for (int i = 0; i < groupCount; ++i) { projList.add(rexBuilder.makeInputRef(oldAggRel, i)); } // List of input expressions. If a particular aggregate needs more, it // will add an expression to the end, and we will create an extra // project. final RelBuilder relBuilder = ruleCall.builder(); relBuilder.push(oldAggRel.getInput()); final List inputExprs = new ArrayList<>(relBuilder.fields()); // create new aggregate function calls and rest of project list together for (AggregateCall oldCall : oldCalls) { projList.add( reduceAgg( oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs)); } final int extraArgCount = inputExprs.size() - relBuilder.peek().getRowType().getFieldCount(); if (extraArgCount > 0) { relBuilder.project(inputExprs, CompositeList.of( relBuilder.peek().getRowType().getFieldNames(), Collections.nCopies(extraArgCount, null))); } newAggregateRel(relBuilder, oldAggRel, newCalls); newCalcRel(relBuilder, oldAggRel.getRowType(), projList); final RelNode build = relBuilder.build(); ruleCall.transformTo(build); } private RexNode reduceAgg( Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping, List inputExprs) { if (canReduceAggCallByGrouping(oldAggRel, oldCall)) { // replace original MAX/MIN/AVG/ANY_VALUE/FIRST_VALUE/LAST_VALUE(x) with // target field of x, when x exists in group final RexNode reducedNode = reduceAggCallByGrouping(oldAggRel, oldCall); return reducedNode; } else if (canReduce(oldCall)) { final Integer y; final Integer x; final SqlKind kind = oldCall.getAggregation().getKind(); switch (kind) { case SUM: // replace original SUM(x) with // case COUNT(x) when 0 then null else SUM0(x) end return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping); case AVG: // replace original AVG(x) with SUM(x) / COUNT(x) return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs); case COVAR_POP: // replace original COVAR_POP(x, y) with // (SUM(x * y) - SUM(y) * SUM(y) / COUNT(x)) // / COUNT(x)) return reduceCovariance(oldAggRel, oldCall, true, newCalls, aggCallMapping, inputExprs); case COVAR_SAMP: // replace original COVAR_SAMP(x, y) with // SQRT( // (SUM(x * y) - SUM(x) * SUM(y) / COUNT(x)) // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END) return reduceCovariance(oldAggRel, oldCall, false, newCalls, aggCallMapping, inputExprs); case REGR_SXX: // replace original REGR_SXX(x, y) with // REGR_COUNT(x, y) * VAR_POP(y) assert oldCall.getArgList().size() == 2 : oldCall.getArgList(); x = oldCall.getArgList().get(0); y = oldCall.getArgList().get(1); //noinspection SuspiciousNameCombination return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs, y, y, x); case REGR_SYY: // replace original REGR_SYY(x, y) with // REGR_COUNT(x, y) * VAR_POP(x) assert oldCall.getArgList().size() == 2 : oldCall.getArgList(); x = oldCall.getArgList().get(0); y = oldCall.getArgList().get(1); //noinspection SuspiciousNameCombination return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs, x, x, y); case STDDEV_POP: // replace original STDDEV_POP(x) with // SQRT( // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) // / COUNT(x)) return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs); case STDDEV_SAMP: // replace original STDDEV_POP(x) with // SQRT( // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END) return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs); case VAR_POP: // replace original VAR_POP(x) with // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) // / COUNT(x) return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs); case VAR_SAMP: // replace original VAR_POP(x) with // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs); default: throw Util.unexpected(kind); } } else { // anything else: preserve original call RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); final int nGroups = oldAggRel.getGroupCount(); return rexBuilder.addAggCall(oldCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); } } private static AggregateCall createAggregateCallWithBinding( RelDataTypeFactory typeFactory, SqlAggFunction aggFunction, RelDataType operandType, Aggregate oldAggRel, AggregateCall oldCall, int argOrdinal, int filter) { final Aggregate.AggCallBinding binding = new Aggregate.AggCallBinding(typeFactory, aggFunction, ImmutableList.of(operandType), oldAggRel.getGroupCount(), filter >= 0); return AggregateCall.create(aggFunction, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), filter, oldCall.distinctKeys, oldCall.collation, aggFunction.inferReturnType(binding), null); } private static RexNode reduceAvg( Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping, @SuppressWarnings("unused") List inputExprs) { final int nGroups = oldAggRel.getGroupCount(); final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); final AggregateCall sumCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null); final AggregateCall countCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null); // NOTE: these references are with respect to the output // of newAggRel RexNode numeratorRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); final RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); final RelDataType avgType = typeFactory.createTypeWithNullability( oldCall.getType(), numeratorRef.getType().isNullable()); numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true); final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef); return rexBuilder.makeCast(oldCall.getType(), divideRef); } private static RexNode reduceSum( Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping) { final int nGroups = oldAggRel.getGroupCount(); RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); final AggregateCall sumZeroCall = AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, oldCall.name); final AggregateCall countCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel, null, null); // NOTE: these references are with respect to the output // of newAggRel RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); if (!oldCall.getType().isNullable()) { // If SUM(x) is not nullable, the validator must have determined that // nulls are impossible (because the group is never empty and x is never // null). Therefore we translate to SUM0(x). return sumZeroRef; } RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.makeNullLiteral(sumZeroRef.getType()), sumZeroRef); } private static RexNode reduceStddev( Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List newCalls, Map aggCallMapping, List inputExprs) { // stddev_pop(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) // / count(x), // .5) // // stddev_samp(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) // / nullif(count(x) - 1, 0), // .5) final int nGroups = oldAggRel.getGroupCount(); final RelOptCluster cluster = oldAggRel.getCluster(); final RexBuilder rexBuilder = cluster.getRexBuilder(); final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); assert oldCall.getArgList().size() == 1 : oldCall.getArgList(); final int argOrdinal = oldCall.getArgList().get(0); final IntPredicate fieldIsNullable = oldAggRel.getInput()::fieldIsNullable; final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), fieldIsNullable.test(argOrdinal)); final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true); final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef); final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared); final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM, argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal, -1); final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); final AggregateCall sumArgAggCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null); final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true); final RexNode sumSquaredArg = rexBuilder.makeCall( SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast); final AggregateCall countArgAggCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel, null, null); final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); final RexNode div = divide(biased, rexBuilder, sumArgSquared, sumSquaredArg, countArg); final RexNode result; if (sqrt) { final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5")); result = rexBuilder.makeCall( SqlStdOperatorTable.POWER, div, half); } else { result = div; } return rexBuilder.makeCast( oldCall.getType(), result); } private static RexNode reduceAggCallByGrouping( Aggregate oldAggRel, AggregateCall oldCall) { final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); final List oldGroups = oldAggRel.getGroupSet().asList(); final Integer firstArg = oldCall.getArgList().get(0); final int index = oldGroups.lastIndexOf(firstArg); assert index >= 0; final RexInputRef refByGroup = RexInputRef.of(index, oldAggRel.getRowType().getFieldList()); if (refByGroup.getType().equals(oldCall.getType())) { return refByGroup; } else { return rexBuilder.makeCast(oldCall.getType(), refByGroup); } } private static RexNode getSumAggregatedRexNode(Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping, RexBuilder rexBuilder, int argOrdinal, int filterArg) { final AggregateCall aggregateCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null); return rexBuilder.addAggCall(aggregateCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); } private static RexNode getSumAggregatedRexNodeWithBinding(Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping, RelDataType operandType, int argOrdinal, int filter) { RelOptCluster cluster = oldAggRel.getCluster(); final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(cluster.getTypeFactory(), SqlStdOperatorTable.SUM, operandType, oldAggRel, oldCall, argOrdinal, filter); return cluster.getRexBuilder().addAggCall(sumArgSquaredAggCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); } private static RexNode getRegrCountRexNode(Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping, ImmutableIntList argOrdinals, int filterArg) { final AggregateCall countArgAggCall = AggregateCall.create(SqlStdOperatorTable.REGR_COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), argOrdinals, filterArg, oldCall.distinctKeys, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel, null, null); return oldAggRel.getCluster().getRexBuilder().addAggCall(countArgAggCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable); } private static RexNode reduceRegrSzz( Aggregate oldAggRel, AggregateCall oldCall, List newCalls, Map aggCallMapping, List inputExprs, int xIndex, int yIndex, int nullFilterIndex) { // regr_sxx(x, y) ==> // sum(y * y, x) - sum(y, x) * sum(y, x) / regr_count(x, y) // final RelOptCluster cluster = oldAggRel.getCluster(); final RexBuilder rexBuilder = cluster.getRexBuilder(); final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); final IntPredicate fieldIsNullable = oldAggRel.getInput()::fieldIsNullable; final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), fieldIsNullable.test(xIndex) || fieldIsNullable.test(yIndex) || fieldIsNullable.test(nullFilterIndex)); final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(xIndex), true); final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(yIndex), true); final RexNode argNullFilter = rexBuilder.ensureType(oldCallType, inputExprs.get(nullFilterIndex), true); final RexNode argXArgY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY); final int argSquaredOrdinal = lookupOrAdd(inputExprs, argXArgY); final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND, rexBuilder.makeCall(SqlStdOperatorTable.AND, rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX), rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)), rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argNullFilter)); final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter); final RexNode sumXY = getSumAggregatedRexNodeWithBinding( oldAggRel, oldCall, newCalls, aggCallMapping, argXArgY.getType(), argSquaredOrdinal, argXAndYNotNullFilterOrdinal); final RexNode sumXYCast = rexBuilder.ensureType(oldCallType, sumXY, true); final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, xIndex, argXAndYNotNullFilterOrdinal); final RexNode sumY = xIndex == yIndex ? sumX : getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, yIndex, argXAndYNotNullFilterOrdinal); final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY); final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, ImmutableIntList.of(xIndex), argXAndYNotNullFilterOrdinal); RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO); RexNode nul = rexBuilder.makeNullLiteral(zero.getType()); final RexNode avgSumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, zero), nul, rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg)); final RexNode avgSumXSumYCast = rexBuilder.ensureType(oldCallType, avgSumXSumY, true); final RexNode result = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXYCast, avgSumXSumYCast); return rexBuilder.makeCast(oldCall.getType(), result); } private static RexNode reduceCovariance( Aggregate oldAggRel, AggregateCall oldCall, boolean biased, List newCalls, Map aggCallMapping, List inputExprs) { // covar_pop(x, y) ==> // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y)) // / regr_count(x, y) // // covar_samp(x, y) ==> // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y)) // / regr_count(count(x, y) - 1, 0) final RelOptCluster cluster = oldAggRel.getCluster(); final RexBuilder rexBuilder = cluster.getRexBuilder(); final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); assert oldCall.getArgList().size() == 2 : oldCall.getArgList(); final int argXOrdinal = oldCall.getArgList().get(0); final int argYOrdinal = oldCall.getArgList().get(1); final IntPredicate fieldIsNullable = oldAggRel.getInput()::fieldIsNullable; final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), fieldIsNullable.test(argXOrdinal) || fieldIsNullable.test(argYOrdinal)); final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(argXOrdinal), true); final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(argYOrdinal), true); final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND, rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX), rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)); final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter); final RexNode argXY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY); final int argXYOrdinal = lookupOrAdd(inputExprs, argXY); final RexNode sumXY = getSumAggregatedRexNodeWithBinding(oldAggRel, oldCall, newCalls, aggCallMapping, argXY.getType(), argXYOrdinal, argXAndYNotNullFilterOrdinal); final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, argXOrdinal, argXAndYNotNullFilterOrdinal); final RexNode sumY = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, argYOrdinal, argXAndYNotNullFilterOrdinal); final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY); final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, ImmutableIntList.of(argXOrdinal, argYOrdinal), argXAndYNotNullFilterOrdinal); final RexNode result = divide(biased, rexBuilder, sumXY, sumXSumY, countArg); return rexBuilder.makeCast(oldCall.getType(), result); } private static RexNode divide(boolean biased, RexBuilder rexBuilder, RexNode sumXY, RexNode sumXSumY, RexNode countArg) { final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg); final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXY, avgSumSquaredArg); final RexNode denominator; if (biased) { denominator = countArg; } else { final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE); final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType()); final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one); final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one); denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne); } return rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator); } /** * Finds the ordinal of an element in a list, or adds it. * * @param list List * @param element Element to lookup or add * @param Element type * @return Ordinal of element in list */ private static int lookupOrAdd(List list, T element) { int ordinal = list.indexOf(element); if (ordinal == -1) { ordinal = list.size(); list.add(element); } return ordinal; } /** * Does a shallow clone of oldAggRel and updates aggCalls. Could be refactored * into Aggregate and subclasses - but it's only needed for some * subclasses. * * @param relBuilder Builder of relational expressions; at the top of its * stack is its input * @param oldAggregate LogicalAggregate to clone. * @param newCalls New list of AggregateCalls */ protected void newAggregateRel(RelBuilder relBuilder, Aggregate oldAggregate, List newCalls) { relBuilder.aggregate( relBuilder.groupKey(oldAggregate.getGroupSet(), oldAggregate.getGroupSets()), newCalls); } /** * Adds a calculation with the expressions to compute the original aggregate * calls from the decomposed ones. * * @param relBuilder Builder of relational expressions; at the top of its * stack is its input * @param rowType The output row type of the original aggregate. * @param exprs The expressions to compute the original aggregate calls */ protected void newCalcRel(RelBuilder relBuilder, RelDataType rowType, List exprs) { relBuilder.project(exprs, rowType.getFieldNames()); } /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { Config DEFAULT = ImmutableAggregateReduceFunctionsRule.Config.of() .withOperandFor(LogicalAggregate.class); Set DEFAULT_FUNCTIONS_TO_REDUCE = ImmutableSet.builder() .addAll(SqlKind.AVG_AGG_FUNCTIONS) .addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS) .add(SqlKind.SUM) .build(); @Override default AggregateReduceFunctionsRule toRule() { return new AggregateReduceFunctionsRule(this); } /** The set of aggregate function types to try to reduce. * *

Any aggregate function whose type is omitted from this set, OR which * does not pass the {@link #extraCondition}, will be ignored. */ @Nullable Set functionsToReduce(); /** A test that must pass before attempting to reduce any aggregate function. * *

Any aggegate function which does not pass, OR whose type is omitted * from {@link #functionsToReduce}, will be ignored. The default predicate * always passes. */ @Value.Default default Predicate extraCondition() { return ignored -> true; } /** Sets {@link #functionsToReduce}. */ Config withFunctionsToReduce(@Nullable Iterable functionSet); default Config withFunctionsToReduce(@Nullable Set functionSet) { return withFunctionsToReduce((Iterable) functionSet); } /** Sets {@link #extraCondition}. */ Config withExtraCondition(Predicate test); /** Returns the validated set of functions to reduce, or the default set * if not specified. */ default Set actualFunctionsToReduce() { final Set set = Util.first(functionsToReduce(), DEFAULT_FUNCTIONS_TO_REDUCE); set.forEach(AggregateReduceFunctionsRule::validateFunction); return set; } /** Defines an operand tree for the given classes. */ default Config withOperandFor(Class aggregateClass) { return withOperandSupplier(b -> b.operand(aggregateClass).anyInputs()) .as(Config.class); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy