
io.prestosql.sql.planner.optimizations.PruneUnreferencedOutputs Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.optimizations;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.rule.PruneTableScanColumns;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AggregationNode.Aggregation;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.DeleteNode;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.ExceptNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.ExplainAnalyzeNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.GroupIdNode;
import io.prestosql.sql.planner.plan.IndexJoinNode;
import io.prestosql.sql.planner.plan.IndexSourceNode;
import io.prestosql.sql.planner.plan.IntersectNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.OffsetNode;
import io.prestosql.sql.planner.plan.OutputNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SetOperationNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.StatisticAggregations;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.TopNRowNumberNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.Row;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Sets.intersection;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.prestosql.sql.planner.plan.CorrelatedJoinNode.Type.INNER;
import static io.prestosql.sql.planner.plan.CorrelatedJoinNode.Type.LEFT;
import static io.prestosql.sql.planner.plan.CorrelatedJoinNode.Type.RIGHT;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static java.util.Objects.requireNonNull;
/**
* Removes all computation that does is not referenced transitively from the root of the plan
*
* E.g.,
*
* {@code Output[$0] -> Project[$0 := $1 + $2, $3 = $4 / $5] -> ...}
*
* gets rewritten as
*
* {@code Output[$0] -> Project[$0 := $1 + $2] -> ...}
*/
public class PruneUnreferencedOutputs
implements PlanOptimizer
{
private final Metadata metadata;
public PruneUnreferencedOutputs(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}
@Override
public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(symbolAllocator, "symbolAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");
return SimplePlanRewriter.rewriteWith(new Rewriter(metadata, types, session), plan, ImmutableSet.of());
}
private static class Rewriter
extends SimplePlanRewriter>
{
private final Metadata metadata;
private final TypeProvider types;
private final Session session;
public Rewriter(Metadata metadata, TypeProvider types, Session session)
{
this.metadata = metadata;
this.types = types;
this.session = session;
}
@Override
public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext> context)
{
return context.defaultRewrite(node, ImmutableSet.copyOf(node.getSource().getOutputSymbols()));
}
@Override
public PlanNode visitExchange(ExchangeNode node, RewriteContext> context)
{
Set expectedOutputSymbols = Sets.newHashSet(context.get());
node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputSymbols::add);
expectedOutputSymbols.addAll(node.getPartitioningScheme().getPartitioning().getColumns());
node.getOrderingScheme().ifPresent(orderingScheme -> expectedOutputSymbols.addAll(orderingScheme.getOrderBy()));
List> inputsBySource = new ArrayList<>(node.getInputs().size());
for (int i = 0; i < node.getInputs().size(); i++) {
inputsBySource.add(new ArrayList<>());
}
List newOutputSymbols = new ArrayList<>(node.getOutputSymbols().size());
for (int i = 0; i < node.getOutputSymbols().size(); i++) {
Symbol outputSymbol = node.getOutputSymbols().get(i);
if (expectedOutputSymbols.contains(outputSymbol)) {
newOutputSymbols.add(outputSymbol);
for (int source = 0; source < node.getInputs().size(); source++) {
inputsBySource.get(source).add(node.getInputs().get(source).get(i));
}
}
}
// newOutputSymbols contains all partition, sort and hash symbols so simply swap the output layout
PartitioningScheme partitioningScheme = new PartitioningScheme(
node.getPartitioningScheme().getPartitioning(),
newOutputSymbols,
node.getPartitioningScheme().getHashColumn(),
node.getPartitioningScheme().isReplicateNullsAndAny(),
node.getPartitioningScheme().getBucketToPartition());
ImmutableList.Builder rewrittenSources = ImmutableList.builder();
for (int i = 0; i < node.getSources().size(); i++) {
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(inputsBySource.get(i));
rewrittenSources.add(context.rewrite(
node.getSources().get(i),
expectedInputs.build()));
}
return new ExchangeNode(
node.getId(),
node.getType(),
node.getScope(),
partitioningScheme,
rewrittenSources.build(),
inputsBySource,
node.getOrderingScheme());
}
@Override
public PlanNode visitJoin(JoinNode node, RewriteContext> context)
{
Set expectedFilterInputs = node.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of());
ImmutableSet.Builder leftInputs = ImmutableSet.builder()
.addAll(context.get())
.addAll(expectedFilterInputs)
.addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft));
node.getLeftHashSymbol().ifPresent(leftInputs::add);
ImmutableSet.Builder rightInputs = ImmutableSet.builder()
.addAll(context.get())
.addAll(expectedFilterInputs)
.addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight));
node.getRightHashSymbol().ifPresent(rightInputs::add);
PlanNode left = context.rewrite(node.getLeft(), leftInputs.build());
PlanNode right = context.rewrite(node.getRight(), rightInputs.build());
List leftOutputSymbols = node.getLeftOutputSymbols().stream()
.filter(context.get()::contains)
.distinct()
.collect(toImmutableList());
List rightOutputSymbols = node.getRightOutputSymbols().stream()
.filter(context.get()::contains)
.distinct()
.collect(toImmutableList());
return new JoinNode(
node.getId(),
node.getType(),
left,
right,
node.getCriteria(),
leftOutputSymbols,
rightOutputSymbols,
node.getFilter(),
node.getLeftHashSymbol(),
node.getRightHashSymbol(),
node.getDistributionType(),
node.isSpillable(),
node.getDynamicFilters(),
node.getReorderJoinStatsAndCost());
}
@Override
public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext> context)
{
if (!context.get().contains(node.getSemiJoinOutput())) {
return context.rewrite(node.getSource(), context.get());
}
ImmutableSet.Builder sourceInputsBuilder = ImmutableSet.builder();
sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinSymbol());
if (node.getSourceHashSymbol().isPresent()) {
sourceInputsBuilder.add(node.getSourceHashSymbol().get());
}
Set sourceInputs = sourceInputsBuilder.build();
ImmutableSet.Builder filteringSourceInputBuilder = ImmutableSet.builder();
filteringSourceInputBuilder.add(node.getFilteringSourceJoinSymbol());
if (node.getFilteringSourceHashSymbol().isPresent()) {
filteringSourceInputBuilder.add(node.getFilteringSourceHashSymbol().get());
}
Set filteringSourceInputs = filteringSourceInputBuilder.build();
PlanNode source = context.rewrite(node.getSource(), sourceInputs);
PlanNode filteringSource = context.rewrite(node.getFilteringSource(), filteringSourceInputs);
return new SemiJoinNode(
node.getId(),
source,
filteringSource,
node.getSourceJoinSymbol(),
node.getFilteringSourceJoinSymbol(),
node.getSemiJoinOutput(),
node.getSourceHashSymbol(),
node.getFilteringSourceHashSymbol(),
node.getDistributionType(),
node.getDynamicFilterId());
}
@Override
public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext> context)
{
Set requiredInputs = ImmutableSet.builder()
.addAll(SymbolsExtractor.extractUnique(node.getFilter()))
.addAll(context.get())
.build();
ImmutableSet.Builder leftInputs = ImmutableSet.builder();
node.getLeftPartitionSymbol().map(leftInputs::add);
ImmutableSet.Builder rightInputs = ImmutableSet.builder();
node.getRightPartitionSymbol().map(rightInputs::add);
PlanNode left = context.rewrite(node.getLeft(), leftInputs.addAll(requiredInputs).build());
PlanNode right = context.rewrite(node.getRight(), rightInputs.addAll(requiredInputs).build());
List outputSymbols = node.getOutputSymbols().stream()
.filter(context.get()::contains)
.distinct()
.collect(toImmutableList());
return new SpatialJoinNode(node.getId(), node.getType(), left, right, outputSymbols, node.getFilter(), node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree());
}
@Override
public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext> context)
{
ImmutableSet.Builder probeInputsBuilder = ImmutableSet.builder();
probeInputsBuilder.addAll(context.get())
.addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe));
if (node.getProbeHashSymbol().isPresent()) {
probeInputsBuilder.add(node.getProbeHashSymbol().get());
}
Set probeInputs = probeInputsBuilder.build();
ImmutableSet.Builder indexInputBuilder = ImmutableSet.builder();
indexInputBuilder.addAll(context.get())
.addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex));
if (node.getIndexHashSymbol().isPresent()) {
indexInputBuilder.add(node.getIndexHashSymbol().get());
}
Set indexInputs = indexInputBuilder.build();
PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs);
PlanNode indexSource = context.rewrite(node.getIndexSource(), indexInputs);
return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol());
}
@Override
public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext> context)
{
List newOutputSymbols = node.getOutputSymbols().stream()
.filter(context.get()::contains)
.collect(toImmutableList());
Set newLookupSymbols = node.getLookupSymbols().stream()
.filter(context.get()::contains)
.collect(toImmutableSet());
Map newAssignments = newOutputSymbols.stream()
.collect(Collectors.toMap(Function.identity(), node.getAssignments()::get));
return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbols, newOutputSymbols, newAssignments);
}
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(node.getGroupingKeys());
if (node.getHashSymbol().isPresent()) {
expectedInputs.add(node.getHashSymbol().get());
}
ImmutableMap.Builder aggregations = ImmutableMap.builder();
for (Map.Entry entry : node.getAggregations().entrySet()) {
Symbol symbol = entry.getKey();
if (context.get().contains(symbol)) {
Aggregation aggregation = entry.getValue();
expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation));
aggregations.put(symbol, aggregation);
}
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new AggregationNode(
node.getId(),
source,
aggregations.build(),
node.getGroupingSets(),
ImmutableList.of(),
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol());
}
@Override
public PlanNode visitWindow(WindowNode node, RewriteContext> context)
{
Map prunedFunctions = node.getWindowFunctions().entrySet().stream()
.filter(entry -> context.get().contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
if (prunedFunctions.isEmpty()) {
return context.rewrite(node.getSource(), context.get());
}
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(context.get())
.addAll(node.getPartitionBy());
node.getOrderingScheme().ifPresent(orderingScheme ->
orderingScheme.getOrderBy()
.forEach(expectedInputs::add));
if (node.getHashSymbol().isPresent()) {
expectedInputs.add(node.getHashSymbol().get());
}
prunedFunctions.values().stream()
.map(SymbolsExtractor::extractUnique)
.forEach(expectedInputs::addAll);
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new WindowNode(
node.getId(),
source,
node.getSpecification(),
prunedFunctions,
node.getHashSymbol(),
node.getPrePartitionedInputs(),
node.getPreSortedOrderPrefix());
}
@Override
public PlanNode visitTableScan(TableScanNode node, RewriteContext> context)
{
return PruneTableScanColumns.pruneColumns(metadata, types, session, node, context.get())
.orElse(node);
}
@Override
public PlanNode visitFilter(FilterNode node, RewriteContext> context)
{
Set expectedInputs = ImmutableSet.builder()
.addAll(SymbolsExtractor.extractUnique(node.getPredicate()))
.addAll(context.get())
.build();
PlanNode source = context.rewrite(node.getSource(), expectedInputs);
return new FilterNode(node.getId(), source, node.getPredicate());
}
@Override
public PlanNode visitGroupId(GroupIdNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder();
List newAggregationArguments = node.getAggregationArguments().stream()
.filter(context.get()::contains)
.collect(Collectors.toList());
expectedInputs.addAll(newAggregationArguments);
ImmutableList.Builder> newGroupingSets = ImmutableList.builder();
Map newGroupingMapping = new HashMap<>();
for (List groupingSet : node.getGroupingSets()) {
ImmutableList.Builder newGroupingSet = ImmutableList.builder();
for (Symbol output : groupingSet) {
if (context.get().contains(output)) {
newGroupingSet.add(output);
newGroupingMapping.putIfAbsent(output, node.getGroupingColumns().get(output));
expectedInputs.add(node.getGroupingColumns().get(output));
}
}
newGroupingSets.add(newGroupingSet.build());
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMapping, newAggregationArguments, node.getGroupIdSymbol());
}
@Override
public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext> context)
{
if (!context.get().contains(node.getMarkerSymbol())) {
return context.rewrite(node.getSource(), context.get());
}
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(node.getDistinctSymbols())
.addAll(context.get().stream()
.filter(symbol -> !symbol.equals(node.getMarkerSymbol()))
.collect(toImmutableList()));
if (node.getHashSymbol().isPresent()) {
expectedInputs.add(node.getHashSymbol().get());
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new MarkDistinctNode(node.getId(), source, node.getMarkerSymbol(), node.getDistinctSymbols(), node.getHashSymbol());
}
@Override
public PlanNode visitUnnest(UnnestNode node, RewriteContext> context)
{
ImmutableSet.Builder contextAndFilterSymbolsBuilder = ImmutableSet.builder()
.addAll(context.get());
node.getFilter().ifPresent(expression -> contextAndFilterSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression)));
Set contextAndFilterSymbols = contextAndFilterSymbolsBuilder.build();
List prunedReplicateSymbols = node.getReplicateSymbols().stream()
.filter(contextAndFilterSymbols::contains)
.collect(toImmutableList());
Optional prunedOrdinalitySymbol = node.getOrdinalitySymbol()
.filter(contextAndFilterSymbols::contains);
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(prunedReplicateSymbols);
node.getMappings().stream()
.map(UnnestNode.Mapping::getInput)
.forEach(expectedInputs::add);
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new UnnestNode(node.getId(), source, prunedReplicateSymbols, node.getMappings(), prunedOrdinalitySymbol, node.getJoinType(), node.getFilter());
}
@Override
public PlanNode visitProject(ProjectNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder();
Assignments.Builder builder = Assignments.builder();
node.getAssignments().forEach((symbol, expression) -> {
if (context.get().contains(symbol)) {
expectedInputs.addAll(SymbolsExtractor.extractUnique(expression));
builder.put(symbol, expression);
}
});
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new ProjectNode(node.getId(), source, builder.build());
}
@Override
public PlanNode visitOutput(OutputNode node, RewriteContext> context)
{
Set expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols());
PlanNode source = context.rewrite(node.getSource(), expectedInputs);
return new OutputNode(node.getId(), source, node.getColumnNames(), node.getOutputSymbols());
}
@Override
public PlanNode visitOffset(OffsetNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(context.get());
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new OffsetNode(node.getId(), source, node.getCount());
}
@Override
public PlanNode visitLimit(LimitNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(context.get())
.addAll(node.getTiesResolvingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()));
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new LimitNode(node.getId(), source, node.getCount(), node.getTiesResolvingScheme(), node.isPartial());
}
@Override
public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext> context)
{
Set expectedInputs;
if (node.getHashSymbol().isPresent()) {
expectedInputs = ImmutableSet.copyOf(concat(node.getDistinctSymbols(), ImmutableList.of(node.getHashSymbol().get())));
}
else {
expectedInputs = ImmutableSet.copyOf(node.getDistinctSymbols());
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs);
return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getDistinctSymbols(), node.getHashSymbol());
}
@Override
public PlanNode visitTopN(TopNNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(context.get())
.addAll(node.getOrderingScheme().getOrderBy());
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new TopNNode(node.getId(), source, node.getCount(), node.getOrderingScheme(), node.getStep());
}
@Override
public PlanNode visitRowNumber(RowNumberNode node, RewriteContext> context)
{
// Remove unused RowNumberNode
if (!context.get().contains(node.getRowNumberSymbol())) {
PlanNode source = context.rewrite(node.getSource(), context.get());
if (node.getMaxRowCountPerPartition().isEmpty()) {
return source;
}
if (node.getPartitionBy().isEmpty()) {
return new LimitNode(node.getId(), source, node.getMaxRowCountPerPartition().get(), false);
}
}
ImmutableSet.Builder inputsBuilder = ImmutableSet.builder();
ImmutableSet.Builder expectedInputs = inputsBuilder
.addAll(context.get())
.addAll(node.getPartitionBy());
if (node.getHashSymbol().isPresent()) {
inputsBuilder.add(node.getHashSymbol().get());
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.isOrderSensitive(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol());
}
@Override
public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(context.get())
.addAll(node.getPartitionBy())
.addAll(node.getOrderingScheme().getOrderBy());
if (node.getHashSymbol().isPresent()) {
expectedInputs.add(node.getHashSymbol().get());
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new TopNRowNumberNode(
node.getId(),
source,
node.getSpecification(),
node.getRowNumberSymbol(),
node.getMaxRowCountPerPartition(),
node.isPartial(),
node.getHashSymbol());
}
@Override
public PlanNode visitSort(SortNode node, RewriteContext> context)
{
Set expectedInputs = ImmutableSet.copyOf(concat(context.get(), node.getOrderingScheme().getOrderBy()));
PlanNode source = context.rewrite(node.getSource(), expectedInputs);
return new SortNode(node.getId(), source, node.getOrderingScheme(), node.isPartial());
}
@Override
public PlanNode visitTableWriter(TableWriterNode node, RewriteContext> context)
{
ImmutableSet.Builder expectedInputs = ImmutableSet.builder()
.addAll(node.getColumns());
if (node.getPartitioningScheme().isPresent()) {
PartitioningScheme partitioningScheme = node.getPartitioningScheme().get();
partitioningScheme.getPartitioning().getColumns().forEach(expectedInputs::add);
partitioningScheme.getHashColumn().ifPresent(expectedInputs::add);
}
if (node.getStatisticsAggregation().isPresent()) {
StatisticAggregations aggregations = node.getStatisticsAggregation().get();
expectedInputs.addAll(aggregations.getGroupingSymbols());
aggregations.getAggregations().values().forEach(aggregation -> expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation)));
}
PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new TableWriterNode(
node.getId(),
source,
node.getTarget(),
node.getRowCountSymbol(),
node.getFragmentSymbol(),
node.getColumns(),
node.getColumnNames(),
node.getNotNullColumnSymbols(),
node.getPartitioningScheme(),
node.getStatisticsAggregation(),
node.getStatisticsAggregationDescriptor());
}
@Override
public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext> context)
{
PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputSymbols()));
return new StatisticsWriterNode(
node.getId(),
source,
node.getTarget(),
node.getRowCountSymbol(),
node.isRowCountEnabled(),
node.getDescriptor());
}
@Override
public PlanNode visitTableFinish(TableFinishNode node, RewriteContext> context)
{
PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputSymbols()));
return new TableFinishNode(
node.getId(),
source,
node.getTarget(),
node.getRowCountSymbol(),
node.getStatisticsAggregation(),
node.getStatisticsAggregationDescriptor());
}
@Override
public PlanNode visitDelete(DeleteNode node, RewriteContext> context)
{
PlanNode source = context.rewrite(node.getSource(), ImmutableSet.of(node.getRowId()));
return new DeleteNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getOutputSymbols());
}
@Override
public PlanNode visitUnion(UnionNode node, RewriteContext> context)
{
// Find out which output symbols we need to keep
ImmutableListMultimap.Builder prunedMappingBuilder = ImmutableListMultimap.builder();
for (Symbol symbol : node.getOutputSymbols()) {
if (context.get().contains(symbol)) {
prunedMappingBuilder.putAll(symbol, node.getSymbolMapping().get(symbol));
}
}
ListMultimap prunedSymbolMapping = prunedMappingBuilder.build();
// Find the corresponding input symbols to the remaining output symbols and prune the children
ImmutableList.Builder rewrittenSources = ImmutableList.builder();
for (int i = 0; i < node.getSources().size(); i++) {
ImmutableSet.Builder expectedSourceSymbols = ImmutableSet.builder();
for (Collection symbols : prunedSymbolMapping.asMap().values()) {
expectedSourceSymbols.add(Iterables.get(symbols, i));
}
rewrittenSources.add(context.rewrite(node.getSources().get(i), expectedSourceSymbols.build()));
}
return new UnionNode(node.getId(), rewrittenSources.build(), prunedSymbolMapping, ImmutableList.copyOf(prunedSymbolMapping.keySet()));
}
@Override
public PlanNode visitIntersect(IntersectNode node, RewriteContext> context)
{
return rewriteSetOperationChildren(node, context);
}
@Override
public PlanNode visitExcept(ExceptNode node, RewriteContext> context)
{
return rewriteSetOperationChildren(node, context);
}
private PlanNode rewriteSetOperationChildren(SetOperationNode node, RewriteContext> context)
{
ImmutableList.Builder rewrittenSources = ImmutableList.builder();
for (int i = 0; i < node.getSources().size(); i++) {
rewrittenSources.add(context.rewrite(node.getSources().get(i), ImmutableSet.copyOf(node.sourceOutputLayout(i))));
}
return node.replaceChildren(rewrittenSources.build());
}
@Override
public PlanNode visitValues(ValuesNode node, RewriteContext> context)
{
// nothing to prune: no output symbols and no expressions
if (node.getRows().isEmpty()) {
return node;
}
// handle the case of all output symbols pruned
if (node.getOutputSymbols().stream().noneMatch(context.get()::contains)) {
return new ValuesNode(node.getId(), node.getRowCount());
}
// if any of ValuesNode's rows is specified by expression other than Row, the redundant piece cannot be extracted and pruned
if (!node.getRows().get().stream().allMatch(Row.class::isInstance)) {
return node;
}
ImmutableList.Builder rewrittenOutputSymbolsBuilder = ImmutableList.builder();
ImmutableList.Builder> rowBuildersBuilder = ImmutableList.builder();
// Initialize builder for each row
for (int i = 0; i < node.getRowCount(); i++) {
rowBuildersBuilder.add(ImmutableList.builder());
}
ImmutableList> rowBuilders = rowBuildersBuilder.build();
for (int i = 0; i < node.getOutputSymbols().size(); i++) {
Symbol outputSymbol = node.getOutputSymbols().get(i);
// If output symbol is used
if (context.get().contains(outputSymbol)) {
rewrittenOutputSymbolsBuilder.add(outputSymbol);
// Add the value of the output symbol for each row
for (int j = 0; j < node.getRowCount(); j++) {
rowBuilders.get(j).add(((Row) node.getRows().get().get(j)).getItems().get(i));
}
}
}
List rewrittenRows = rowBuilders.stream()
.map(ImmutableList.Builder::build)
.map(Row::new)
.collect(toImmutableList());
return new ValuesNode(node.getId(), rewrittenOutputSymbolsBuilder.build(), rewrittenRows);
}
@Override
public PlanNode visitApply(ApplyNode node, RewriteContext> context)
{
// remove unused apply nodes
if (intersection(node.getSubqueryAssignments().getSymbols(), context.get()).isEmpty()) {
return context.rewrite(node.getInput(), context.get());
}
// extract symbols required subquery plan
ImmutableSet.Builder subqueryAssignmentsSymbolsBuilder = ImmutableSet.builder();
Assignments.Builder subqueryAssignments = Assignments.builder();
for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) {
Symbol output = entry.getKey();
Expression expression = entry.getValue();
if (context.get().contains(output)) {
subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression));
subqueryAssignments.put(output, expression);
}
}
Set subqueryAssignmentsSymbols = subqueryAssignmentsSymbolsBuilder.build();
PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsSymbols);
// prune not used correlation symbols
Set subquerySymbols = SymbolsExtractor.extractUnique(subquery);
List newCorrelation = node.getCorrelation().stream()
.filter(subquerySymbols::contains)
.collect(toImmutableList());
Set inputContext = ImmutableSet.builder()
.addAll(context.get())
.addAll(newCorrelation)
.addAll(subqueryAssignmentsSymbols) // need to include those: e.g: "expr" from "expr IN (SELECT 1)"
.build();
PlanNode input = context.rewrite(node.getInput(), inputContext);
return new ApplyNode(node.getId(), input, subquery, subqueryAssignments.build(), newCorrelation, node.getOriginSubquery());
}
@Override
public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext> context)
{
if (!context.get().contains(node.getIdColumn())) {
return context.rewrite(node.getSource(), context.get());
}
return context.defaultRewrite(node, context.get());
}
@Override
public PlanNode visitCorrelatedJoin(CorrelatedJoinNode node, RewriteContext> context)
{
Set expectedFilterSymbols = SymbolsExtractor.extractUnique(node.getFilter());
Set expectedFilterAndContextSymbols = ImmutableSet.builder()
.addAll(expectedFilterSymbols)
.addAll(context.get())
.build();
PlanNode subquery = context.rewrite(node.getSubquery(), expectedFilterAndContextSymbols);
// remove unused correlated join nodes
if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty()) {
// remove unused subquery of inner join
if (node.getType() == INNER && isScalar(subquery) && node.getFilter().equals(TRUE_LITERAL)) {
return context.rewrite(node.getInput(), context.get());
}
// remove unused subquery of left join
if (node.getType() == LEFT && isAtMostScalar(subquery)) {
return context.rewrite(node.getInput(), context.get());
}
}
// prune not used correlation symbols
Set subquerySymbols = SymbolsExtractor.extractUnique(subquery);
List newCorrelation = node.getCorrelation().stream()
.filter(subquerySymbols::contains)
.collect(toImmutableList());
Set expectedCorrelationAndContextSymbols = ImmutableSet.builder()
.addAll(newCorrelation)
.addAll(context.get())
.build();
Set inputContext = ImmutableSet.builder()
.addAll(expectedCorrelationAndContextSymbols)
.addAll(expectedFilterSymbols)
.build();
PlanNode input = context.rewrite(node.getInput(), inputContext);
// remove unused input nodes
if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), expectedCorrelationAndContextSymbols).isEmpty()) {
// remove unused input of inner join
if (node.getType() == INNER && isScalar(input) && node.getFilter().equals(TRUE_LITERAL)) {
return subquery;
}
// remove unused input of right join
if (node.getType() == RIGHT && isAtMostScalar(input)) {
return subquery;
}
}
return new CorrelatedJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getFilter(), node.getOriginSubquery());
}
}
}