com.facebook.presto.jdbc.internal.spi.plan.AggregationNode Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.jdbc.internal.spi.plan;
import com.facebook.presto.jdbc.internal.spi.SourceLocation;
import com.facebook.presto.jdbc.internal.spi.function.FunctionHandle;
import com.facebook.presto.jdbc.internal.spi.relation.CallExpression;
import com.facebook.presto.jdbc.internal.spi.relation.RowExpression;
import com.facebook.presto.jdbc.internal.spi.relation.VariableReferenceExpression;
import com.facebook.presto.jdbc.internal.jackson.annotation.JsonCreator;
import com.facebook.presto.jdbc.internal.jackson.annotation.JsonProperty;
import com.facebook.presto.jdbc.internal.javax.annotation.concurrent.Immutable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.jdbc.internal.spi.plan.AggregationNode.Step.SINGLE;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
import static java.util.Collections.unmodifiableSet;
import static java.util.Objects.requireNonNull;
@Immutable
public final class AggregationNode
extends PlanNode
{
private final PlanNode source;
private final Map aggregations;
private final GroupingSetDescriptor groupingSets;
private final List preGroupedVariables;
private final Step step;
private final Optional hashVariable;
private final Optional groupIdVariable;
private final List outputs;
@JsonCreator
public AggregationNode(
Optional sourceLocation,
@JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source,
@JsonProperty("aggregations") Map aggregations,
@JsonProperty("groupingSets") GroupingSetDescriptor groupingSets,
@JsonProperty("preGroupedVariables") List preGroupedVariables,
@JsonProperty("step") Step step,
@JsonProperty("hashVariable") Optional hashVariable,
@JsonProperty("groupIdVariable") Optional groupIdVariable)
{
this(sourceLocation, id, Optional.empty(), source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable);
}
public AggregationNode(
Optional sourceLocation,
PlanNodeId id,
Optional statsEquivalentPlanNode,
PlanNode source,
Map aggregations,
GroupingSetDescriptor groupingSets,
List preGroupedVariables,
Step step,
Optional hashVariable,
Optional groupIdVariable)
{
super(sourceLocation, id, statsEquivalentPlanNode);
this.source = source;
this.aggregations = unmodifiableMap(new LinkedHashMap<>(requireNonNull(aggregations, "aggregations is null")));
requireNonNull(groupingSets, "groupingSets is null");
groupIdVariable.ifPresent(variable -> checkArgument(groupingSets.getGroupingKeys().contains(variable), "Grouping columns does not contain groupId column"));
this.groupingSets = groupingSets;
this.groupIdVariable = requireNonNull(groupIdVariable);
boolean noOrderBy = aggregations.values().stream()
.map(Aggregation::getOrderBy)
.noneMatch(Optional::isPresent);
checkArgument(noOrderBy || step == SINGLE, "ORDER BY does not support distributed aggregation");
this.step = step;
this.hashVariable = hashVariable;
requireNonNull(preGroupedVariables, "preGroupedVariables is null");
checkArgument(preGroupedVariables.isEmpty() || groupingSets.getGroupingKeys().containsAll(preGroupedVariables), "Pre-grouped variables must be a subset of the grouping keys");
this.preGroupedVariables = unmodifiableList(new ArrayList<>(preGroupedVariables));
ArrayList keys = new ArrayList<>(groupingSets.getGroupingKeys());
hashVariable.ifPresent(keys::add);
keys.addAll(new ArrayList<>(aggregations.keySet()));
this.outputs = unmodifiableList(keys);
}
/**
* Whether this node corresponds to a DISTINCT operation in SQL
*/
public static boolean isDistinct(AggregationNode node)
{
return node.getAggregations().isEmpty() &&
node.getOutputVariables().size() == node.getGroupingKeys().size() &&
node.getOutputVariables().containsAll(node.getGroupingKeys());
}
public List getGroupingKeys()
{
return groupingSets.getGroupingKeys();
}
@JsonProperty
public GroupingSetDescriptor getGroupingSets()
{
return groupingSets;
}
/**
* @return whether this node should produce default output in case of no input pages.
* For example for query:
*
* SELECT count(*) FROM nation WHERE nationkey < 0
*
* A default output of "0" is expected to be produced by FINAL aggregation operator.
*/
public boolean hasDefaultOutput()
{
return hasEmptyGroupingSet() && (step.isOutputPartial() || step.equals(SINGLE));
}
public boolean hasEmptyGroupingSet()
{
return !groupingSets.getGlobalGroupingSets().isEmpty();
}
public boolean hasNonEmptyGroupingSet()
{
return groupingSets.getGroupingSetCount() > groupingSets.getGlobalGroupingSets().size();
}
@Override
public List getSources()
{
return Collections.singletonList(source);
}
@Override
public LogicalProperties computeLogicalProperties(LogicalPropertiesProvider logicalPropertiesProvider)
{
requireNonNull(logicalPropertiesProvider, "logicalPropertiesProvider cannot be null.");
return logicalPropertiesProvider.getAggregationProperties(this);
}
@Override
public List getOutputVariables()
{
return outputs;
}
@JsonProperty
public Map getAggregations()
{
return aggregations;
}
@JsonProperty
public List getPreGroupedVariables()
{
return preGroupedVariables;
}
public int getGroupingSetCount()
{
return groupingSets.getGroupingSetCount();
}
public Set getGlobalGroupingSets()
{
return groupingSets.getGlobalGroupingSets();
}
@JsonProperty
public PlanNode getSource()
{
return source;
}
@JsonProperty
public Step getStep()
{
return step;
}
@JsonProperty
public Optional getHashVariable()
{
return hashVariable;
}
@JsonProperty
public Optional getGroupIdVariable()
{
return groupIdVariable;
}
public boolean hasOrderings()
{
return aggregations.values().stream()
.map(Aggregation::getOrderBy)
.anyMatch(Optional::isPresent);
}
@Override
public R accept(PlanVisitor visitor, C context)
{
return visitor.visitAggregation(this, context);
}
@Override
public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode)
{
return new AggregationNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable);
}
@Override
public PlanNode replaceChildren(List newChildren)
{
checkArgument(newChildren.size() == 1, "Unexpected number of elements in list newChildren");
return new AggregationNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), newChildren.get(0), aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable);
}
public boolean isStreamable()
{
return !preGroupedVariables.isEmpty()
&& groupingSets.getGroupingSetCount() == 1
&& groupingSets.getGlobalGroupingSets().isEmpty()
&& preGroupedVariables.size() == groupingSets.groupingKeys.size();
}
public boolean isSegmentedAggregationEligible()
{
return !preGroupedVariables.isEmpty()
&& groupingSets.getGroupingSetCount() == 1
&& groupingSets.getGlobalGroupingSets().isEmpty()
&& preGroupedVariables.size() < groupingSets.groupingKeys.size();
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AggregationNode that = (AggregationNode) o;
return Objects.equals(source, that.source) &&
Objects.equals(aggregations, that.aggregations) &&
Objects.equals(groupingSets, that.groupingSets) &&
Objects.equals(preGroupedVariables, that.preGroupedVariables) &&
step == that.step &&
Objects.equals(hashVariable, that.hashVariable) &&
Objects.equals(groupIdVariable, that.groupIdVariable) &&
Objects.equals(outputs, that.outputs);
}
@Override
public int hashCode()
{
return Objects.hash(source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable, outputs);
}
public static GroupingSetDescriptor globalAggregation()
{
return singleGroupingSet(emptyList());
}
public static GroupingSetDescriptor singleGroupingSet(List groupingKeys)
{
Set globalGroupingSets;
if (groupingKeys.isEmpty()) {
globalGroupingSets = Collections.singleton(0);
}
else {
globalGroupingSets = emptySet();
}
return new GroupingSetDescriptor(groupingKeys, 1, globalGroupingSets);
}
public static GroupingSetDescriptor groupingSets(List groupingKeys, int groupingSetCount, Set globalGroupingSets)
{
return new GroupingSetDescriptor(groupingKeys, groupingSetCount, globalGroupingSets);
}
public static class GroupingSetDescriptor
{
private final List groupingKeys;
private final int groupingSetCount;
private final Set globalGroupingSets;
@JsonCreator
public GroupingSetDescriptor(
@JsonProperty("groupingKeys") List groupingKeys,
@JsonProperty("groupingSetCount") int groupingSetCount,
@JsonProperty("globalGroupingSets") Set globalGroupingSets)
{
requireNonNull(globalGroupingSets, "globalGroupingSets is null");
checkArgument(globalGroupingSets.size() <= groupingSetCount, "list of empty global grouping sets must be no larger than grouping set count");
requireNonNull(groupingKeys, "groupingKeys is null");
if (groupingKeys.isEmpty()) {
checkArgument(!globalGroupingSets.isEmpty(), "no grouping keys implies at least one global grouping set, but none provided");
}
this.groupingKeys = unmodifiableList(new ArrayList<>(groupingKeys));
this.groupingSetCount = groupingSetCount;
this.globalGroupingSets = unmodifiableSet(new LinkedHashSet<>(globalGroupingSets));
}
@JsonProperty
public List getGroupingKeys()
{
return groupingKeys;
}
@JsonProperty
public int getGroupingSetCount()
{
return groupingSetCount;
}
@JsonProperty
public Set getGlobalGroupingSets()
{
return globalGroupingSets;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
GroupingSetDescriptor that = (GroupingSetDescriptor) o;
return groupingSetCount == that.groupingSetCount &&
Objects.equals(groupingKeys, that.groupingKeys) &&
Objects.equals(globalGroupingSets, that.globalGroupingSets);
}
@Override
public int hashCode()
{
return Objects.hash(groupingKeys, groupingSetCount, globalGroupingSets);
}
}
public enum Step
{
PARTIAL(true, true),
FINAL(false, false),
INTERMEDIATE(false, true),
SINGLE(true, false);
private final boolean inputRaw;
private final boolean outputPartial;
Step(boolean inputRaw, boolean outputPartial)
{
this.inputRaw = inputRaw;
this.outputPartial = outputPartial;
}
public boolean isInputRaw()
{
return inputRaw;
}
public boolean isOutputPartial()
{
return outputPartial;
}
public static Step partialOutput(Step step)
{
if (step.isInputRaw()) {
return Step.PARTIAL;
}
else {
return Step.INTERMEDIATE;
}
}
public static Step partialInput(Step step)
{
if (step.isOutputPartial()) {
return Step.INTERMEDIATE;
}
else {
return Step.FINAL;
}
}
}
public static class Aggregation
{
private final CallExpression call;
private final Optional filter;
private final Optional orderingScheme;
private final boolean isDistinct;
private final Optional mask;
@JsonCreator
public Aggregation(
@JsonProperty("call") CallExpression call,
@JsonProperty("filter") Optional filter,
@JsonProperty("orderBy") Optional orderingScheme,
@JsonProperty("distinct") boolean isDistinct,
@JsonProperty("mask") Optional mask)
{
this.call = requireNonNull(call, "call is null");
this.filter = requireNonNull(filter, "filter is null");
this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null");
this.isDistinct = isDistinct;
this.mask = requireNonNull(mask, "mask is null");
}
public static AggregationNode.Aggregation removeDistinct(AggregationNode.Aggregation aggregation)
{
checkArgument(aggregation.isDistinct(), "Expected aggregation to have DISTINCT input");
return new AggregationNode.Aggregation(
aggregation.getCall(),
aggregation.getFilter(),
aggregation.getOrderBy(),
false,
aggregation.getMask());
}
@JsonProperty
public CallExpression getCall()
{
return call;
}
@JsonProperty
public FunctionHandle getFunctionHandle()
{
return call.getFunctionHandle();
}
@JsonProperty
public List getArguments()
{
return call.getArguments();
}
@JsonProperty
public Optional getOrderBy()
{
return orderingScheme;
}
@JsonProperty
public Optional getFilter()
{
return filter;
}
@JsonProperty
public boolean isDistinct()
{
return isDistinct;
}
@JsonProperty
public Optional getMask()
{
return mask;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (!(o instanceof Aggregation)) {
return false;
}
Aggregation that = (Aggregation) o;
return isDistinct == that.isDistinct &&
Objects.equals(call, that.call) &&
Objects.equals(filter, that.filter) &&
Objects.equals(orderingScheme, that.orderingScheme) &&
Objects.equals(mask, that.mask);
}
@Override
public String toString()
{
return "Aggregation{" +
"call=" + call +
", filter=" + filter +
", orderingScheme=" + orderingScheme +
", isDistinct=" + isDistinct +
", mask=" + mask +
'}';
}
@Override
public int hashCode()
{
return Objects.hash(call, filter, orderingScheme, isDistinct, mask);
}
}
private static void checkArgument(boolean condition, String message)
{
if (!condition) {
throw new IllegalArgumentException(message);
}
}
}