![JAR search and dependency download from the Maven repository](/logo.png)
io.trino.sql.planner.iterative.rule.PreAggregateCaseAggregations Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.TrinoException;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WhenClause;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.SystemSessionProperties.isPreAggregateCaseAggregationsEnabled;
import static io.trino.matching.Capture.newCapture;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.project;
import static io.trino.sql.planner.plan.Patterns.source;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.function.Predicate.not;
/**
* Rule that transforms selected aggregations from:
* {@code
* - Aggregation[key]: sum(aggr1), sum(aggr2), ..., sum(aggrN)
* - Project:
* aggr1 = CASE WHEN col=1 THEN expr ELSE 0
* aggr2 = CASE WHEN col=2 THEN expr ELSE null
* ...
* aggrN = CASE WHEN col=N THEN expr
* - source
* }
*
* into:
* {@code
* - Aggregation[key]: sum(aggr1), sum(aggr2), ..., sum(aggrN)
* - Project:
* aggr1 = CASE WHEN col=1 THEN pre_aggregate
* aggr2 = CASE WHEN col=2 THEN pre_aggregate
* ..
* aggrN = CASE WHEN col=N THEN pre_aggregate
* - Aggregation[key, col]: pre_aggregate = sum(expr)
* - source
* }
*
*/
public class PreAggregateCaseAggregations
implements Rule
{
private static final int MIN_AGGREGATION_COUNT = 4;
// BE EXTREMELY CAREFUL WHEN ADDING NEW FUNCTIONS TO THIS SET
// This code appears to be generic, but is not. It only works because the allowed functions have very specific behavior.
private static final CatalogSchemaFunctionName MAX = builtinFunctionName("max");
private static final CatalogSchemaFunctionName MIN = builtinFunctionName("min");
private static final CatalogSchemaFunctionName SUM = builtinFunctionName("sum");
private static final Set ALLOWED_FUNCTIONS = ImmutableSet.of(MAX, MIN, SUM);
private static final Capture PROJECT_CAPTURE = newCapture();
private static final Pattern PATTERN = aggregation()
.matching(aggregation -> aggregation.getStep() == SINGLE && aggregation.getGroupingSetCount() == 1)
.with(source().matching(project().capturedAs(PROJECT_CAPTURE)
// prevent rule from looping by ensuring that projection source is not aggregation
.with(source().matching(not(AggregationNode.class::isInstance)))));
private final PlannerContext plannerContext;
private final TypeAnalyzer typeAnalyzer;
public PreAggregateCaseAggregations(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
}
@Override
public Pattern getPattern()
{
return PATTERN;
}
@Override
public boolean isEnabled(Session session)
{
return isPreAggregateCaseAggregationsEnabled(session);
}
@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
{
ProjectNode projectNode = captures.get(PROJECT_CAPTURE);
Optional> aggregationsOptional = extractCaseAggregations(aggregationNode, projectNode, context);
if (aggregationsOptional.isEmpty()) {
return Result.empty();
}
List aggregations = aggregationsOptional.get();
if (aggregations.size() < MIN_AGGREGATION_COUNT) {
return Result.empty();
}
Set extraGroupingKeys = aggregations.stream()
.flatMap(expression -> SymbolsExtractor.extractUnique(expression.getOperand()).stream())
.collect(toImmutableSet());
if (extraGroupingKeys.size() != 1) {
// pre-aggregation can have single extra symbol
return Result.empty();
}
Map preAggregations = getPreAggregations(aggregations, context);
Assignments.Builder preGroupingExpressionsBuilder = Assignments.builder();
preGroupingExpressionsBuilder.putIdentities(extraGroupingKeys);
aggregationNode.getGroupingKeys().forEach(symbol -> preGroupingExpressionsBuilder.put(symbol, projectNode.getAssignments().get(symbol)));
Assignments preGroupingExpressions = preGroupingExpressionsBuilder.build();
ProjectNode preProjection = createPreProjection(
projectNode.getSource(),
preGroupingExpressions,
preAggregations,
context);
AggregationNode preAggregation = createPreAggregation(
preProjection,
preGroupingExpressions.getOutputs(),
preAggregations,
context);
Map newProjectionSymbols = getNewProjectionSymbols(aggregations, context);
ProjectNode newProjection = createNewProjection(
preAggregation,
aggregationNode,
projectNode,
newProjectionSymbols,
preAggregations);
return Result.ofPlanNode(createNewAggregation(newProjection, aggregationNode, newProjectionSymbols));
}
private AggregationNode createNewAggregation(
PlanNode source,
AggregationNode aggregationNode,
Map newProjectionSymbols)
{
return new AggregationNode(
aggregationNode.getId(),
source,
newProjectionSymbols.entrySet().stream()
.collect(toImmutableMap(
entry -> entry.getKey().getAggregationSymbol(),
entry -> new Aggregation(
entry.getKey().getCumulativeFunction(),
ImmutableList.of(entry.getValue().toSymbolReference()),
false,
Optional.empty(),
Optional.empty(),
Optional.empty()))),
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedSymbols(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol());
}
private ProjectNode createNewProjection(
PlanNode source,
AggregationNode aggregationNode,
ProjectNode projectNode,
Map newProjectionSymbols,
Map preAggregations)
{
Assignments.Builder assignments = Assignments.builder();
// grouping key expressions are already evaluated in pre-projection
assignments.putIdentities(aggregationNode.getGroupingKeys());
newProjectionSymbols.forEach((aggregation, symbol) -> assignments.put(
symbol,
new SearchedCaseExpression(ImmutableList.of(
new WhenClause(
aggregation.getOperand(),
preAggregations.get(new PreAggregationKey(aggregation)).getAggregationSymbol().toSymbolReference())),
aggregation.getCumulativeAggregationDefaultValue())));
return new ProjectNode(projectNode.getId(), source, assignments.build());
}
private Map getNewProjectionSymbols(List aggregations, Context context)
{
return aggregations.stream()
.collect(toImmutableMap(
identity(),
// new projection has the same type as original aggregation
aggregation -> context.getSymbolAllocator().newSymbol(aggregation.getAggregationSymbol())));
}
private AggregationNode createPreAggregation(
PlanNode source,
List groupingKeys,
Map preAggregations,
Context context)
{
return new AggregationNode(
context.getIdAllocator().getNextId(),
source,
preAggregations.entrySet().stream()
.collect(toImmutableMap(
entry -> entry.getValue().getAggregationSymbol(),
entry -> new Aggregation(
entry.getKey().getFunction(),
ImmutableList.of(entry.getValue().getProjectionSymbol().toSymbolReference()),
false,
Optional.empty(),
Optional.empty(),
Optional.empty()))),
singleGroupingSet(groupingKeys),
ImmutableList.of(),
SINGLE,
Optional.empty(),
Optional.empty());
}
private ProjectNode createPreProjection(
PlanNode source,
Assignments groupingExpressions,
Map preAggregations,
Context context)
{
Assignments.Builder assignments = Assignments.builder();
assignments.putAll(groupingExpressions);
preAggregations.values().forEach(aggregation -> assignments.put(aggregation.getProjectionSymbol(), aggregation.getProjection()));
return new ProjectNode(context.getIdAllocator().getNextId(), source, assignments.build());
}
private Map getPreAggregations(List aggregations, Context context)
{
Set keys = new HashSet<>();
ImmutableMap.Builder preAggregations = ImmutableMap.builder();
for (CaseAggregation aggregation : aggregations) {
PreAggregationKey preAggregationKey = new PreAggregationKey(aggregation);
if (keys.contains(preAggregationKey)) {
continue;
}
// Cast pre-projection if needed to match aggregation input type.
// This is because entire "CASE WHEN" expression could be wrapped in CAST.
Expression preProjection = aggregation.getResult();
Type preProjectionType = getType(context, preProjection);
Type aggregationInputType = getOnlyElement(aggregation.getFunction().getSignature().getArgumentTypes());
if (!preProjectionType.equals(aggregationInputType)) {
preProjection = new Cast(preProjection, toSqlType(aggregationInputType));
preProjectionType = aggregationInputType;
}
Symbol preProjectionSymbol = context.getSymbolAllocator().newSymbol(preProjection, preProjectionType);
Symbol preAggregationSymbol = context.getSymbolAllocator().newSymbol(aggregation.getAggregationSymbol());
preAggregations.put(preAggregationKey, new PreAggregation(preAggregationSymbol, preProjection, preProjectionSymbol));
keys.add(preAggregationKey);
}
return ImmutableMap.copyOf(preAggregations.buildOrThrow());
}
private Optional> extractCaseAggregations(AggregationNode aggregationNode, ProjectNode projectNode, Context context)
{
ImmutableList.Builder caseAggregations = ImmutableList.builder();
for (Map.Entry aggregation : aggregationNode.getAggregations().entrySet()) {
Optional caseAggregation = extractCaseAggregation(
aggregation.getKey(),
aggregation.getValue(),
projectNode,
context);
if (caseAggregation.isEmpty()) {
return Optional.empty();
}
caseAggregations.add(caseAggregation.get());
}
return Optional.of(caseAggregations.build());
}
private Optional extractCaseAggregation(Symbol aggregationSymbol, Aggregation aggregation, ProjectNode projectNode, Context context)
{
if (aggregation.getArguments().size() != 1
|| !(aggregation.getArguments().get(0) instanceof SymbolReference)
|| aggregation.isDistinct()
|| aggregation.getFilter().isPresent()
|| aggregation.getMask().isPresent()
|| aggregation.getOrderingScheme().isPresent()) {
// aggregation must be a basic aggregation
return Optional.empty();
}
ResolvedFunction resolvedFunction = aggregation.getResolvedFunction();
CatalogSchemaFunctionName name = resolvedFunction.getSignature().getName();
if (!ALLOWED_FUNCTIONS.contains(name)) {
// only cumulative aggregations (e.g. that can be split into aggregation of aggregations) are supported
return Optional.empty();
}
Symbol projectionSymbol = Symbol.from(aggregation.getArguments().get(0));
Expression projection = projectNode.getAssignments().get(projectionSymbol);
Expression unwrappedProjection;
// unwrap top-level cast
if (projection instanceof Cast) {
unwrappedProjection = ((Cast) projection).getExpression();
}
else {
unwrappedProjection = projection;
}
if (!(unwrappedProjection instanceof SearchedCaseExpression caseExpression)) {
return Optional.empty();
}
if (caseExpression.getWhenClauses().size() != 1) {
return Optional.empty();
}
Type aggregationType = resolvedFunction.getSignature().getReturnType();
ResolvedFunction cumulativeFunction;
try {
cumulativeFunction = plannerContext.getMetadata().resolveBuiltinFunction(name.getFunctionName(), fromTypes(aggregationType));
}
catch (TrinoException e) {
// there is no cumulative aggregation
return Optional.empty();
}
if (!cumulativeFunction.getSignature().getReturnType().equals(aggregationType)) {
// aggregation type after rewrite must not change
return Optional.empty();
}
Optional cumulativeAggregationDefaultValue = Optional.empty();
if (caseExpression.getDefaultValue().isPresent()) {
Type defaultType = getType(context, caseExpression.getDefaultValue().get());
Object defaultValue = optimizeExpression(caseExpression.getDefaultValue().get(), context);
if (defaultValue != null) {
if (!name.equals(SUM)) {
return Optional.empty();
}
// sum aggregation is only supported if default value is null or 0, otherwise it wouldn't be cumulative
if (defaultType instanceof BigintType
|| defaultType == INTEGER
|| defaultType == SMALLINT
|| defaultType == TINYINT
|| defaultType == DOUBLE
|| defaultType == REAL
|| defaultType instanceof DecimalType) {
if (!defaultValue.equals(0L) && !defaultValue.equals(0.0d) && !defaultValue.equals(Int128.ZERO)) {
return Optional.empty();
}
}
else {
return Optional.empty();
}
}
// cumulative aggregation default value need to be CAST to cumulative aggregation input type
cumulativeAggregationDefaultValue = Optional.of(new Cast(
caseExpression.getDefaultValue().get(),
toSqlType(aggregationType)));
}
return Optional.of(new CaseAggregation(
aggregationSymbol,
resolvedFunction,
cumulativeFunction,
name,
caseExpression.getWhenClauses().get(0).getOperand(),
caseExpression.getWhenClauses().get(0).getResult(),
cumulativeAggregationDefaultValue));
}
private Type getType(Context context, Expression expression)
{
return typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), expression);
}
private Object optimizeExpression(Expression expression, Context context)
{
Map, Type> expressionTypes = typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), expression);
ExpressionInterpreter expressionInterpreter = new ExpressionInterpreter(expression, plannerContext, context.getSession(), expressionTypes);
return expressionInterpreter.optimize(Symbol::toSymbolReference);
}
private static class CaseAggregation
{
// original aggregation symbol
private final Symbol aggregationSymbol;
// original aggregation function
private final ResolvedFunction function;
// cumulative aggregation function (e.g. aggregation of aggregations)
private final ResolvedFunction cumulativeFunction;
// aggregation function name
private final CatalogSchemaFunctionName name;
// CASE expression only operand expression
private final Expression operand;
// CASE expression only result expression
private final Expression result;
// default value of cumulative aggregation
private final Optional cumulativeAggregationDefaultValue;
public CaseAggregation(
Symbol aggregationSymbol,
ResolvedFunction function,
ResolvedFunction cumulativeFunction,
CatalogSchemaFunctionName name,
Expression operand,
Expression result,
Optional cumulativeAggregationDefaultValue)
{
this.aggregationSymbol = requireNonNull(aggregationSymbol, "aggregationSymbol is null");
this.function = requireNonNull(function, "function is null");
this.cumulativeFunction = requireNonNull(cumulativeFunction, "cumulativeFunction is null");
this.name = requireNonNull(name, "name is null");
this.operand = requireNonNull(operand, "operand is null");
this.result = requireNonNull(result, "result is null");
this.cumulativeAggregationDefaultValue = requireNonNull(cumulativeAggregationDefaultValue, "cumulativeAggregationDefaultValue is null");
}
public Symbol getAggregationSymbol()
{
return aggregationSymbol;
}
public ResolvedFunction getFunction()
{
return function;
}
public ResolvedFunction getCumulativeFunction()
{
return cumulativeFunction;
}
public CatalogSchemaFunctionName getName()
{
return name;
}
public Expression getOperand()
{
return operand;
}
public Expression getResult()
{
return result;
}
public Optional getCumulativeAggregationDefaultValue()
{
return cumulativeAggregationDefaultValue;
}
}
private static class PreAggregationKey
{
// original aggregation function
private final ResolvedFunction function;
// projected input to aggregation (CASE expression only result expression)
private final Expression projection;
private PreAggregationKey(CaseAggregation aggregation)
{
this.function = aggregation.getFunction();
this.projection = aggregation.getResult();
}
public ResolvedFunction getFunction()
{
return function;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PreAggregationKey that = (PreAggregationKey) o;
return Objects.equals(function, that.function) &&
Objects.equals(projection, that.projection);
}
@Override
public int hashCode()
{
return Objects.hash(function, projection);
}
}
private static class PreAggregation
{
private final Symbol aggregationSymbol;
private final Expression projection;
private final Symbol projectionSymbol;
public PreAggregation(Symbol aggregationSymbol, Expression projection, Symbol projectionSymbol)
{
this.aggregationSymbol = requireNonNull(aggregationSymbol, "aggregationSymbol is null");
this.projection = requireNonNull(projection, "projection is null");
this.projectionSymbol = requireNonNull(projectionSymbol, "projectionSymbol is null");
}
public Symbol getAggregationSymbol()
{
return aggregationSymbol;
}
public Expression getProjection()
{
return projection;
}
public Symbol getProjectionSymbol()
{
return projectionSymbol;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy