io.trino.sql.planner.plan.AggregationNode Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.plan;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.errorprone.annotations.Immutable;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.LambdaExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.type.FunctionType;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static java.util.Objects.requireNonNull;
@Immutable
public class AggregationNode
extends PlanNode
{
private final PlanNode source;
private final Map aggregations;
private final GroupingSetDescriptor groupingSets;
private final List preGroupedSymbols;
private final Step step;
private final Optional hashSymbol;
private final Optional groupIdSymbol;
private final List outputs;
public static AggregationNode singleAggregation(
PlanNodeId id,
PlanNode source,
Map aggregations,
GroupingSetDescriptor groupingSets)
{
return new AggregationNode(id, source, aggregations, groupingSets, ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
}
@JsonCreator
public AggregationNode(
@JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source,
@JsonProperty("aggregations") Map aggregations,
@JsonProperty("groupingSets") GroupingSetDescriptor groupingSets,
@JsonProperty("preGroupedSymbols") List preGroupedSymbols,
@JsonProperty("step") Step step,
@JsonProperty("hashSymbol") Optional hashSymbol,
@JsonProperty("groupIdSymbol") Optional groupIdSymbol)
{
super(id);
this.source = source;
this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null"));
aggregations.values().forEach(aggregation -> aggregation.verifyArguments(step));
requireNonNull(groupingSets, "groupingSets is null");
groupIdSymbol.ifPresent(symbol -> checkArgument(groupingSets.getGroupingKeys().contains(symbol), "Grouping columns does not contain groupId column"));
this.groupingSets = groupingSets;
this.groupIdSymbol = requireNonNull(groupIdSymbol);
boolean noOrderBy = aggregations.values().stream()
.map(Aggregation::getOrderingScheme)
.noneMatch(Optional::isPresent);
checkArgument(noOrderBy || step == SINGLE, "ORDER BY does not support distributed aggregation");
this.step = step;
this.hashSymbol = hashSymbol;
requireNonNull(preGroupedSymbols, "preGroupedSymbols is null");
checkArgument(preGroupedSymbols.isEmpty() || groupingSets.getGroupingKeys().containsAll(preGroupedSymbols), "Pre-grouped symbols must be a subset of the grouping keys");
this.preGroupedSymbols = ImmutableList.copyOf(preGroupedSymbols);
ImmutableList.Builder outputs = ImmutableList.builder();
outputs.addAll(groupingSets.getGroupingKeys());
hashSymbol.ifPresent(outputs::add);
outputs.addAll(aggregations.keySet());
this.outputs = outputs.build();
}
public List getGroupingKeys()
{
return groupingSets.getGroupingKeys();
}
@JsonProperty("groupingSets")
public GroupingSetDescriptor getGroupingSets()
{
return groupingSets;
}
/**
* @return true if the aggregation collapses all rows into a single global group (e.g., as a result of a GROUP BY () query).
* Otherwise, false.
*/
public boolean hasSingleGlobalAggregation()
{
return hasEmptyGroupingSet() && getGroupingSetCount() == 1;
}
/**
* @return whether this node should produce default output in case of no input pages.
* For example for query:
*
* SELECT count(*) FROM nation WHERE nationkey < 0
*
* A default output of "0" is expected to be produced by FINAL aggregation operator.
*/
public boolean hasDefaultOutput()
{
return hasEmptyGroupingSet() && (step.isOutputPartial() || step == SINGLE);
}
public boolean hasEmptyGroupingSet()
{
return !groupingSets.getGlobalGroupingSets().isEmpty();
}
public boolean hasNonEmptyGroupingSet()
{
return groupingSets.getGroupingSetCount() > groupingSets.getGlobalGroupingSets().size();
}
@Override
public List getSources()
{
return ImmutableList.of(source);
}
@Override
public List getOutputSymbols()
{
return outputs;
}
@JsonProperty
public Map getAggregations()
{
return aggregations;
}
@JsonProperty("preGroupedSymbols")
public List getPreGroupedSymbols()
{
return preGroupedSymbols;
}
public int getGroupingSetCount()
{
return groupingSets.getGroupingSetCount();
}
public Set getGlobalGroupingSets()
{
return groupingSets.getGlobalGroupingSets();
}
@JsonProperty("source")
public PlanNode getSource()
{
return source;
}
@JsonProperty("step")
public Step getStep()
{
return step;
}
@JsonProperty("hashSymbol")
public Optional getHashSymbol()
{
return hashSymbol;
}
@JsonProperty("groupIdSymbol")
public Optional getGroupIdSymbol()
{
return groupIdSymbol;
}
public boolean hasOrderings()
{
return aggregations.values().stream()
.map(Aggregation::getOrderingScheme)
.anyMatch(Optional::isPresent);
}
@Override
public R accept(PlanVisitor visitor, C context)
{
return visitor.visitAggregation(this, context);
}
@Override
public PlanNode replaceChildren(List newChildren)
{
return builderFrom(this)
.setSource(Iterables.getOnlyElement(newChildren))
.build();
}
public boolean producesDistinctRows()
{
return aggregations.isEmpty() &&
!groupingSets.getGroupingKeys().isEmpty() &&
outputs.size() == groupingSets.getGroupingKeys().size() &&
outputs.containsAll(new HashSet<>(groupingSets.getGroupingKeys()));
}
public boolean isDecomposable(Session session, Metadata metadata)
{
boolean hasOrderBy = getAggregations().values().stream()
.map(Aggregation::getOrderingScheme)
.anyMatch(Optional::isPresent);
boolean hasDistinct = getAggregations().values().stream()
.anyMatch(Aggregation::isDistinct);
boolean decomposableFunctions = getAggregations().values().stream()
.map(Aggregation::getResolvedFunction)
.map(resolvedFunction -> metadata.getAggregationFunctionMetadata(session, resolvedFunction))
.allMatch(AggregationFunctionMetadata::isDecomposable);
return !hasOrderBy && !hasDistinct && decomposableFunctions;
}
public boolean hasSingleNodeExecutionPreference(Session session, Metadata metadata)
{
// There are two kinds of aggregations the have single node execution preference:
//
// 1. aggregations with only empty grouping sets like
//
// SELECT count(*) FROM lineitem;
//
// there is no need for distributed aggregation. Single node FINAL aggregation will suffice,
// since all input have to be aggregated into one line output.
//
// 2. aggregations that must produce default output and are not decomposable, we cannot distribute them.
return (hasEmptyGroupingSet() && !hasNonEmptyGroupingSet()) || (hasDefaultOutput() && !isDecomposable(session, metadata));
}
public boolean isStreamable()
{
return ImmutableSet.copyOf(preGroupedSymbols).equals(ImmutableSet.copyOf(groupingSets.getGroupingKeys()))
&& groupingSets.getGroupingSetCount() == 1
&& groupingSets.getGlobalGroupingSets().isEmpty();
}
public static GroupingSetDescriptor globalAggregation()
{
return singleGroupingSet(ImmutableList.of());
}
public static GroupingSetDescriptor singleGroupingSet(List groupingKeys)
{
Set globalGroupingSets;
if (groupingKeys.isEmpty()) {
globalGroupingSets = ImmutableSet.of(0);
}
else {
globalGroupingSets = ImmutableSet.of();
}
return new GroupingSetDescriptor(groupingKeys, 1, globalGroupingSets);
}
public static GroupingSetDescriptor groupingSets(List groupingKeys, int groupingSetCount, Set globalGroupingSets)
{
return new GroupingSetDescriptor(groupingKeys, groupingSetCount, globalGroupingSets);
}
public static class GroupingSetDescriptor
{
private final List groupingKeys;
private final int groupingSetCount;
private final Set globalGroupingSets;
@JsonCreator
public GroupingSetDescriptor(
@JsonProperty("groupingKeys") List groupingKeys,
@JsonProperty("groupingSetCount") int groupingSetCount,
@JsonProperty("globalGroupingSets") Set globalGroupingSets)
{
requireNonNull(globalGroupingSets, "globalGroupingSets is null");
checkArgument(groupingSetCount > 0, "grouping set count must be larger than 0");
checkArgument(globalGroupingSets.size() <= groupingSetCount, "list of empty global grouping sets must be no larger than grouping set count");
requireNonNull(groupingKeys, "groupingKeys is null");
if (groupingKeys.isEmpty()) {
checkArgument(!globalGroupingSets.isEmpty(), "no grouping keys implies at least one global grouping set, but none provided");
}
this.groupingKeys = ImmutableList.copyOf(groupingKeys);
this.groupingSetCount = groupingSetCount;
this.globalGroupingSets = ImmutableSet.copyOf(globalGroupingSets);
}
@JsonProperty
public List getGroupingKeys()
{
return groupingKeys;
}
@JsonProperty
public int getGroupingSetCount()
{
return groupingSetCount;
}
@JsonProperty
public Set getGlobalGroupingSets()
{
return globalGroupingSets;
}
}
public enum Step
{
PARTIAL(true, true),
FINAL(false, false),
INTERMEDIATE(false, true),
SINGLE(true, false);
private final boolean inputRaw;
private final boolean outputPartial;
Step(boolean inputRaw, boolean outputPartial)
{
this.inputRaw = inputRaw;
this.outputPartial = outputPartial;
}
public boolean isInputRaw()
{
return inputRaw;
}
public boolean isOutputPartial()
{
return outputPartial;
}
public static Step partialOutput(Step step)
{
if (step.isInputRaw()) {
return Step.PARTIAL;
}
return Step.INTERMEDIATE;
}
public static Step partialInput(Step step)
{
if (step.isOutputPartial()) {
return Step.INTERMEDIATE;
}
return Step.FINAL;
}
}
public static class Aggregation
{
private final ResolvedFunction resolvedFunction;
private final List arguments;
private final boolean distinct;
private final Optional filter;
private final Optional orderingScheme;
private final Optional mask;
@JsonCreator
public Aggregation(
@JsonProperty("resolvedFunction") ResolvedFunction resolvedFunction,
@JsonProperty("arguments") List arguments,
@JsonProperty("distinct") boolean distinct,
@JsonProperty("filter") Optional filter,
@JsonProperty("orderingScheme") Optional orderingScheme,
@JsonProperty("mask") Optional mask)
{
this.resolvedFunction = requireNonNull(resolvedFunction, "resolvedFunction is null");
this.arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments is null"));
for (Expression argument : arguments) {
checkArgument(argument instanceof SymbolReference || argument instanceof LambdaExpression,
"argument must be symbol or lambda expression: %s", argument.getClass().getSimpleName());
}
this.distinct = distinct;
this.filter = requireNonNull(filter, "filter is null");
this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null");
this.mask = requireNonNull(mask, "mask is null");
}
@JsonProperty
public ResolvedFunction getResolvedFunction()
{
return resolvedFunction;
}
@JsonProperty
public List getArguments()
{
return arguments;
}
@JsonProperty
public boolean isDistinct()
{
return distinct;
}
@JsonProperty
public Optional getFilter()
{
return filter;
}
@JsonProperty
public Optional getOrderingScheme()
{
return orderingScheme;
}
@JsonProperty
public Optional getMask()
{
return mask;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Aggregation that = (Aggregation) o;
return distinct == that.distinct &&
Objects.equals(resolvedFunction, that.resolvedFunction) &&
Objects.equals(arguments, that.arguments) &&
Objects.equals(filter, that.filter) &&
Objects.equals(orderingScheme, that.orderingScheme) &&
Objects.equals(mask, that.mask);
}
@Override
public int hashCode()
{
return Objects.hash(resolvedFunction, arguments, distinct, filter, orderingScheme, mask);
}
private void verifyArguments(Step step)
{
int expectedArgumentCount;
if (step == SINGLE || step == Step.PARTIAL) {
expectedArgumentCount = resolvedFunction.getSignature().getArgumentTypes().size();
}
else {
// Intermediate and final steps get the intermediate value and the lambda functions
expectedArgumentCount = 1 + (int) resolvedFunction.getSignature().getArgumentTypes().stream()
.filter(FunctionType.class::isInstance)
.count();
}
checkArgument(
expectedArgumentCount == arguments.size(),
"%s aggregation function %s has %s arguments, but %s arguments were provided to function call",
step,
resolvedFunction.getSignature(),
expectedArgumentCount,
arguments.size());
}
}
public static Builder builderFrom(AggregationNode node)
{
return new Builder(node);
}
public static class Builder
{
private PlanNodeId id;
private PlanNode source;
private Map aggregations;
private GroupingSetDescriptor groupingSets;
private List preGroupedSymbols;
private Step step;
private Optional hashSymbol;
private Optional groupIdSymbol;
public Builder(AggregationNode node)
{
requireNonNull(node, "node is null");
this.id = node.getId();
this.source = node.getSource();
this.aggregations = node.getAggregations();
this.groupingSets = node.getGroupingSets();
this.preGroupedSymbols = node.getPreGroupedSymbols();
this.step = node.getStep();
this.hashSymbol = node.getHashSymbol();
this.groupIdSymbol = node.getGroupIdSymbol();
}
public Builder setId(PlanNodeId id)
{
this.id = requireNonNull(id, "id is null");
return this;
}
public Builder setSource(PlanNode source)
{
this.source = requireNonNull(source, "source is null");
return this;
}
public Builder setAggregations(Map aggregations)
{
this.aggregations = requireNonNull(aggregations, "aggregations is null");
return this;
}
public Builder setGroupingSets(GroupingSetDescriptor groupingSets)
{
this.groupingSets = requireNonNull(groupingSets, "groupingSets is null");
return this;
}
public Builder setPreGroupedSymbols(List preGroupedSymbols)
{
this.preGroupedSymbols = requireNonNull(preGroupedSymbols, "preGroupedSymbols is null");
return this;
}
public Builder setStep(Step step)
{
this.step = requireNonNull(step, "step is null");
return this;
}
public Builder setHashSymbol(Optional hashSymbol)
{
this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null");
return this;
}
public Builder setGroupIdSymbol(Optional groupIdSymbol)
{
this.groupIdSymbol = requireNonNull(groupIdSymbol, "groupIdSymbol is null");
return this;
}
public AggregationNode build()
{
return new AggregationNode(
id,
source,
aggregations,
groupingSets,
preGroupedSymbols,
step,
hashSymbol,
groupIdSymbol);
}
}
}