Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.sql.planner.LocalExecutionPlanner Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner;
import com.google.common.base.VerifyException;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ContiguousSet;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.primitives.Ints;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cache.NonEvictableCache;
import io.trino.client.NodeVersion;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.DynamicFilterConfig;
import io.trino.execution.ExplainAnalyzeContext;
import io.trino.execution.StageId;
import io.trino.execution.TableExecuteContextManager;
import io.trino.execution.TaskId;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.metadata.MergeHandle;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TableExecuteHandle;
import io.trino.metadata.TableHandle;
import io.trino.operator.AggregationOperator.AggregationOperatorFactory;
import io.trino.operator.AssignUniqueIdOperator;
import io.trino.operator.DevNullOperator.DevNullOperatorFactory;
import io.trino.operator.DirectExchangeClientSupplier;
import io.trino.operator.DriverFactory;
import io.trino.operator.DynamicFilterSourceOperator;
import io.trino.operator.DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory;
import io.trino.operator.EnforceSingleRowOperator;
import io.trino.operator.ExchangeOperator.ExchangeOperatorFactory;
import io.trino.operator.ExplainAnalyzeOperator.ExplainAnalyzeOperatorFactory;
import io.trino.operator.FilterAndProjectOperator;
import io.trino.operator.FlatHashStrategyCompiler;
import io.trino.operator.GroupIdOperator;
import io.trino.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import io.trino.operator.HashSemiJoinOperator;
import io.trino.operator.JoinOperatorType;
import io.trino.operator.LeafTableFunctionOperator.LeafTableFunctionOperatorFactory;
import io.trino.operator.LimitOperator.LimitOperatorFactory;
import io.trino.operator.LocalPlannerAware;
import io.trino.operator.MarkDistinctOperator.MarkDistinctOperatorFactory;
import io.trino.operator.MergeOperator.MergeOperatorFactory;
import io.trino.operator.MergeProcessorOperator;
import io.trino.operator.MergeWriterOperator.MergeWriterOperatorFactory;
import io.trino.operator.OperatorFactory;
import io.trino.operator.OrderByOperator.OrderByOperatorFactory;
import io.trino.operator.OutputFactory;
import io.trino.operator.PagesIndex;
import io.trino.operator.PagesSpatialIndexFactory;
import io.trino.operator.PartitionFunction;
import io.trino.operator.RefreshMaterializedViewOperator.RefreshMaterializedViewOperatorFactory;
import io.trino.operator.RetryPolicy;
import io.trino.operator.RowNumberOperator;
import io.trino.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory;
import io.trino.operator.SetBuilderOperator.SetBuilderOperatorFactory;
import io.trino.operator.SetBuilderOperator.SetSupplier;
import io.trino.operator.SimpleTableExecuteOperator.SimpleTableExecuteOperatorOperatorFactory;
import io.trino.operator.SourceOperatorFactory;
import io.trino.operator.SpatialIndexBuilderOperator.SpatialIndexBuilderOperatorFactory;
import io.trino.operator.SpatialIndexBuilderOperator.SpatialPredicate;
import io.trino.operator.SpatialJoinOperator.SpatialJoinOperatorFactory;
import io.trino.operator.StatisticsWriterOperator.StatisticsWriterOperatorFactory;
import io.trino.operator.StreamingAggregationOperator;
import io.trino.operator.TableMutationOperator.TableMutationOperatorFactory;
import io.trino.operator.TableScanOperator.TableScanOperatorFactory;
import io.trino.operator.TaskContext;
import io.trino.operator.TopNOperator;
import io.trino.operator.TopNRankingOperator;
import io.trino.operator.ValuesOperator.ValuesOperatorFactory;
import io.trino.operator.WindowFunctionDefinition;
import io.trino.operator.WindowOperator.WindowOperatorFactory;
import io.trino.operator.aggregation.AccumulatorFactory;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.DistinctAccumulatorFactory;
import io.trino.operator.aggregation.OrderedAccumulatorFactory;
import io.trino.operator.aggregation.partial.PartialAggregationController;
import io.trino.operator.exchange.LocalExchange;
import io.trino.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory;
import io.trino.operator.exchange.LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory;
import io.trino.operator.exchange.LocalMergeSourceOperator.LocalMergeSourceOperatorFactory;
import io.trino.operator.exchange.PageChannelSelector;
import io.trino.operator.function.RegularTableFunctionPartition.PassThroughColumnSpecification;
import io.trino.operator.function.TableFunctionOperator.TableFunctionOperatorFactory;
import io.trino.operator.index.DynamicTupleFilterFactory;
import io.trino.operator.index.FieldSetFilteringRecordSet;
import io.trino.operator.index.IndexBuildDriverFactoryProvider;
import io.trino.operator.index.IndexJoinLookupStats;
import io.trino.operator.index.IndexLookupSourceFactory;
import io.trino.operator.index.IndexManager;
import io.trino.operator.index.IndexSourceOperator;
import io.trino.operator.join.HashBuilderOperator.HashBuilderOperatorFactory;
import io.trino.operator.join.JoinBridgeManager;
import io.trino.operator.join.JoinOperatorFactory;
import io.trino.operator.join.LookupSourceFactory;
import io.trino.operator.join.NestedLoopJoinBridge;
import io.trino.operator.join.NestedLoopJoinPagesSupplier;
import io.trino.operator.join.PartitionedLookupSourceFactory;
import io.trino.operator.join.unspilled.HashBuilderOperator;
import io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputFactory;
import io.trino.operator.output.PositionsAppenderFactory;
import io.trino.operator.output.SkewedPartitionRebalancer;
import io.trino.operator.output.TaskOutputOperator.TaskOutputFactory;
import io.trino.operator.project.CursorProcessor;
import io.trino.operator.project.PageProcessor;
import io.trino.operator.project.PageProjection;
import io.trino.operator.unnest.UnnestOperator;
import io.trino.operator.window.AggregationWindowFunctionSupplier;
import io.trino.operator.window.FrameInfo;
import io.trino.operator.window.PartitionerSupplier;
import io.trino.operator.window.PatternRecognitionPartitionerSupplier;
import io.trino.operator.window.RegularPartitionerSupplier;
import io.trino.operator.window.matcher.IrRowPatternToProgramRewriter;
import io.trino.operator.window.matcher.Matcher;
import io.trino.operator.window.matcher.Program;
import io.trino.operator.window.pattern.ArgumentComputation.ArgumentComputationSupplier;
import io.trino.operator.window.pattern.LabelEvaluator.EvaluationSupplier;
import io.trino.operator.window.pattern.LogicalIndexNavigation;
import io.trino.operator.window.pattern.MatchAggregation.MatchAggregationInstantiator;
import io.trino.operator.window.pattern.MatchAggregationPointer;
import io.trino.operator.window.pattern.MeasureComputation.MeasureComputationSupplier;
import io.trino.operator.window.pattern.PhysicalValueAccessor;
import io.trino.operator.window.pattern.PhysicalValuePointer;
import io.trino.operator.window.pattern.SetEvaluator.SetEvaluatorSupplier;
import io.trino.plugin.base.MappedRecordSet;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlRow;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorIndex;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.DynamicFilter;
import io.trino.spi.connector.RecordSet;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.function.WindowFunctionSupplier;
import io.trino.spi.function.table.TableFunctionProcessorProvider;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.NullableValue;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spiller.PartitioningSpillerFactory;
import io.trino.spiller.SingleStreamSpillerFactory;
import io.trino.spiller.SpillerFactory;
import io.trino.split.PageSinkManager;
import io.trino.split.PageSourceProvider;
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.gen.ExpressionCompiler;
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.gen.JoinFilterFunctionCompiler;
import io.trino.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory;
import io.trino.sql.gen.OrderingCompiler;
import io.trino.sql.gen.PageFunctionCompiler;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.optimizations.IndexJoinOptimizer;
import io.trino.sql.planner.plan.AdaptivePlanNode;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.AggregationNode.Step;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.DynamicFilterSourceNode;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.ExplainAnalyzeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.IndexJoinNode;
import io.trino.sql.planner.plan.IndexSourceNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.MergeProcessorNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PatternRecognitionNode.Measure;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RefreshMaterializedViewNode;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.SampleNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimpleTableExecuteNode;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.StatisticAggregationsDescriptor;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableDeleteNode;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableFunctionNode;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableUpdateNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.TableWriterNode.MergeTarget;
import io.trino.sql.planner.plan.TableWriterNode.TableExecuteTarget;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.plan.WindowNode.Frame;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.ClassifierValuePointer;
import io.trino.sql.planner.rowpattern.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.LogicalIndexPointer;
import io.trino.sql.planner.rowpattern.MatchNumberValuePointer;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SqlToRowExpressionTranslator;
import io.trino.type.BlockTypeOperators;
import io.trino.type.FunctionType;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import static com.google.common.base.Functions.forMap;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.DiscreteDomain.integers;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static com.google.common.collect.Sets.difference;
import static io.trino.SystemSessionProperties.getAdaptivePartialAggregationUniqueRowsRatioThreshold;
import static io.trino.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit;
import static io.trino.SystemSessionProperties.getExchangeCompressionCodec;
import static io.trino.SystemSessionProperties.getFilterAndProjectMinOutputPageRowCount;
import static io.trino.SystemSessionProperties.getFilterAndProjectMinOutputPageSize;
import static io.trino.SystemSessionProperties.getPagePartitioningBufferPoolSize;
import static io.trino.SystemSessionProperties.getSkewedPartitionMinDataProcessedRebalanceThreshold;
import static io.trino.SystemSessionProperties.getTaskConcurrency;
import static io.trino.SystemSessionProperties.getTaskMaxWriterCount;
import static io.trino.SystemSessionProperties.getTaskMinWriterCount;
import static io.trino.SystemSessionProperties.getWriterScalingMinDataProcessed;
import static io.trino.SystemSessionProperties.isAdaptivePartialAggregationEnabled;
import static io.trino.SystemSessionProperties.isEnableCoordinatorDynamicFiltersDistribution;
import static io.trino.SystemSessionProperties.isEnableLargeDynamicFilters;
import static io.trino.SystemSessionProperties.isForceSpillingOperator;
import static io.trino.SystemSessionProperties.isSpillEnabled;
import static io.trino.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.operator.DistinctLimitOperator.DistinctLimitOperatorFactory;
import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier;
import static io.trino.operator.OperatorFactories.join;
import static io.trino.operator.OperatorFactories.spillingJoin;
import static io.trino.operator.TableFinishOperator.TableFinishOperatorFactory;
import static io.trino.operator.TableFinishOperator.TableFinisher;
import static io.trino.operator.TableWriterOperator.FRAGMENT_CHANNEL;
import static io.trino.operator.TableWriterOperator.ROW_COUNT_CHANNEL;
import static io.trino.operator.TableWriterOperator.STATS_START_CHANNEL;
import static io.trino.operator.TableWriterOperator.TableWriterOperatorFactory;
import static io.trino.operator.WindowFunctionDefinition.window;
import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory;
import static io.trino.operator.join.JoinUtils.isBuildSideReplicated;
import static io.trino.operator.join.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory;
import static io.trino.operator.join.NestedLoopJoinOperator.NestedLoopJoinOperatorFactory;
import static io.trino.operator.output.SkewedPartitionRebalancer.checkCanScalePartitionsRemotely;
import static io.trino.operator.output.SkewedPartitionRebalancer.createPartitionFunction;
import static io.trino.operator.output.SkewedPartitionRebalancer.getMaxWritersBasedOnMemory;
import static io.trino.operator.output.SkewedPartitionRebalancer.getTaskCount;
import static io.trino.operator.window.FrameInfo.Ordering.ASCENDING;
import static io.trino.operator.window.FrameInfo.Ordering.DESCENDING;
import static io.trino.operator.window.pattern.PhysicalValuePointer.CLASSIFIER;
import static io.trino.operator.window.pattern.PhysicalValuePointer.MATCH_NUMBER;
import static io.trino.spi.StandardErrorCode.COMPILER_ERROR;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.spi.type.TypeUtils.writeNativeValue;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory;
import static io.trino.sql.DynamicFilters.extractDynamicFilters;
import static io.trino.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider;
import static io.trino.sql.ir.Booleans.TRUE;
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN;
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.ir.IrUtils.combineConjuncts;
import static io.trino.sql.planner.ExpressionExtractor.extractExpressions;
import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression;
import static io.trino.sql.planner.SortExpressionExtractor.extractSortExpression;
import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.FrameBoundType.CURRENT_ROW;
import static io.trino.sql.planner.plan.JoinType.FULL;
import static io.trino.sql.planner.plan.JoinType.INNER;
import static io.trino.sql.planner.plan.JoinType.LEFT;
import static io.trino.sql.planner.plan.JoinType.RIGHT;
import static io.trino.sql.planner.plan.RowsPerMatch.ONE;
import static io.trino.sql.planner.plan.SkipToPosition.LAST;
import static io.trino.sql.planner.plan.TableWriterNode.CreateTarget;
import static io.trino.sql.planner.plan.TableWriterNode.InsertTarget;
import static io.trino.sql.planner.plan.TableWriterNode.WriterTarget;
import static io.trino.sql.planner.plan.WindowFrameType.ROWS;
import static io.trino.util.MoreMath.previousPowerOfTwo;
import static io.trino.util.SpatialJoinUtils.ST_CONTAINS;
import static io.trino.util.SpatialJoinUtils.ST_DISTANCE;
import static io.trino.util.SpatialJoinUtils.ST_INTERSECTS;
import static io.trino.util.SpatialJoinUtils.ST_WITHIN;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialComparisons;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions;
import static java.lang.Math.ceil;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.HOURS;
import static java.util.stream.Collectors.partitioningBy;
import static java.util.stream.IntStream.range;
public class LocalExecutionPlanner
{
private static final Logger log = Logger.get(LocalExecutionPlanner.class);
private final PlannerContext plannerContext;
private final Metadata metadata;
private final Optional explainAnalyzeContext;
private final PageSourceProvider pageSourceProvider;
private final IndexManager indexManager;
private final NodePartitioningManager nodePartitioningManager;
private final PageSinkManager pageSinkManager;
private final DirectExchangeClientSupplier directExchangeClientSupplier;
private final ExpressionCompiler expressionCompiler;
private final PageFunctionCompiler pageFunctionCompiler;
private final JoinFilterFunctionCompiler joinFilterFunctionCompiler;
private final DataSize maxIndexMemorySize;
private final IndexJoinLookupStats indexJoinLookupStats;
private final DataSize maxPartialAggregationMemorySize;
private final DataSize maxPagePartitioningBufferSize;
private final DataSize maxLocalExchangeBufferSize;
private final SpillerFactory spillerFactory;
private final SingleStreamSpillerFactory singleStreamSpillerFactory;
private final PartitioningSpillerFactory partitioningSpillerFactory;
private final PagesIndex.Factory pagesIndexFactory;
private final JoinCompiler joinCompiler;
private final FlatHashStrategyCompiler hashStrategyCompiler;
private final OrderingCompiler orderingCompiler;
private final int largeMaxDistinctValuesPerDriver;
private final int largePartitionedMaxDistinctValuesPerDriver;
private final int smallMaxDistinctValuesPerDriver;
private final int smallPartitionedMaxDistinctValuesPerDriver;
private final DataSize largeMaxSizePerDriver;
private final DataSize largePartitionedMaxSizePerDriver;
private final DataSize smallMaxSizePerDriver;
private final DataSize smallPartitionedMaxSizePerDriver;
private final int largeRangeRowLimitPerDriver;
private final int largePartitionedRangeRowLimitPerDriver;
private final int smallRangeRowLimitPerDriver;
private final int smallPartitionedRangeRowLimitPerDriver;
private final DataSize largeMaxSizePerOperator;
private final DataSize largePartitionedMaxSizePerOperator;
private final DataSize smallMaxSizePerOperator;
private final DataSize smallPartitionedMaxSizePerOperator;
private final BlockTypeOperators blockTypeOperators;
private final TypeOperators typeOperators;
private final TableExecuteContextManager tableExecuteContextManager;
private final ExchangeManagerRegistry exchangeManagerRegistry;
private final PositionsAppenderFactory positionsAppenderFactory;
private final NodeVersion version;
private final boolean specializeAggregationLoops;
private final NonEvictableCache accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterAccess(1, HOURS));
private final NonEvictableCache aggregationWindowFunctionSupplierCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterAccess(1, HOURS));
@Inject
public LocalExecutionPlanner(
PlannerContext plannerContext,
Optional explainAnalyzeContext,
PageSourceProvider pageSourceProvider,
IndexManager indexManager,
NodePartitioningManager nodePartitioningManager,
PageSinkManager pageSinkManager,
DirectExchangeClientSupplier directExchangeClientSupplier,
ExpressionCompiler expressionCompiler,
PageFunctionCompiler pageFunctionCompiler,
JoinFilterFunctionCompiler joinFilterFunctionCompiler,
IndexJoinLookupStats indexJoinLookupStats,
TaskManagerConfig taskManagerConfig,
SpillerFactory spillerFactory,
SingleStreamSpillerFactory singleStreamSpillerFactory,
PartitioningSpillerFactory partitioningSpillerFactory,
PagesIndex.Factory pagesIndexFactory,
JoinCompiler joinCompiler,
FlatHashStrategyCompiler hashStrategyCompiler,
OrderingCompiler orderingCompiler,
DynamicFilterConfig dynamicFilterConfig,
BlockTypeOperators blockTypeOperators,
TypeOperators typeOperators,
TableExecuteContextManager tableExecuteContextManager,
ExchangeManagerRegistry exchangeManagerRegistry,
NodeVersion version,
CompilerConfig compilerConfig)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.metadata = plannerContext.getMetadata();
this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null");
this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null");
this.indexManager = requireNonNull(indexManager, "indexManager is null");
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
this.directExchangeClientSupplier = directExchangeClientSupplier;
this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null");
this.expressionCompiler = requireNonNull(expressionCompiler, "expressionCompiler is null");
this.pageFunctionCompiler = requireNonNull(pageFunctionCompiler, "pageFunctionCompiler is null");
this.joinFilterFunctionCompiler = requireNonNull(joinFilterFunctionCompiler, "joinFilterFunctionCompiler is null");
this.indexJoinLookupStats = requireNonNull(indexJoinLookupStats, "indexJoinLookupStats is null");
this.maxIndexMemorySize = taskManagerConfig.getMaxIndexMemoryUsage();
this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null");
this.singleStreamSpillerFactory = requireNonNull(singleStreamSpillerFactory, "singleStreamSpillerFactory is null");
this.partitioningSpillerFactory = requireNonNull(partitioningSpillerFactory, "partitioningSpillerFactory is null");
this.maxPartialAggregationMemorySize = taskManagerConfig.getMaxPartialAggregationMemoryUsage();
this.maxPagePartitioningBufferSize = taskManagerConfig.getMaxPagePartitioningBufferSize();
this.maxLocalExchangeBufferSize = taskManagerConfig.getMaxLocalExchangeBufferSize();
this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");
this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null");
this.hashStrategyCompiler = requireNonNull(hashStrategyCompiler, "hashStrategyCompiler is null");
this.orderingCompiler = requireNonNull(orderingCompiler, "orderingCompiler is null");
this.largeMaxDistinctValuesPerDriver = dynamicFilterConfig.getLargeMaxDistinctValuesPerDriver();
this.smallMaxDistinctValuesPerDriver = dynamicFilterConfig.getSmallMaxDistinctValuesPerDriver();
this.smallPartitionedMaxDistinctValuesPerDriver = dynamicFilterConfig.getSmallPartitionedMaxDistinctValuesPerDriver();
this.largeMaxSizePerDriver = dynamicFilterConfig.getLargeMaxSizePerDriver();
this.largePartitionedMaxSizePerDriver = dynamicFilterConfig.getLargePartitionedMaxSizePerDriver();
this.smallMaxSizePerDriver = dynamicFilterConfig.getSmallMaxSizePerDriver();
this.smallPartitionedMaxSizePerDriver = dynamicFilterConfig.getSmallPartitionedMaxSizePerDriver();
this.largeRangeRowLimitPerDriver = dynamicFilterConfig.getLargeRangeRowLimitPerDriver();
this.largePartitionedRangeRowLimitPerDriver = dynamicFilterConfig.getLargePartitionedRangeRowLimitPerDriver();
this.smallRangeRowLimitPerDriver = dynamicFilterConfig.getSmallRangeRowLimitPerDriver();
this.smallPartitionedRangeRowLimitPerDriver = dynamicFilterConfig.getSmallPartitionedRangeRowLimitPerDriver();
this.largeMaxSizePerOperator = dynamicFilterConfig.getLargeMaxSizePerOperator();
this.largePartitionedMaxSizePerOperator = dynamicFilterConfig.getLargePartitionedMaxSizePerOperator();
this.smallMaxSizePerOperator = dynamicFilterConfig.getSmallMaxSizePerOperator();
this.smallPartitionedMaxSizePerOperator = dynamicFilterConfig.getSmallPartitionedMaxSizePerOperator();
this.largePartitionedMaxDistinctValuesPerDriver = dynamicFilterConfig.getLargePartitionedMaxDistinctValuesPerDriver();
this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
this.typeOperators = requireNonNull(typeOperators, "typeOperators is null");
this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null");
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
this.positionsAppenderFactory = new PositionsAppenderFactory(blockTypeOperators);
this.version = requireNonNull(version, "version is null");
this.specializeAggregationLoops = compilerConfig.isSpecializeAggregationLoops();
}
public LocalExecutionPlan plan(
TaskContext taskContext,
PlanNode plan,
PartitioningScheme partitioningScheme,
List partitionedSourceOrder,
OutputBuffer outputBuffer)
{
List outputLayout = partitioningScheme.getOutputLayout();
if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION) ||
partitioningScheme.getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) ||
partitioningScheme.getPartitioning().getHandle().equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION) ||
partitioningScheme.getPartitioning().getHandle().equals(SINGLE_DISTRIBUTION) ||
partitioningScheme.getPartitioning().getHandle().equals(COORDINATOR_DISTRIBUTION)) {
return plan(taskContext, plan, outputLayout, partitionedSourceOrder, new TaskOutputFactory(outputBuffer));
}
// We can convert the symbols directly into channels, because the root must be a sink and therefore the layout is fixed
List partitionChannels;
List> partitionConstants;
List partitionChannelTypes;
if (partitioningScheme.getHashColumn().isPresent()) {
partitionChannels = ImmutableList.of(outputLayout.indexOf(partitioningScheme.getHashColumn().get()));
partitionConstants = ImmutableList.of(Optional.empty());
partitionChannelTypes = ImmutableList.of(BIGINT);
}
else {
partitionChannels = partitioningScheme.getPartitioning().getArguments().stream()
.map(argument -> {
if (argument.isConstant()) {
return -1;
}
return outputLayout.indexOf(argument.getColumn());
})
.collect(toImmutableList());
partitionConstants = partitioningScheme.getPartitioning().getArguments().stream()
.map(argument -> {
if (argument.isConstant()) {
return Optional.of(argument.getConstant());
}
return Optional.empty();
})
.collect(toImmutableList());
partitionChannelTypes = partitioningScheme.getPartitioning().getArguments().stream()
.map(argument -> {
if (argument.isConstant()) {
return argument.getConstant().getType();
}
return argument.getColumn().type();
})
.collect(toImmutableList());
}
PartitionFunction partitionFunction;
Optional skewedPartitionRebalancer = Optional.empty();
int taskCount = getTaskCount(partitioningScheme);
if (checkCanScalePartitionsRemotely(taskContext.getSession(), taskCount, partitioningScheme.getPartitioning().getHandle(), nodePartitioningManager)) {
partitionFunction = createPartitionFunction(taskContext.getSession(), nodePartitioningManager, partitioningScheme, partitionChannelTypes);
int partitionedWriterCount = getPartitionedWriterCountBasedOnMemory(taskContext.getSession());
// Keep the task bucket count to 50% of total local writers
int taskBucketCount = (int) ceil(0.5 * partitionedWriterCount);
skewedPartitionRebalancer = Optional.of(new SkewedPartitionRebalancer(
partitionFunction.partitionCount(),
taskCount,
taskBucketCount,
getWriterScalingMinDataProcessed(taskContext.getSession()).toBytes(),
getSkewedPartitionMinDataProcessedRebalanceThreshold(taskContext.getSession()).toBytes()));
}
else {
partitionFunction = nodePartitioningManager.getPartitionFunction(taskContext.getSession(), partitioningScheme, partitionChannelTypes);
}
OptionalInt nullChannel = OptionalInt.empty();
Set partitioningColumns = partitioningScheme.getPartitioning().getColumns();
// partitioningColumns expected to have one column in the normal case, and zero columns when partitioning on a constant
checkArgument(!partitioningScheme.isReplicateNullsAndAny() || partitioningColumns.size() <= 1);
if (partitioningScheme.isReplicateNullsAndAny() && partitioningColumns.size() == 1) {
nullChannel = OptionalInt.of(outputLayout.indexOf(getOnlyElement(partitioningColumns)));
}
return plan(
taskContext,
plan,
outputLayout,
partitionedSourceOrder,
new PartitionedOutputFactory(
partitionFunction,
partitionChannels,
partitionConstants,
partitioningScheme.isReplicateNullsAndAny(),
nullChannel,
outputBuffer,
maxPagePartitioningBufferSize,
positionsAppenderFactory,
taskContext.getSession().getExchangeEncryptionKey(),
taskContext.newAggregateMemoryContext(),
getPagePartitioningBufferPoolSize(taskContext.getSession()),
skewedPartitionRebalancer));
}
public LocalExecutionPlan plan(
TaskContext taskContext,
PlanNode plan,
List outputLayout,
List partitionedSourceOrder,
OutputFactory outputOperatorFactory)
{
Session session = taskContext.getSession();
LocalExecutionPlanContext context = new LocalExecutionPlanContext(taskContext);
PhysicalOperation physicalOperation = plan.accept(new Visitor(session), context);
Function pagePreprocessor = enforceLoadedLayoutProcessor(outputLayout, physicalOperation.getLayout());
List outputTypes = outputLayout.stream()
.map(Symbol::type)
.collect(toImmutableList());
context.addDriverFactory(
true,
new PhysicalOperation(
outputOperatorFactory.createOutputOperator(
context.getNextOperatorId(),
plan.getId(),
outputTypes,
pagePreprocessor,
new PagesSerdeFactory(plannerContext.getBlockEncodingSerde(), getExchangeCompressionCodec(session))),
physicalOperation),
context);
// notify operator factories that planning has completed
context.getDriverFactories().stream()
.map(DriverFactory::getOperatorFactories)
.flatMap(List::stream)
.filter(LocalPlannerAware.class::isInstance)
.map(LocalPlannerAware.class::cast)
.forEach(LocalPlannerAware::localPlannerComplete);
return new LocalExecutionPlan(context.getDriverFactories(), partitionedSourceOrder);
}
private static class LocalExecutionPlanContext
{
private final TaskContext taskContext;
private final List driverFactories;
private final Optional indexSourceContext;
// this is shared with all subContexts
private final AtomicInteger nextPipelineId;
private int nextOperatorId;
private boolean inputDriver = true;
private OptionalInt driverInstanceCount = OptionalInt.empty();
public LocalExecutionPlanContext(TaskContext taskContext)
{
this(
taskContext,
new ArrayList<>(),
Optional.empty(),
new AtomicInteger(0));
}
private LocalExecutionPlanContext(
TaskContext taskContext,
List driverFactories,
Optional indexSourceContext,
AtomicInteger nextPipelineId)
{
this.taskContext = taskContext;
this.driverFactories = driverFactories;
this.indexSourceContext = indexSourceContext;
this.nextPipelineId = nextPipelineId;
}
public void addDriverFactory(boolean outputDriver, PhysicalOperation physicalOperation, LocalExecutionPlanContext context)
{
boolean inputDriver = context.isInputDriver();
OptionalInt driverInstances = context.getDriverInstanceCount();
List operatorFactories = physicalOperation.getOperatorFactories();
addLookupOuterDrivers(outputDriver, operatorFactories);
addDriverFactory(inputDriver, outputDriver, operatorFactories, driverInstances);
}
private void addLookupOuterDrivers(boolean isOutputDriver, List operatorFactories)
{
// For an outer join on the lookup side (RIGHT or FULL) add an additional
// driver to output the unused rows in the lookup source
for (int i = 0; i < operatorFactories.size(); i++) {
OperatorFactory operatorFactory = operatorFactories.get(i);
if (!(operatorFactory instanceof JoinOperatorFactory lookupJoin)) {
continue;
}
Optional outerOperatorFactoryResult = lookupJoin.createOuterOperatorFactory();
if (outerOperatorFactoryResult.isPresent()) {
// Add a new driver to output the unmatched rows in an outer join.
// We duplicate all of the factories above the JoinOperator (the ones reading from the joins),
// and replace the JoinOperator with the OuterOperator (the one that produces unmatched rows).
ImmutableList.Builder newOperators = ImmutableList.builder();
newOperators.add(outerOperatorFactoryResult.get());
operatorFactories.subList(i + 1, operatorFactories.size()).stream()
.map(OperatorFactory::duplicate)
.forEach(newOperators::add);
addDriverFactory(false, isOutputDriver, newOperators.build(), OptionalInt.of(1));
}
}
}
private void addDriverFactory(boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances)
{
driverFactories.add(new DriverFactory(getNextPipelineId(), inputDriver, outputDriver, operatorFactories, driverInstances));
}
private List getDriverFactories()
{
return ImmutableList.copyOf(driverFactories);
}
public StageId getStageId()
{
return taskContext.getTaskId().getStageId();
}
public TaskId getTaskId()
{
return taskContext.getTaskId();
}
public LocalDynamicFiltersCollector getDynamicFiltersCollector()
{
return taskContext.getLocalDynamicFiltersCollector();
}
private void registerCoordinatorDynamicFilters(List dynamicFilters)
{
if (!isEnableCoordinatorDynamicFiltersDistribution(taskContext.getSession())) {
return;
}
Set consumedFilterIds = dynamicFilters.stream()
.map(DynamicFilters.Descriptor::getId)
.collect(toImmutableSet());
LocalDynamicFiltersCollector dynamicFiltersCollector = getDynamicFiltersCollector();
// Don't repeat registration of node-local filters or those already registered by another scan (e.g. co-located joins)
dynamicFiltersCollector.register(
difference(consumedFilterIds, dynamicFiltersCollector.getRegisteredDynamicFilterIds()));
}
private TaskContext getTaskContext()
{
return taskContext;
}
public Optional getIndexSourceContext()
{
return indexSourceContext;
}
private int getNextPipelineId()
{
return nextPipelineId.getAndIncrement();
}
private int getNextOperatorId()
{
return nextOperatorId++;
}
private boolean isInputDriver()
{
return inputDriver;
}
private void setInputDriver(boolean inputDriver)
{
this.inputDriver = inputDriver;
}
public LocalExecutionPlanContext createSubContext()
{
checkState(indexSourceContext.isEmpty(), "index build plan cannot have sub-contexts");
return new LocalExecutionPlanContext(taskContext, driverFactories, indexSourceContext, nextPipelineId);
}
public LocalExecutionPlanContext createIndexSourceSubContext(IndexSourceContext indexSourceContext)
{
return new LocalExecutionPlanContext(taskContext, driverFactories, Optional.of(indexSourceContext), nextPipelineId);
}
public OptionalInt getDriverInstanceCount()
{
return driverInstanceCount;
}
public void setDriverInstanceCount(int driverInstanceCount)
{
checkArgument(driverInstanceCount > 0, "driverInstanceCount must be > 0");
if (this.driverInstanceCount.isPresent()) {
checkState(this.driverInstanceCount.getAsInt() == driverInstanceCount, "driverInstance count already set to " + this.driverInstanceCount.getAsInt());
}
this.driverInstanceCount = OptionalInt.of(driverInstanceCount);
}
}
private static class IndexSourceContext
{
private final SetMultimap indexLookupToProbeInput;
public IndexSourceContext(SetMultimap indexLookupToProbeInput)
{
this.indexLookupToProbeInput = ImmutableSetMultimap.copyOf(requireNonNull(indexLookupToProbeInput, "indexLookupToProbeInput is null"));
}
private SetMultimap getIndexLookupToProbeInput()
{
return indexLookupToProbeInput;
}
}
public static class LocalExecutionPlan
{
private final List driverFactories;
private final List partitionedSourceOrder;
public LocalExecutionPlan(List driverFactories, List partitionedSourceOrder)
{
this.driverFactories = ImmutableList.copyOf(requireNonNull(driverFactories, "driverFactories is null"));
this.partitionedSourceOrder = ImmutableList.copyOf(requireNonNull(partitionedSourceOrder, "partitionedSourceOrder is null"));
}
public List getDriverFactories()
{
return driverFactories;
}
public List getPartitionedSourceOrder()
{
return partitionedSourceOrder;
}
}
private class Visitor
extends PlanVisitor
{
private final Session session;
private Visitor(Session session)
{
this.session = session;
}
@Override
public PhysicalOperation visitRemoteSource(RemoteSourceNode node, LocalExecutionPlanContext context)
{
if (node.getOrderingScheme().isPresent()) {
return createMergeSource(node, context);
}
return createRemoteSource(node, context);
}
private PhysicalOperation createMergeSource(RemoteSourceNode node, LocalExecutionPlanContext context)
{
checkArgument(node.getOrderingScheme().isPresent(), "orderingScheme is absent");
checkArgument(node.getRetryPolicy() == RetryPolicy.NONE, "unexpected retry policy: %s", node.getRetryPolicy());
// merging remote source must have a single driver
context.setDriverInstanceCount(1);
OrderingScheme orderingScheme = node.getOrderingScheme().get();
ImmutableMap layout = makeLayout(node);
List sortChannels = getChannelsForSymbols(orderingScheme.orderBy(), layout);
List sortOrder = orderingScheme.orderingList();
List types = getSourceOperatorTypes(node);
ImmutableList outputChannels = IntStream.range(0, types.size())
.boxed()
.collect(toImmutableList());
OperatorFactory operatorFactory = new MergeOperatorFactory(
context.getNextOperatorId(),
node.getId(),
directExchangeClientSupplier,
new PagesSerdeFactory(plannerContext.getBlockEncodingSerde(), getExchangeCompressionCodec(session)),
orderingCompiler,
types,
outputChannels,
sortChannels,
sortOrder);
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
private PhysicalOperation createRemoteSource(RemoteSourceNode node, LocalExecutionPlanContext context)
{
if (context.getDriverInstanceCount().isEmpty()) {
context.setDriverInstanceCount(getTaskConcurrency(session));
}
OperatorFactory operatorFactory = new ExchangeOperatorFactory(
context.getNextOperatorId(),
node.getId(),
directExchangeClientSupplier,
new PagesSerdeFactory(plannerContext.getBlockEncodingSerde(), getExchangeCompressionCodec(session)),
node.getRetryPolicy(),
exchangeManagerRegistry);
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
@Override
public PhysicalOperation visitExplainAnalyze(ExplainAnalyzeNode node, LocalExecutionPlanContext context)
{
ExplainAnalyzeContext analyzeContext = explainAnalyzeContext
.orElseThrow(() -> new IllegalStateException("ExplainAnalyze can only run on coordinator"));
PhysicalOperation source = node.getSource().accept(this, context);
OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(
context.getNextOperatorId(),
node.getId(),
analyzeContext.getQueryPerformanceFetcher(),
metadata,
plannerContext.getFunctionManager(),
node.isVerbose(),
version);
return new PhysicalOperation(operatorFactory, makeLayout(node), source);
}
@Override
public PhysicalOperation visitOutput(OutputNode node, LocalExecutionPlanContext context)
{
return node.getSource().accept(this, context);
}
@Override
public PhysicalOperation visitRowNumber(RowNumberNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List partitionBySymbols = node.getPartitionBy();
List partitionChannels = getChannelsForSymbols(partitionBySymbols, source.getLayout());
List partitionTypes = partitionChannels.stream()
.map(channel -> source.getTypes().get(channel))
.collect(toImmutableList());
ImmutableList.Builder outputChannels = ImmutableList.builder();
for (int i = 0; i < source.getTypes().size(); i++) {
outputChannels.add(i);
}
// compute the layout of the output from the window operator
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
outputMappings.putAll(source.getLayout());
// row number function goes in the last channel
int channel = source.getTypes().size();
outputMappings.put(node.getRowNumberSymbol(), channel);
Optional hashChannel = node.getHashSymbol().map(channelGetter(source));
OperatorFactory operatorFactory = new RowNumberOperator.RowNumberOperatorFactory(
context.getNextOperatorId(),
node.getId(),
source.getTypes(),
outputChannels.build(),
partitionChannels,
partitionTypes,
node.getMaxRowCountPerPartition(),
hashChannel,
10_000,
hashStrategyCompiler);
return new PhysicalOperation(operatorFactory, outputMappings.buildOrThrow(), source);
}
@Override
public PhysicalOperation visitTopNRanking(TopNRankingNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List partitionBySymbols = node.getPartitionBy();
List partitionChannels = getChannelsForSymbols(partitionBySymbols, source.getLayout());
List partitionTypes = partitionChannels.stream()
.map(channel -> source.getTypes().get(channel))
.collect(toImmutableList());
List orderBySymbols = node.getOrderingScheme().orderBy();
List sortChannels = getChannelsForSymbols(orderBySymbols, source.getLayout());
List sortOrder = orderBySymbols.stream()
.map(symbol -> node.getOrderingScheme().ordering(symbol))
.collect(toImmutableList());
ImmutableList.Builder outputChannels = ImmutableList.builder();
for (int i = 0; i < source.getTypes().size(); i++) {
outputChannels.add(i);
}
// compute the layout of the output from the window operator
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
outputMappings.putAll(source.getLayout());
if (!node.isPartial() || !partitionChannels.isEmpty()) {
// ranking function goes in the last channel
int channel = source.getTypes().size();
outputMappings.put(node.getRankingSymbol(), channel);
}
Optional hashChannel = node.getHashSymbol().map(channelGetter(source));
boolean isPartial = node.isPartial();
Optional maxPartialTopNMemorySize = isPartial ? Optional.of(SystemSessionProperties.getMaxPartialTopNMemory(session)).filter(
maxSize -> maxSize.compareTo(DataSize.ofBytes(0)) > 0) : Optional.empty();
OperatorFactory operatorFactory = new TopNRankingOperator.TopNRankingOperatorFactory(
context.getNextOperatorId(),
node.getId(),
node.getRankingType(),
source.getTypes(),
outputChannels.build(),
partitionChannels,
partitionTypes,
sortChannels,
sortOrder,
node.getMaxRankingPerPartition(),
isPartial,
hashChannel,
1000,
maxPartialTopNMemorySize,
hashStrategyCompiler,
plannerContext.getTypeOperators(),
blockTypeOperators);
return new PhysicalOperation(operatorFactory, makeLayout(node), source);
}
@Override
public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List partitionBySymbols = node.getPartitionBy();
List partitionChannels = ImmutableList.copyOf(getChannelsForSymbols(partitionBySymbols, source.getLayout()));
List preGroupedChannels = ImmutableList.copyOf(getChannelsForSymbols(ImmutableList.copyOf(node.getPrePartitionedInputs()), source.getLayout()));
List sortChannels = ImmutableList.of();
List sortOrder = ImmutableList.of();
if (node.getOrderingScheme().isPresent()) {
OrderingScheme orderingScheme = node.getOrderingScheme().get();
sortChannels = getChannelsForSymbols(orderingScheme.orderBy(), source.getLayout());
sortOrder = orderingScheme.orderingList();
}
ImmutableList.Builder outputChannels = ImmutableList.builder();
for (int i = 0; i < source.getTypes().size(); i++) {
outputChannels.add(i);
}
ImmutableList.Builder windowFunctionsBuilder = ImmutableList.builder();
ImmutableList.Builder windowFunctionOutputSymbolsBuilder = ImmutableList.builder();
for (Map.Entry entry : node.getWindowFunctions().entrySet()) {
Optional frameStartChannel = Optional.empty();
Optional sortKeyChannelForStartComparison = Optional.empty();
Optional frameEndChannel = Optional.empty();
Optional sortKeyChannelForEndComparison = Optional.empty();
Optional sortKeyChannel = Optional.empty();
Optional ordering = Optional.empty();
Frame frame = entry.getValue().getFrame();
if (frame.getStartValue().isPresent()) {
frameStartChannel = Optional.of(source.getLayout().get(frame.getStartValue().get()));
}
if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) {
sortKeyChannelForStartComparison = Optional.of(source.getLayout().get(frame.getSortKeyCoercedForFrameStartComparison().get()));
}
if (frame.getEndValue().isPresent()) {
frameEndChannel = Optional.of(source.getLayout().get(frame.getEndValue().get()));
}
if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) {
sortKeyChannelForEndComparison = Optional.of(source.getLayout().get(frame.getSortKeyCoercedForFrameEndComparison().get()));
}
if (node.getOrderingScheme().isPresent()) {
sortKeyChannel = Optional.of(sortChannels.get(0));
ordering = Optional.of(sortOrder.get(0).isAscending() ? ASCENDING : DESCENDING);
}
FrameInfo frameInfo = new FrameInfo(
frame.getType(),
frame.getStartType(),
frameStartChannel,
sortKeyChannelForStartComparison,
frame.getEndType(),
frameEndChannel,
sortKeyChannelForEndComparison,
sortKeyChannel,
ordering);
WindowNode.Function function = entry.getValue();
ResolvedFunction resolvedFunction = function.getResolvedFunction();
ImmutableList.Builder arguments = ImmutableList.builder();
for (Expression argument : function.getArguments()) {
if (!(argument instanceof Lambda)) {
Symbol argumentSymbol = Symbol.from(argument);
arguments.add(source.getLayout().get(argumentSymbol));
}
}
Symbol symbol = entry.getKey();
WindowFunctionSupplier windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction);
Type type = resolvedFunction.signature().getReturnType();
List lambdas = function.getArguments().stream()
.filter(Lambda.class::isInstance)
.map(Lambda.class::cast)
.collect(toImmutableList());
List functionTypes = resolvedFunction.signature().getArgumentTypes().stream()
.filter(FunctionType.class::isInstance)
.map(FunctionType.class::cast)
.collect(toImmutableList());
List> lambdaProviders = makeLambdaProviders(lambdas, windowFunctionSupplier.getLambdaInterfaces(), functionTypes);
windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), lambdaProviders, arguments.build()));
windowFunctionOutputSymbolsBuilder.add(symbol);
}
List windowFunctionOutputSymbols = windowFunctionOutputSymbolsBuilder.build();
// compute the layout of the output from the window operator
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
for (Symbol symbol : node.getSource().getOutputSymbols()) {
outputMappings.put(symbol, source.getLayout().get(symbol));
}
// window functions go in remaining channels starting after the last channel from the source operator, one per channel
int channel = source.getTypes().size();
for (Symbol symbol : windowFunctionOutputSymbols) {
outputMappings.put(symbol, channel);
channel++;
}
OperatorFactory operatorFactory = new WindowOperatorFactory(
context.getNextOperatorId(),
node.getId(),
source.getTypes(),
outputChannels.build(),
windowFunctionsBuilder.build(),
partitionChannels,
preGroupedChannels,
sortChannels,
sortOrder,
node.getPreSortedOrderPrefix(),
10_000,
pagesIndexFactory,
isSpillEnabled(session),
spillerFactory,
orderingCompiler,
ImmutableList.of(),
new RegularPartitionerSupplier());
return new PhysicalOperation(operatorFactory, outputMappings.buildOrThrow(), source);
}
private WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
{
if (resolvedFunction.functionKind() == FunctionKind.AGGREGATE) {
return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.functionId(), resolvedFunction.signature()), () -> {
AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction);
return new AggregationWindowFunctionSupplier(
resolvedFunction.signature(),
aggregationImplementation,
resolvedFunction.functionNullability());
});
}
return plannerContext.getFunctionManager().getWindowFunctionSupplier(resolvedFunction);
}
@Override
public PhysicalOperation visitPatternRecognition(PatternRecognitionNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List partitionBySymbols = node.getPartitionBy();
List partitionChannels = ImmutableList.copyOf(getChannelsForSymbols(partitionBySymbols, source.getLayout()));
List preGroupedChannels = ImmutableList.copyOf(getChannelsForSymbols(ImmutableList.copyOf(node.getPrePartitionedInputs()), source.getLayout()));
List sortChannels = ImmutableList.of();
List sortOrder = ImmutableList.of();
if (node.getOrderingScheme().isPresent()) {
OrderingScheme orderingScheme = node.getOrderingScheme().get();
sortChannels = getChannelsForSymbols(orderingScheme.orderBy(), source.getLayout());
sortOrder = orderingScheme.orderingList();
}
// The output order for pattern recognition operation is defined as follows:
// - for ONE ROW PER MATCH: partition by symbols, then measures,
// - for ALL ROWS PER MATCH: partition by symbols, order by symbols, measures, remaining input symbols,
// - for WINDOW: all input symbols, then window functions (including measures).
// The operator produces output in the following order:
// - for ONE ROW PER MATCH: partition by symbols, then measures,
// - otherwise all input symbols, then window functions and measures.
// There is no need to shuffle channels for output. Any upstream operator will pick them in preferred order using output mappings.
// input channels to be passed directly to output
ImmutableList.Builder outputChannels = ImmutableList.builder();
// all output symbols mapped to output channels
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
int nextOutputChannel;
if (node.getRowsPerMatch() == ONE) {
outputChannels.addAll(partitionChannels);
nextOutputChannel = partitionBySymbols.size();
for (int i = 0; i < partitionBySymbols.size(); i++) {
outputMappings.put(partitionBySymbols.get(i), i);
}
}
else {
outputChannels.addAll(IntStream.range(0, source.getTypes().size())
.boxed()
.collect(toImmutableList()));
nextOutputChannel = source.getTypes().size();
outputMappings.putAll(source.getLayout());
}
// measures go in remaining channels starting after the last channel from the source operator, one per channel
for (Map.Entry measure : node.getMeasures().entrySet()) {
outputMappings.put(measure.getKey(), nextOutputChannel);
nextOutputChannel++;
}
// process window functions
ImmutableList.Builder windowFunctionsBuilder = ImmutableList.builder();
for (Map.Entry entry : node.getWindowFunctions().entrySet()) {
// window functions outputs go in remaining channels starting after the last measure channel
outputMappings.put(entry.getKey(), nextOutputChannel);
nextOutputChannel++;
WindowNode.Function function = entry.getValue();
ResolvedFunction resolvedFunction = function.getResolvedFunction();
ImmutableList.Builder arguments = ImmutableList.builder();
for (Expression argument : function.getArguments()) {
if (!(argument instanceof Lambda)) {
Symbol argumentSymbol = Symbol.from(argument);
arguments.add(source.getLayout().get(argumentSymbol));
}
}
WindowFunctionSupplier windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction);
Type type = resolvedFunction.signature().getReturnType();
List lambdas = function.getArguments().stream()
.filter(Lambda.class::isInstance)
.map(Lambda.class::cast)
.collect(toImmutableList());
List functionTypes = resolvedFunction.signature().getArgumentTypes().stream()
.filter(FunctionType.class::isInstance)
.map(FunctionType.class::cast)
.collect(toImmutableList());
List> lambdaProviders = makeLambdaProviders(lambdas, windowFunctionSupplier.getLambdaInterfaces(), functionTypes);
windowFunctionsBuilder.add(window(windowFunctionSupplier, type, function.isIgnoreNulls(), lambdaProviders, arguments.build()));
}
// prepare structures specific to PatternRecognitionNode
// 1. establish a two-way mapping of IrLabels to `int`
List primaryLabels = ImmutableList.copyOf(node.getVariableDefinitions().keySet());
ImmutableList.Builder labelNamesBuilder = ImmutableList.builder();
ImmutableMap.Builder mappingBuilder = ImmutableMap.builder();
for (int i = 0; i < primaryLabels.size(); i++) {
IrLabel label = primaryLabels.get(i);
labelNamesBuilder.add(label.getName());
mappingBuilder.put(label, i);
}
Map mapping = mappingBuilder.buildOrThrow();
List labelNames = labelNamesBuilder.build();
// 2. rewrite pattern to program
Program program = IrRowPatternToProgramRewriter.rewrite(node.getPattern(), mapping);
// 3. prepare common base frame for pattern matching in window
Optional frame = node.getCommonBaseFrame()
.map(baseFrame -> {
checkArgument(
baseFrame.getType() == ROWS &&
baseFrame.getStartType() == CURRENT_ROW,
"invalid base frame");
return new FrameInfo(
baseFrame.getType(),
baseFrame.getStartType(),
Optional.empty(),
Optional.empty(),
baseFrame.getEndType(),
baseFrame.getEndValue().map(source.getLayout()::get),
Optional.empty(),
Optional.empty(),
Optional.empty());
});
ConnectorSession connectorSession = session.toConnectorSession();
// 4. prepare label evaluations (LabelEvaluator is to be instantiated once per Partition)
// during pattern matching, each thread will have a list of aggregations necessary for label evaluations.
// the list of aggregations for a thread will be produced at thread creation time from this supplier list, respecting the order.
// pointers in LabelEvaluator and ThreadEquivalence will access aggregations by position in list.
int matchAggregationIndex = 0;
ImmutableList.Builder labelEvaluationsAggregations = ImmutableList.builder();
// runtime-evaluated aggregation arguments will appear in additional channels after all source channels
int firstUnusedChannel = source.getLayout().values().stream().mapToInt(Integer::intValue).max().orElse(-1) + 1;
ImmutableList.Builder labelEvaluationsAggregationArguments = ImmutableList.builder();
ImmutableList.Builder> evaluationsValuePointers = ImmutableList.builder();
ImmutableList.Builder aggregationsLabelDependencies = ImmutableList.builder();
ImmutableList.Builder evaluationsBuilder = ImmutableList.builder();
for (ExpressionAndValuePointers expressionAndValuePointers : node.getVariableDefinitions().values()) {
// compile the rewritten expression
Supplier pageProjectionSupplier = prepareProjection(expressionAndValuePointers);
// prepare physical value accessors to provide input for the expression
ValueAccessors valueAccessors = preparePhysicalValuePointers(expressionAndValuePointers, mapping, source, connectorSession, firstUnusedChannel, matchAggregationIndex);
firstUnusedChannel = valueAccessors.getFirstUnusedChannel();
matchAggregationIndex = valueAccessors.getAggregationIndex();
// record aggregations
labelEvaluationsAggregations.addAll(valueAccessors.getAggregations());
// record aggregation argument computations
labelEvaluationsAggregationArguments.addAll(valueAccessors.getAggregationArguments());
// record aggregation label dependencies and value accessors for ThreadEquivalence
aggregationsLabelDependencies.addAll(valueAccessors.getLabelDependencies());
evaluationsValuePointers.add(valueAccessors.getValueAccessors());
// build label evaluation
evaluationsBuilder.add(new EvaluationSupplier(pageProjectionSupplier, valueAccessors.getValueAccessors(), labelNames, connectorSession));
}
List labelEvaluations = evaluationsBuilder.build();
// 5. prepare measures computations
matchAggregationIndex = 0;
ImmutableList.Builder measureComputationsAggregations = ImmutableList.builder();
// runtime-evaluated aggregation arguments will appear in additional channels after all source channels
// measure computations will use a different instance of WindowIndex than the label evaluations
firstUnusedChannel = source.getLayout().values().stream().mapToInt(Integer::intValue).max().orElse(-1) + 1;
ImmutableList.Builder measureComputationsAggregationArguments = ImmutableList.builder();
ImmutableList.Builder measuresBuilder = ImmutableList.builder();
for (Measure measure : node.getMeasures().values()) {
ExpressionAndValuePointers expressionAndValuePointers = measure.getExpressionAndValuePointers();
// compile the rewritten expression
Supplier pageProjectionSupplier = prepareProjection(expressionAndValuePointers);
// prepare physical value accessors to provide input for the expression
ValueAccessors valueAccessors = preparePhysicalValuePointers(expressionAndValuePointers, mapping, source, connectorSession, firstUnusedChannel, matchAggregationIndex);
firstUnusedChannel = valueAccessors.getFirstUnusedChannel();
matchAggregationIndex = valueAccessors.getAggregationIndex();
// record aggregations
measureComputationsAggregations.addAll(valueAccessors.getAggregations());
// record aggregation argument computations
measureComputationsAggregationArguments.addAll(valueAccessors.getAggregationArguments());
// build measure computation
measuresBuilder.add(new MeasureComputationSupplier(pageProjectionSupplier, valueAccessors.getValueAccessors(), measure.getType(), labelNames, connectorSession));
}
List measureComputations = measuresBuilder.build();
// 6. prepare SKIP TO navigation
Optional skipToNavigation = Optional.empty();
if (!node.getSkipToLabels().isEmpty()) {
boolean last = node.getSkipToPosition().equals(LAST);
skipToNavigation = Optional.of(new LogicalIndexPointer(node.getSkipToLabels(), last, false, 0, 0).toLogicalIndexNavigation(mapping));
}
// 7. pass additional info like: rowsPerMatch, skipToPosition, initial to the WindowPartition factory supplier
PartitionerSupplier partitionerSupplier = new PatternRecognitionPartitionerSupplier(
measureComputations,
measureComputationsAggregations.build(),
measureComputationsAggregationArguments.build(),
frame,
node.getRowsPerMatch(),
skipToNavigation,
node.getSkipToPosition(),
node.isInitial(),
new Matcher(program, evaluationsValuePointers.build(), aggregationsLabelDependencies.build(), labelEvaluationsAggregations.build()),
labelEvaluations,
labelEvaluationsAggregationArguments.build(),
labelNames);
OperatorFactory operatorFactory = new WindowOperatorFactory(
context.getNextOperatorId(),
node.getId(),
source.getTypes(),
outputChannels.build(),
windowFunctionsBuilder.build(),
partitionChannels,
preGroupedChannels,
sortChannels,
sortOrder,
node.getPreSortedOrderPrefix(),
10_000,
pagesIndexFactory,
isSpillEnabled(session),
spillerFactory,
orderingCompiler,
node.getMeasures().values().stream()
.map(Measure::getType)
.collect(toImmutableList()),
partitionerSupplier);
return new PhysicalOperation(operatorFactory, outputMappings.buildOrThrow(), source);
}
private Supplier prepareProjection(ExpressionAndValuePointers expressionAndValuePointers)
{
Expression rewritten = expressionAndValuePointers.getExpression();
// prepare input layout and type provider for compilation
ImmutableMap.Builder inputTypes = ImmutableMap.builder();
ImmutableMap.Builder inputLayout = ImmutableMap.builder();
List assignments = expressionAndValuePointers.getAssignments();
for (int i = 0; i < assignments.size(); i++) {
ExpressionAndValuePointers.Assignment assignment = assignments.get(i);
inputLayout.put(assignment.symbol(), i);
inputTypes.put(
assignment.symbol(),
switch (assignment.valuePointer()) {
case AggregationValuePointer pointer -> pointer.getFunction().signature().getReturnType();
case ClassifierValuePointer pointer -> VARCHAR;
case MatchNumberValuePointer pointer -> BIGINT;
case ScalarValuePointer pointer -> pointer.getInputSymbol().type();
});
}
// compile expression using input layout and input types
RowExpression rowExpression = toRowExpression(rewritten, inputLayout.buildOrThrow());
return pageFunctionCompiler.compileProjection(rowExpression, Optional.empty());
}
private ValueAccessors preparePhysicalValuePointers(
ExpressionAndValuePointers expressionAndValuePointers,
Map mapping,
PhysicalOperation source,
ConnectorSession connectorSession,
int firstUnusedChannel,
int matchAggregationIndex)
{
Map sourceLayout = source.getLayout();
ImmutableList.Builder matchAggregations = ImmutableList.builder();
// runtime-evaluated aggregation arguments mapped to free channel slots
ImmutableList.Builder aggregationArguments = ImmutableList.builder();
// for thread equivalence
ImmutableList.Builder labelDependencies = ImmutableList.builder();
ImmutableList.Builder valueAccessors = ImmutableList.builder();
for (ExpressionAndValuePointers.Assignment assignment : expressionAndValuePointers.getAssignments()) {
switch (assignment.valuePointer()) {
case ClassifierValuePointer pointer -> {
valueAccessors.add(new PhysicalValuePointer(
CLASSIFIER,
VARCHAR,
pointer.getLogicalIndexPointer().toLogicalIndexNavigation(mapping)));
}
case MatchNumberValuePointer pointer -> {
valueAccessors.add(new PhysicalValuePointer(MATCH_NUMBER, BIGINT, LogicalIndexNavigation.NO_OP));
}
case ScalarValuePointer pointer -> {
valueAccessors.add(new PhysicalValuePointer(
getOnlyElement(getChannelsForSymbols(ImmutableList.of(pointer.getInputSymbol()), sourceLayout)),
pointer.getInputSymbol().type(),
pointer.getLogicalIndexPointer().toLogicalIndexNavigation(mapping)));
}
case AggregationValuePointer pointer -> {
boolean classifierInvolved = false;
ResolvedFunction resolvedFunction = pointer.getFunction();
AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(pointer.getFunction());
ImmutableList.Builder> builder = ImmutableList.builder();
List signatureTypes = resolvedFunction.signature().getArgumentTypes();
for (int i = 0; i < pointer.getArguments().size(); i++) {
builder.add(new SimpleEntry<>(pointer.getArguments().get(i), signatureTypes.get(i)));
}
Map>> arguments = builder.build().stream()
.collect(partitioningBy(entry -> entry.getKey() instanceof Lambda));
// handle lambda arguments
List lambdas = arguments.get(true).stream()
.map(Map.Entry::getKey)
.map(Lambda.class::cast)
.collect(toImmutableList());
List functionTypes = resolvedFunction.signature().getArgumentTypes().stream()
.filter(FunctionType.class::isInstance)
.map(FunctionType.class::cast)
.collect(toImmutableList());
// TODO when we support lambda arguments: lambda cannot have runtime-evaluated symbols -- add check in the Analyzer
List> lambdaProviders = makeLambdaProviders(lambdas, aggregationImplementation.getLambdaInterfaces(), functionTypes);
// handle non-lambda arguments
List valueChannels = new ArrayList<>();
Optional classifierArgumentSymbol = pointer.getClassifierSymbol();
Optional matchNumberArgumentSymbol = pointer.getMatchNumberSymbol();
Set runtimeEvaluatedSymbols = ImmutableSet.of(classifierArgumentSymbol, matchNumberArgumentSymbol).stream()
.flatMap(Optional::stream)
.collect(toImmutableSet());
for (Map.Entry argumentWithType : arguments.get(false)) {
Expression argument = argumentWithType.getKey();
boolean isRuntimeEvaluated = !(argument instanceof Reference) || runtimeEvaluatedSymbols.contains(Symbol.from(argument));
if (isRuntimeEvaluated) {
List argumentInputSymbols = ImmutableList.copyOf(SymbolsExtractor.extractUnique(argument));
Supplier argumentProjectionSupplier = prepareArgumentProjection(argument, argumentInputSymbols);
List argumentInputChannels = new ArrayList<>();
for (Symbol symbol : argumentInputSymbols) {
if (classifierArgumentSymbol.isPresent() && symbol.equals(classifierArgumentSymbol.get())) {
classifierInvolved = true;
argumentInputChannels.add(CLASSIFIER);
}
else if (matchNumberArgumentSymbol.isPresent() && symbol.equals(matchNumberArgumentSymbol.get())) {
argumentInputChannels.add(MATCH_NUMBER);
}
else {
argumentInputChannels.add(sourceLayout.get(symbol));
}
}
Type argumentType = argumentWithType.getValue();
ArgumentComputationSupplier argumentComputationSupplier = new ArgumentComputationSupplier(argumentProjectionSupplier, argumentType, argumentInputChannels, connectorSession);
aggregationArguments.add(argumentComputationSupplier);
// the runtime-evaluated argument will appear in an extra channel after all input channels
valueChannels.add(firstUnusedChannel);
firstUnusedChannel++;
}
else {
valueChannels.add(sourceLayout.get(Symbol.from(argument)));
}
}
AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier = uncheckedCacheGet(
aggregationWindowFunctionSupplierCache,
new FunctionKey(resolvedFunction.functionId(), resolvedFunction.signature()),
() -> new AggregationWindowFunctionSupplier(
resolvedFunction.signature(),
aggregationImplementation,
resolvedFunction.functionNullability()));
matchAggregations.add(new MatchAggregationInstantiator(
resolvedFunction.signature(),
aggregationWindowFunctionSupplier,
valueChannels,
lambdaProviders,
new SetEvaluatorSupplier(pointer.getSetDescriptor(), mapping)));
labelDependencies.add(new MatchAggregationLabelDependency(
pointer.getSetDescriptor().getLabels().stream()
.map(mapping::get)
.collect(toImmutableSet()),
classifierInvolved));
valueAccessors.add(new MatchAggregationPointer(matchAggregationIndex));
matchAggregationIndex++;
}
}
}
return new ValueAccessors(valueAccessors.build(), matchAggregations.build(), matchAggregationIndex, aggregationArguments.build(), firstUnusedChannel, labelDependencies.build());
}
private Supplier prepareArgumentProjection(Expression argument, List inputSymbols)
{
// prepare input layout and type provider for compilation
ImmutableMap.Builder inputLayout = ImmutableMap.builder();
for (int i = 0; i < inputSymbols.size(); i++) {
inputLayout.put(inputSymbols.get(i), i);
}
// compile expression using input layout and input types
RowExpression rowExpression = toRowExpression(argument, inputLayout.buildOrThrow());
return pageFunctionCompiler.compileProjection(rowExpression, Optional.empty());
}
@Override
public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecutionPlanContext context)
{
throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName()));
}
@Override
public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context)
{
TableFunctionProcessorProvider processorProvider = plannerContext.getFunctionManager().getTableFunctionProcessorProvider(node.getHandle());
if (node.getSource().isEmpty()) {
OperatorFactory operatorFactory = new LeafTableFunctionOperatorFactory(
context.getNextOperatorId(),
node.getId(),
node.getHandle().catalogHandle(),
processorProvider,
node.getHandle().functionHandle());
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
PhysicalOperation source = node.getSource().orElseThrow().accept(this, context);
int properChannelsCount = node.getProperOutputs().size();
long passThroughSourcesCount = node.getPassThroughSpecifications().stream()
.filter(PassThroughSpecification::declaredAsPassThrough)
.count();
List> requiredChannels = node.getRequiredSymbols().stream()
.map(list -> getChannelsForSymbols(list, source.getLayout()))
.collect(toImmutableList());
Optional> markerChannels = node.getMarkerSymbols()
.map(map -> map.entrySet().stream()
.collect(toImmutableMap(entry -> source.getLayout().get(entry.getKey()), entry -> source.getLayout().get(entry.getValue()))));
int channel = properChannelsCount;
ImmutableList.Builder passThroughColumnSpecifications = ImmutableList.builder();
for (PassThroughSpecification specification : node.getPassThroughSpecifications()) {
// the table function produces one index channel for each source declared as pass-through. They are laid out after the proper channels.
int indexChannel = specification.declaredAsPassThrough() ? channel++ : -1;
for (PassThroughColumn column : specification.columns()) {
passThroughColumnSpecifications.add(new PassThroughColumnSpecification(column.isPartitioningColumn(), source.getLayout().get(column.symbol()), indexChannel));
}
}
List partitionChannels = node.getSpecification()
.map(DataOrganizationSpecification::partitionBy)
.map(list -> getChannelsForSymbols(list, source.getLayout()))
.orElse(ImmutableList.of());
List sortChannels = ImmutableList.of();
List sortOrders = ImmutableList.of();
if (node.getSpecification().flatMap(DataOrganizationSpecification::orderingScheme).isPresent()) {
OrderingScheme orderingScheme = node.getSpecification().flatMap(DataOrganizationSpecification::orderingScheme).orElseThrow();
sortChannels = getChannelsForSymbols(orderingScheme.orderBy(), source.getLayout());
sortOrders = orderingScheme.orderingList();
}
OperatorFactory operator = new TableFunctionOperatorFactory(
context.getNextOperatorId(),
node.getId(),
processorProvider,
node.getHandle().catalogHandle(),
node.getHandle().functionHandle(),
properChannelsCount,
toIntExact(passThroughSourcesCount),
requiredChannels,
markerChannels,
passThroughColumnSpecifications.build(),
node.isPruneWhenEmpty(),
partitionChannels,
getChannelsForSymbols(ImmutableList.copyOf(node.getPrePartitioned()), source.getLayout()),
sortChannels,
sortOrders,
node.getPreSorted(),
source.getTypes(),
10_000,
pagesIndexFactory);
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
for (int i = 0; i < node.getProperOutputs().size(); i++) {
outputMappings.put(node.getProperOutputs().get(i), i);
}
List passThroughSymbols = node.getPassThroughSpecifications().stream()
.map(PassThroughSpecification::columns)
.flatMap(Collection::stream)
.map(PassThroughColumn::symbol)
.collect(toImmutableList());
int outputChannel = properChannelsCount;
for (Symbol passThroughSymbol : passThroughSymbols) {
outputMappings.put(passThroughSymbol, outputChannel++);
}
return new PhysicalOperation(operator, outputMappings.buildOrThrow(), source);
}
@Override
public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List orderBySymbols = node.getOrderingScheme().orderBy();
List sortChannels = new ArrayList<>();
List sortOrders = new ArrayList<>();
for (Symbol symbol : orderBySymbols) {
sortChannels.add(source.getLayout().get(symbol));
sortOrders.add(node.getOrderingScheme().ordering(symbol));
}
OperatorFactory operator = TopNOperator.createOperatorFactory(
context.getNextOperatorId(),
node.getId(),
source.getTypes(),
(int) node.getCount(),
sortChannels,
sortOrders,
plannerContext.getTypeOperators());
return new PhysicalOperation(operator, source.getLayout(), source);
}
@Override
public PhysicalOperation visitSort(SortNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List orderBySymbols = node.getOrderingScheme().orderBy();
List orderByChannels = getChannelsForSymbols(orderBySymbols, source.getLayout());
ImmutableList.Builder sortOrder = ImmutableList.builder();
for (Symbol symbol : orderBySymbols) {
sortOrder.add(node.getOrderingScheme().ordering(symbol));
}
ImmutableList.Builder outputChannels = ImmutableList.builder();
for (int i = 0; i < source.getTypes().size(); i++) {
outputChannels.add(i);
}
boolean spillEnabled = isSpillEnabled(session);
OperatorFactory operator = new OrderByOperatorFactory(
context.getNextOperatorId(),
node.getId(),
source.getTypes(),
outputChannels.build(),
10_000,
orderByChannels,
sortOrder.build(),
pagesIndexFactory,
spillEnabled,
Optional.of(spillerFactory),
orderingCompiler);
return new PhysicalOperation(operator, source.getLayout(), source);
}
@Override
public PhysicalOperation visitLimit(LimitNode node, LocalExecutionPlanContext context)
{
// Limit with ties should be rewritten at this point
checkState(node.getTiesResolvingScheme().isEmpty(), "Limit with ties not supported");
PhysicalOperation source = node.getSource().accept(this, context);
OperatorFactory operatorFactory = new LimitOperatorFactory(context.getNextOperatorId(), node.getId(), node.getCount());
return new PhysicalOperation(operatorFactory, source.getLayout(), source);
}
@Override
public PhysicalOperation visitDistinctLimit(DistinctLimitNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
Optional hashChannel = node.getHashSymbol().map(channelGetter(source));
List distinctChannels = getChannelsForSymbols(node.getDistinctSymbols(), source.getLayout());
OperatorFactory operatorFactory = new DistinctLimitOperatorFactory(
context.getNextOperatorId(),
node.getId(),
source.getTypes(),
distinctChannels,
node.getLimit(),
hashChannel,
hashStrategyCompiler);
return new PhysicalOperation(operatorFactory, makeLayout(node), source);
}
@Override
public PhysicalOperation visitGroupId(GroupIdNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
Map newLayout = new HashMap<>();
ImmutableList.Builder outputTypes = ImmutableList.builder();
int outputChannel = 0;
for (Symbol output : node.getDistinctGroupingSetSymbols()) {
newLayout.put(output, outputChannel++);
outputTypes.add(source.getTypes().get(source.getLayout().get(node.getGroupingColumns().get(output))));
}
Map argumentMappings = new HashMap<>();
for (Symbol output : node.getAggregationArguments()) {
int inputChannel = source.getLayout().get(output);
newLayout.put(output, outputChannel++);
outputTypes.add(source.getTypes().get(inputChannel));
argumentMappings.put(output, inputChannel);
}
// for every grouping set, create a mapping of all output to input channels (including arguments)
ImmutableList.Builder> mappings = ImmutableList.builder();
for (List groupingSet : node.getGroupingSets()) {
ImmutableMap.Builder setMapping = ImmutableMap.builder();
for (Symbol output : groupingSet) {
setMapping.put(newLayout.get(output), source.getLayout().get(node.getGroupingColumns().get(output)));
}
for (Symbol output : argumentMappings.keySet()) {
setMapping.put(newLayout.get(output), argumentMappings.get(output));
}
mappings.add(setMapping.buildOrThrow());
}
newLayout.put(node.getGroupIdSymbol(), outputChannel);
outputTypes.add(BIGINT);
OperatorFactory groupIdOperatorFactory = new GroupIdOperator.GroupIdOperatorFactory(context.getNextOperatorId(),
node.getId(),
outputTypes.build(),
mappings.build());
return new PhysicalOperation(groupIdOperatorFactory, newLayout, source);
}
@Override
public PhysicalOperation visitAggregation(AggregationNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
if (node.getGroupingKeys().isEmpty()) {
return planGlobalAggregation(node, source, context);
}
boolean spillEnabled = isSpillEnabled(session);
DataSize unspillMemoryLimit = getAggregationOperatorUnspillMemoryLimit(session);
return planGroupByAggregation(node, source, spillEnabled, unspillMemoryLimit, context);
}
@Override
public PhysicalOperation visitMarkDistinct(MarkDistinctNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
List channels = getChannelsForSymbols(node.getDistinctSymbols(), source.getLayout());
Optional hashChannel = node.getHashSymbol().map(channelGetter(source));
MarkDistinctOperatorFactory operator = new MarkDistinctOperatorFactory(context.getNextOperatorId(), node.getId(), source.getTypes(), channels, hashChannel, hashStrategyCompiler);
return new PhysicalOperation(operator, makeLayout(node), source);
}
@Override
public PhysicalOperation visitSample(SampleNode node, LocalExecutionPlanContext context)
{
// For system sample, the splits are already filtered out, so no specific action needs to be taken here
if (node.getSampleType() == SampleNode.Type.SYSTEM) {
return node.getSource().accept(this, context);
}
throw new UnsupportedOperationException("not yet implemented: " + node);
}
@Override
public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext context)
{
PlanNode sourceNode = node.getSource();
if (node.getSource() instanceof TableScanNode && getStaticFilter(node.getPredicate()).isEmpty()) {
// filter node contains only dynamic filter, fallback to normal table scan
return visitTableScan(node.getId(), (TableScanNode) node.getSource(), node.getPredicate(), context);
}
Expression filterExpression = node.getPredicate();
List outputSymbols = node.getOutputSymbols();
return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputSymbols), outputSymbols);
}
@Override
public PhysicalOperation visitProject(ProjectNode node, LocalExecutionPlanContext context)
{
PlanNode sourceNode;
Optional filterExpression = Optional.empty();
if (node.getSource() instanceof FilterNode filterNode) {
sourceNode = filterNode.getSource();
filterExpression = Optional.of(filterNode.getPredicate());
}
else {
sourceNode = node.getSource();
}
List outputSymbols = node.getOutputSymbols();
return visitScanFilterAndProject(context, node.getId(), sourceNode, filterExpression, node.getAssignments(), outputSymbols);
}
// TODO: This should be refactored, so that there's an optimizer that merges scan-filter-project into a single PlanNode
private PhysicalOperation visitScanFilterAndProject(
LocalExecutionPlanContext context,
PlanNodeId planNodeId,
PlanNode sourceNode,
Optional filterExpression,
Assignments assignments,
List outputSymbols)
{
// if source is a table scan we fold it directly into the filter and project
// otherwise we plan it as a normal operator
Map sourceLayout;
TableHandle table = null;
List columns = null;
PhysicalOperation source = null;
if (sourceNode instanceof TableScanNode tableScanNode) {
table = tableScanNode.getTable();
// extract the column handles and channel to type mapping
sourceLayout = new LinkedHashMap<>();
columns = new ArrayList<>();
int channel = 0;
for (Symbol symbol : tableScanNode.getOutputSymbols()) {
columns.add(tableScanNode.getAssignments().get(symbol));
Integer input = channel;
sourceLayout.put(symbol, input);
channel++;
}
}
//TODO: This is a simple hack, it will be replaced when we add ability to push down sampling into connectors.
// SYSTEM sampling is performed in the coordinator by dropping some random splits so the SamplingNode can be skipped here.
else if (sourceNode instanceof SampleNode sampleNode) {
checkArgument(sampleNode.getSampleType() == SampleNode.Type.SYSTEM, "%s sampling is not supported", sampleNode.getSampleType());
return visitScanFilterAndProject(context,
planNodeId,
sampleNode.getSource(),
filterExpression,
assignments,
outputSymbols);
}
else {
// plan source
source = sourceNode.accept(this, context);
sourceLayout = source.getLayout();
}
// build output mapping
ImmutableMap.Builder outputMappingsBuilder = ImmutableMap.builder();
for (int i = 0; i < outputSymbols.size(); i++) {
Symbol symbol = outputSymbols.get(i);
outputMappingsBuilder.put(symbol, i);
}
Map outputMappings = outputMappingsBuilder.buildOrThrow();
Optional staticFilters = filterExpression.flatMap(this::getStaticFilter);
DynamicFilter dynamicFilter = filterExpression
.filter(expression -> sourceNode instanceof TableScanNode)
.map(expression -> getDynamicFilter((TableScanNode) sourceNode, expression, context))
.orElse(DynamicFilter.EMPTY);
List projections = new ArrayList<>();
for (Symbol symbol : outputSymbols) {
projections.add(assignments.get(symbol));
}
Optional translatedFilter = staticFilters.map(filter -> toRowExpression(filter, sourceLayout));
List translatedProjections = projections.stream()
.map(expression -> toRowExpression(expression, sourceLayout))
.collect(toImmutableList());
try {
if (columns != null) {
Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(translatedFilter, translatedProjections, sourceNode.getId());
Supplier pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory(
context.getNextOperatorId(),
planNodeId,
sourceNode.getId(),
pageSourceProvider,
cursorProcessor,
pageProcessor,
table,
columns,
dynamicFilter,
getTypes(projections),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));
return new PhysicalOperation(operatorFactory, outputMappings);
}
Supplier pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId));
OperatorFactory operatorFactory = FilterAndProjectOperator.createOperatorFactory(
context.getNextOperatorId(),
planNodeId,
pageProcessor,
getTypes(projections),
getFilterAndProjectMinOutputPageSize(session),
getFilterAndProjectMinOutputPageRowCount(session));
return new PhysicalOperation(operatorFactory, outputMappings, source);
}
catch (TrinoException e) {
throw e;
}
catch (RuntimeException e) {
throw new TrinoException(
COMPILER_ERROR,
"Compiler failed. Possible reasons include: the query may have too many or too complex expressions, " +
"or the underlying tables may have too many columns",
e);
}
}
private RowExpression toRowExpression(Expression expression, Map layout)
{
return SqlToRowExpressionTranslator.translate(expression, layout, metadata, plannerContext.getTypeManager());
}
@Override
public PhysicalOperation visitTableScan(TableScanNode node, LocalExecutionPlanContext context)
{
return visitTableScan(node.getId(), node, TRUE, context);
}
private PhysicalOperation visitTableScan(PlanNodeId planNodeId, TableScanNode node, Expression filterExpression, LocalExecutionPlanContext context)
{
List columns = new ArrayList<>();
for (Symbol symbol : node.getOutputSymbols()) {
columns.add(node.getAssignments().get(symbol));
}
DynamicFilter dynamicFilter = getDynamicFilter(node, filterExpression, context);
OperatorFactory operatorFactory = new TableScanOperatorFactory(context.getNextOperatorId(), planNodeId, node.getId(), pageSourceProvider, node.getTable(), columns, dynamicFilter);
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
private Optional getStaticFilter(Expression filterExpression)
{
DynamicFilters.ExtractResult extractDynamicFilterResult = extractDynamicFilters(filterExpression);
Expression staticFilter = combineConjuncts(extractDynamicFilterResult.getStaticConjuncts());
if (staticFilter.equals(TRUE)) {
return Optional.empty();
}
return Optional.of(staticFilter);
}
private DynamicFilter getDynamicFilter(
TableScanNode tableScanNode,
Expression filterExpression,
LocalExecutionPlanContext context)
{
DynamicFilters.ExtractResult extractDynamicFilterResult = extractDynamicFilters(filterExpression);
List dynamicFilters = extractDynamicFilterResult.getDynamicConjuncts();
if (dynamicFilters.isEmpty()) {
return DynamicFilter.EMPTY;
}
log.debug("[TableScan] Dynamic filters: %s", dynamicFilters);
context.registerCoordinatorDynamicFilters(dynamicFilters);
return context.getDynamicFiltersCollector().createDynamicFilter(
dynamicFilters,
tableScanNode.getAssignments(),
plannerContext);
}
@Override
public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext context)
{
// a values node must have a single driver
context.setDriverInstanceCount(1);
if (node.getRowCount() == 0) {
OperatorFactory operatorFactory = new ValuesOperatorFactory(context.getNextOperatorId(), node.getId(), ImmutableList.of());
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
List outputTypes = getSymbolTypes(node.getOutputSymbols());
PageBuilder pageBuilder = new PageBuilder(node.getRowCount(), outputTypes);
for (int i = 0; i < node.getRowCount(); i++) {
// declare position for every row
pageBuilder.declarePosition();
// evaluate values for non-empty rows
if (node.getRows().isPresent()) {
Expression row = node.getRows().get().get(i);
checkState(row.type() instanceof RowType, "unexpected type of Values row: %s", row.type());
// evaluate the literal value
SqlRow result = (SqlRow) new IrExpressionInterpreter(row, plannerContext, session).evaluate();
int rawIndex = result.getRawIndex();
for (int j = 0; j < outputTypes.size(); j++) {
// divide row into fields
Block fieldBlock = result.getRawFieldBlock(j);
writeNativeValue(outputTypes.get(j), pageBuilder.getBlockBuilder(j), readNativeValue(outputTypes.get(j), fieldBlock, rawIndex));
}
}
}
OperatorFactory operatorFactory = new ValuesOperatorFactory(context.getNextOperatorId(), node.getId(), ImmutableList.of(pageBuilder.build()));
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
@Override
public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext context)
{
PhysicalOperation source = node.getSource().accept(this, context);
ImmutableList.Builder replicateTypes = ImmutableList.builder();
for (Symbol symbol : node.getReplicateSymbols()) {
replicateTypes.add(symbol.type());
}
List unnestSymbols = node.getMappings().stream()
.map(UnnestNode.Mapping::getInput)
.collect(toImmutableList());
ImmutableList.Builder unnestTypes = ImmutableList.builder();
for (Symbol symbol : unnestSymbols) {
unnestTypes.add(symbol.type());
}
Optional ordinalitySymbol = node.getOrdinalitySymbol();
Optional ordinalityType = ordinalitySymbol.map(Symbol::type);
ordinalityType.ifPresent(type -> checkState(type.equals(BIGINT), "Type of ordinalitySymbol must always be BIGINT."));
List replicateChannels = getChannelsForSymbols(node.getReplicateSymbols(), source.getLayout());
List unnestChannels = getChannelsForSymbols(unnestSymbols, source.getLayout());
// Source channels are always laid out first, followed by the unnested symbols
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
int channel = 0;
for (Symbol symbol : node.getReplicateSymbols()) {
outputMappings.put(symbol, channel);
channel++;
}
for (UnnestNode.Mapping mapping : node.getMappings()) {
for (Symbol unnestedSymbol : mapping.getOutputs()) {
outputMappings.put(unnestedSymbol, channel);
channel++;
}
}
if (ordinalitySymbol.isPresent()) {
outputMappings.put(ordinalitySymbol.get(), channel);
channel++;
}
boolean outer = node.getJoinType() == LEFT || node.getJoinType() == FULL;
OperatorFactory operatorFactory = new UnnestOperator.UnnestOperatorFactory(
context.getNextOperatorId(),
node.getId(),
replicateChannels,
replicateTypes.build(),
unnestChannels,
unnestTypes.build(),
ordinalityType.isPresent(),
outer);
return new PhysicalOperation(operatorFactory, outputMappings.buildOrThrow(), source);
}
private ImmutableMap makeLayout(PlanNode node)
{
return makeLayoutFromOutputSymbols(node.getOutputSymbols());
}
private ImmutableMap makeLayoutFromOutputSymbols(List outputSymbols)
{
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
int channel = 0;
for (Symbol symbol : outputSymbols) {
outputMappings.put(symbol, channel);
channel++;
}
return outputMappings.buildOrThrow();
}
@Override
public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPlanContext context)
{
checkState(context.getIndexSourceContext().isPresent(), "Must be in an index source context");
IndexSourceContext indexSourceContext = context.getIndexSourceContext().get();
SetMultimap indexLookupToProbeInput = indexSourceContext.getIndexLookupToProbeInput();
checkState(indexLookupToProbeInput.keySet().equals(node.getLookupSymbols()));
// Finalize the symbol lookup layout for the index source
List lookupSymbolSchema = ImmutableList.copyOf(node.getLookupSymbols());
// Identify how to remap the probe key Input to match the source index lookup layout
ImmutableList.Builder remappedProbeKeyChannelsBuilder = ImmutableList.builder();
// Identify overlapping fields that can produce the same lookup symbol.
// We will filter incoming keys to ensure that overlapping fields will have the same value.
ImmutableList.Builder> overlappingFieldSetsBuilder = ImmutableList.builder();
for (Symbol lookupSymbol : lookupSymbolSchema) {
Set potentialProbeInputs = indexLookupToProbeInput.get(lookupSymbol);
checkState(!potentialProbeInputs.isEmpty(), "Must have at least one source from the probe input");
if (potentialProbeInputs.size() > 1) {
overlappingFieldSetsBuilder.add(ImmutableSet.copyOf(potentialProbeInputs));
}
remappedProbeKeyChannelsBuilder.add(Iterables.getFirst(potentialProbeInputs, null));
}
List> overlappingFieldSets = overlappingFieldSetsBuilder.build();
List remappedProbeKeyChannels = remappedProbeKeyChannelsBuilder.build();
Function probeKeyNormalizer = recordSet -> {
if (!overlappingFieldSets.isEmpty()) {
recordSet = new FieldSetFilteringRecordSet(plannerContext.getTypeOperators(), recordSet, overlappingFieldSets);
}
return new MappedRecordSet(recordSet, remappedProbeKeyChannels);
};
// Declare the input and output schemas for the index and acquire the actual Index
List lookupSchema = Lists.transform(lookupSymbolSchema, forMap(node.getAssignments()));
List outputSchema = Lists.transform(node.getOutputSymbols(), forMap(node.getAssignments()));
ConnectorIndex index = indexManager.getIndex(session, node.getIndexHandle(), lookupSchema, outputSchema);
OperatorFactory operatorFactory = new IndexSourceOperator.IndexSourceOperatorFactory(context.getNextOperatorId(), node.getId(), index, probeKeyNormalizer);
return new PhysicalOperation(operatorFactory, makeLayout(node));
}
/**
* This method creates a mapping from each index source lookup symbol (directly applied to the index)
* to the corresponding probe key Input
*/
private SetMultimap mapIndexSourceLookupSymbolToProbeKeyInput(IndexJoinNode node, Map probeKeyLayout)
{
Set indexJoinSymbols = node.getCriteria().stream()
.map(IndexJoinNode.EquiJoinClause::getIndex)
.collect(toImmutableSet());
// Trace the index join symbols to the index source lookup symbols
// Map: Index join symbol => Index source lookup symbol
Map indexKeyTrace = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), indexJoinSymbols);
// Map the index join symbols to the probe key Input
Multimap indexToProbeKeyInput = HashMultimap.create();
for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) {
indexToProbeKeyInput.put(clause.getIndex(), probeKeyLayout.get(clause.getProbe()));
}
// Create the mapping from index source look up symbol to probe key Input
ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder();
for (Map.Entry entry : indexKeyTrace.entrySet()) {
Symbol indexJoinSymbol = entry.getKey();
Symbol indexLookupSymbol = entry.getValue();
builder.putAll(indexLookupSymbol, indexToProbeKeyInput.get(indexJoinSymbol));
}
return builder.build();
}
@Override
public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanContext context)
{
List clauses = node.getCriteria();
List probeSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe);
List indexSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex);
// Plan probe side
PhysicalOperation probeSource = node.getProbeSource().accept(this, context);
List probeChannels = getChannelsForSymbols(probeSymbols, probeSource.getLayout());
OptionalInt probeHashChannel = node.getProbeHashSymbol().map(channelGetter(probeSource))
.map(OptionalInt::of).orElse(OptionalInt.empty());
// The probe key channels will be handed to the index according to probeSymbol order
Map probeKeyLayout = new HashMap<>();
for (int i = 0; i < probeSymbols.size(); i++) {
// Duplicate symbols can appear and we only need to take one of the Inputs
probeKeyLayout.put(probeSymbols.get(i), i);
}
// Plan the index source side
SetMultimap indexLookupToProbeInput = mapIndexSourceLookupSymbolToProbeKeyInput(node, probeKeyLayout);
LocalExecutionPlanContext indexContext = context.createIndexSourceSubContext(new IndexSourceContext(indexLookupToProbeInput));
PhysicalOperation indexSource = node.getIndexSource().accept(this, indexContext);
List indexOutputChannels = getChannelsForSymbols(indexSymbols, indexSource.getLayout());
OptionalInt indexHashChannel = node.getIndexHashSymbol().map(channelGetter(indexSource))
.map(OptionalInt::of).orElse(OptionalInt.empty());
// Identify just the join keys/channels needed for lookup by the index source (does not have to use all of them).
Set indexSymbolsNeededBySource = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), ImmutableSet.copyOf(indexSymbols)).keySet();
Set lookupSourceInputChannels = node.getCriteria().stream()
.filter(equiJoinClause -> indexSymbolsNeededBySource.contains(equiJoinClause.getIndex()))
.map(IndexJoinNode.EquiJoinClause::getProbe)
.map(probeKeyLayout::get)
.collect(toImmutableSet());
Optional dynamicTupleFilterFactory = Optional.empty();
if (lookupSourceInputChannels.size() < probeKeyLayout.values().size()) {
int[] nonLookupInputChannels = Ints.toArray(node.getCriteria().stream()
.filter(equiJoinClause -> !indexSymbolsNeededBySource.contains(equiJoinClause.getIndex()))
.map(IndexJoinNode.EquiJoinClause::getProbe)
.map(probeKeyLayout::get)
.collect(toImmutableList()));
int[] nonLookupOutputChannels = Ints.toArray(node.getCriteria().stream()
.filter(equiJoinClause -> !indexSymbolsNeededBySource.contains(equiJoinClause.getIndex()))
.map(IndexJoinNode.EquiJoinClause::getIndex)
.map(indexSource.getLayout()::get)
.collect(toImmutableList()));
int filterOperatorId = indexContext.getNextOperatorId();
dynamicTupleFilterFactory = Optional.of(new DynamicTupleFilterFactory(
filterOperatorId,
node.getId(),
nonLookupInputChannels,
nonLookupOutputChannels,
indexSource.getTypes(),
pageFunctionCompiler,
blockTypeOperators));
}
IndexBuildDriverFactoryProvider indexBuildDriverFactoryProvider = new IndexBuildDriverFactoryProvider(
indexContext.getNextPipelineId(),
indexContext.getNextOperatorId(),
node.getId(),
indexContext.isInputDriver(),
indexSource.getTypes(),
indexSource.getOperatorFactories(),
dynamicTupleFilterFactory);
IndexLookupSourceFactory indexLookupSourceFactory = new IndexLookupSourceFactory(
lookupSourceInputChannels,
indexOutputChannels,
indexHashChannel,
indexSource.getTypes(),
indexBuildDriverFactoryProvider,
maxIndexMemorySize,
indexJoinLookupStats,
SystemSessionProperties.isShareIndexLoading(session),
pagesIndexFactory,
hashStrategyCompiler,
blockTypeOperators);
indexLookupSourceFactory.setTaskContext(context.taskContext);
JoinBridgeManager lookupSourceFactoryManager = new JoinBridgeManager<>(
false,
indexLookupSourceFactory,
indexLookupSourceFactory.getOutputTypes());
ImmutableMap.Builder outputMappings = ImmutableMap.builder();
outputMappings.putAll(probeSource.getLayout());
// inputs from index side of the join are laid out following the input from the probe side,
// so adjust the channel ids but keep the field layouts intact
int offset = probeSource.getTypes().size();
for (Map.Entry entry : indexSource.getLayout().entrySet()) {
Integer input = entry.getValue();
outputMappings.put(entry.getKey(), offset + input);
}
OperatorFactory lookupJoinOperatorFactory;
OptionalInt totalOperatorsCount = context.getDriverInstanceCount();
// We use spilling operator since Non-spilling one does not support index lookup sources
lookupJoinOperatorFactory = switch (node.getType()) {
case INNER -> spillingJoin(
JoinOperatorType.innerJoin(false, false),
context.getNextOperatorId(),
node.getId(),
lookupSourceFactoryManager,
false,
probeSource.getTypes(),
probeChannels,
probeHashChannel,
Optional.empty(),
totalOperatorsCount,
unsupportedPartitioningSpillerFactory(),
typeOperators);
case SOURCE_OUTER -> spillingJoin(
JoinOperatorType.probeOuterJoin(false),
context.getNextOperatorId(),
node.getId(),
lookupSourceFactoryManager,
false,
probeSource.getTypes(),
probeChannels,
probeHashChannel,
Optional.empty(),
totalOperatorsCount,
unsupportedPartitioningSpillerFactory(),
typeOperators);
};
return new PhysicalOperation(lookupJoinOperatorFactory, outputMappings.buildOrThrow(), probeSource);
}
@Override
public PhysicalOperation visitJoin(JoinNode node, LocalExecutionPlanContext context)
{
// Register dynamic filters, allowing the scan operators to wait for the collection completion.
// Skip dynamic filters that are not used locally (e.g. in case of distributed joins).
Set localDynamicFilters = node.getDynamicFilters().keySet().stream()
.filter(getConsumedDynamicFilterIds(node.getLeft())::contains)
.collect(toImmutableSet());
context.getDynamicFiltersCollector().register(localDynamicFilters);
if (node.isCrossJoin()) {
return createNestedLoopJoin(node, localDynamicFilters, context);
}
List clauses = node.getCriteria();
List leftSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft);
List rightSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getRight);
return switch (node.getType()) {
case INNER, LEFT, RIGHT, FULL ->
createLookupJoin(node, node.getLeft(), leftSymbols, node.getLeftHashSymbol(), node.getRight(), rightSymbols, node.getRightHashSymbol(), localDynamicFilters, context);
};
}
@Override
public PhysicalOperation visitSpatialJoin(SpatialJoinNode node, LocalExecutionPlanContext context)
{
Expression filterExpression = node.getFilter();
List spatialFunctions = extractSupportedSpatialFunctions(filterExpression);
for (Call spatialFunction : spatialFunctions) {
Optional operation = tryCreateSpatialJoin(context, node, removeExpressionFromFilter(filterExpression, spatialFunction), spatialFunction, Optional.empty(), Optional.empty());
if (operation.isPresent()) {
return operation.get();
}
}
List spatialComparisons = extractSupportedSpatialComparisons(filterExpression);
for (Comparison spatialComparison : spatialComparisons) {
if (spatialComparison.operator() == LESS_THAN || spatialComparison.operator() == LESS_THAN_OR_EQUAL) {
// ST_Distance(a, b) <= r
Expression radius = spatialComparison.right();
if (radius instanceof Reference && getSymbolReferences(node.getRight().getOutputSymbols()).contains(radius) || radius instanceof Constant) {
Call spatialFunction = (Call) spatialComparison.left();
Optional operation = tryCreateSpatialJoin(context, node, removeExpressionFromFilter(filterExpression, spatialComparison), spatialFunction, Optional.of(radius), Optional.of(spatialComparison.operator()));
if (operation.isPresent()) {
return operation.get();
}
}
}
}
throw new VerifyException("No valid spatial relationship found for spatial join");
}
private Optional tryCreateSpatialJoin(
LocalExecutionPlanContext context,
SpatialJoinNode node,
Optional filterExpression,
Call spatialFunction,
Optional radius,
Optional comparisonOperator)
{
List arguments = spatialFunction.arguments();
verify(arguments.size() == 2);
if (!(arguments.get(0) instanceof Reference firstSymbol) || !(arguments.get(1) instanceof Reference secondSymbol)) {
return Optional.empty();
}
PlanNode probeNode = node.getLeft();
Set probeSymbols = getSymbolReferences(probeNode.getOutputSymbols());
PlanNode buildNode = node.getRight();
Set buildSymbols = getSymbolReferences(buildNode.getOutputSymbols());
Optional radiusSymbol = Optional.empty();
OptionalDouble constantRadius = OptionalDouble.empty();
if (radius.isPresent()) {
Expression expression = radius.get();
if (expression instanceof Reference reference) {
radiusSymbol = Optional.of(Symbol.from(reference));
}
else if (expression instanceof Constant constant) {
constantRadius = OptionalDouble.of((Double) constant.value());
}
else {
throw new IllegalArgumentException("Unexpected expression for radius: " + expression);
}
}
if (probeSymbols.contains(firstSymbol) && buildSymbols.contains(secondSymbol)) {
return Optional.of(createSpatialLookupJoin(
node,
probeNode,
Symbol.from(firstSymbol),
buildNode,
Symbol.from(secondSymbol),
radiusSymbol,
constantRadius,
spatialTest(spatialFunction, true, comparisonOperator),
filterExpression,
context));
}
if (probeSymbols.contains(secondSymbol) && buildSymbols.contains(firstSymbol)) {
return Optional.of(createSpatialLookupJoin(
node,
probeNode,
Symbol.from(secondSymbol),
buildNode,
Symbol.from(firstSymbol),
radiusSymbol,
constantRadius,
spatialTest(spatialFunction, false, comparisonOperator),
filterExpression,
context));
}
return Optional.empty();
}
private Optional removeExpressionFromFilter(Expression filter, Expression expression)
{
Expression updatedJoinFilter = replaceExpression(filter, ImmutableMap.of(expression, TRUE));
return updatedJoinFilter.equals(TRUE) ? Optional.empty() : Optional.of(updatedJoinFilter);
}
private SpatialPredicate spatialTest(Call call, boolean probeFirst, Optional comparisonOperator)
{
CatalogSchemaFunctionName functionName = call.function().name();
if (functionName.equals(builtinFunctionName(ST_CONTAINS))) {
if (probeFirst) {
return (buildGeometry, probeGeometry, radius) -> probeGeometry.contains(buildGeometry);
}
return (buildGeometry, probeGeometry, radius) -> buildGeometry.contains(probeGeometry);
}
if (functionName.equals(builtinFunctionName(ST_WITHIN))) {
if (probeFirst) {
return (buildGeometry, probeGeometry, radius) -> probeGeometry.within(buildGeometry);
}
return (buildGeometry, probeGeometry, radius) -> buildGeometry.within(probeGeometry);
}
if (functionName.equals(builtinFunctionName(ST_INTERSECTS))) {
return (buildGeometry, probeGeometry, radius) -> buildGeometry.intersects(probeGeometry);
}
if (functionName.equals(builtinFunctionName(ST_DISTANCE))) {
if (comparisonOperator.orElseThrow() == LESS_THAN) {
return (buildGeometry, probeGeometry, radius) -> buildGeometry.distance(probeGeometry) < radius.getAsDouble();
}
if (comparisonOperator.get() == LESS_THAN_OR_EQUAL) {
return (buildGeometry, probeGeometry, radius) -> buildGeometry.distance(probeGeometry) <= radius.getAsDouble();
}
throw new UnsupportedOperationException("Unsupported comparison operator: " + comparisonOperator.get());
}
throw new UnsupportedOperationException("Unsupported spatial function: " + functionName);
}
private Set getSymbolReferences(Collection symbols)
{
return symbols.stream().map(Symbol::toSymbolReference).collect(toImmutableSet());
}
private PhysicalOperation createNestedLoopJoin(JoinNode node, Set localDynamicFilters, LocalExecutionPlanContext context)
{
PhysicalOperation probeSource = node.getLeft().accept(this, context);
LocalExecutionPlanContext buildContext = context.createSubContext();
PhysicalOperation buildSource = node.getRight().accept(this, buildContext);
checkArgument(node.getType() == INNER, "NestedLoopJoin is only used for inner join");
JoinBridgeManager nestedLoopJoinBridgeManager = new JoinBridgeManager<>(
false,
new NestedLoopJoinPagesSupplier(),
buildSource.getTypes());
NestedLoopBuildOperatorFactory nestedLoopBuildOperatorFactory = new NestedLoopBuildOperatorFactory(
buildContext.getNextOperatorId(),
node.getId(),
nestedLoopJoinBridgeManager);
int partitionCount = buildContext.getDriverInstanceCount().orElse(1);
checkArgument(partitionCount == 1, "Expected local execution to not be parallel");
int operatorId = buildContext.getNextOperatorId();
boolean partitioned = !isBuildSideReplicated(node);
Optional