All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.optimizations.DeterminePartitionCount Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;

import java.util.List;
import java.util.Optional;
import java.util.function.ToDoubleFunction;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinPartitionCount;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinPartitionCountForWrite;
import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
import static io.trino.SystemSessionProperties.getMinHashPartitionCount;
import static io.trino.SystemSessionProperties.getMinHashPartitionCountForWrite;
import static io.trino.SystemSessionProperties.getMinInputRowsPerTask;
import static io.trino.SystemSessionProperties.getMinInputSizePerTask;
import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.SystemSessionProperties.isDeterminePartitionCountForWriteEnabled;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith;
import static java.lang.Double.isNaN;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;

/**
 * This rule looks at the amount of data read and processed by the query to determine the value of partition count
 * used for remote partitioned exchanges. It helps to increase the concurrency of the engine in the case of large cluster.
 * This rule is also cautious about lack of or incorrect statistics therefore it skips for input multiplying nodes like
 * CROSS JOIN or UNNEST.
 * 

* E.g. 1: * Given query: SELECT count(column_a) FROM table_with_stats_a group by column_b * config: * MIN_INPUT_SIZE_PER_TASK: 500 MB * Input table data size: 1000 MB * Estimated partition count: Input table data size / MIN_INPUT_SIZE_PER_TASK => 2 *

* E.g. 2: * Given query: SELECT * FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_b = b.column_b * config: * MIN_INPUT_SIZE_PER_TASK: 500 MB * Input tables data size: 1000 MB * Join output data size: 5000 MB * Estimated partition count: max((Input table data size / MIN_INPUT_SIZE_PER_TASK), (Join output data size / MIN_INPUT_SIZE_PER_TASK)) => 10 */ public class DeterminePartitionCount implements PlanOptimizer { private static final Logger log = Logger.get(DeterminePartitionCount.class); private static final List> INSERT_NODES = ImmutableList.of(TableExecuteNode.class, TableWriterNode.class, MergeWriterNode.class); private final StatsCalculator statsCalculator; private final TaskCountEstimator taskCountEstimator; public DeterminePartitionCount(StatsCalculator statsCalculator, TaskCountEstimator taskCountEstimator) { this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); } @Override public PlanNode optimize( PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); requireNonNull(tableStatsProvider, "tableStatsProvider is null"); requireNonNull(taskCountEstimator, "taskCountEstimator is null"); // Skip partition count determination if no partitioned remote exchanges exist in the plan anyway if (!isEligibleRemoteExchangePresent(plan)) { return plan; } // Unless enabled, skip for write nodes since writing partitioned data with small amount of nodes could cause // memory related issues even when the amount of data is small. boolean isWriteQuery = PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(INSERT_NODES).matches(); if (isWriteQuery && !isDeterminePartitionCountForWriteEnabled(session)) { return plan; } try { return determinePartitionCount(plan, session, types, tableStatsProvider, isWriteQuery) .map(partitionCount -> rewriteWith(new Rewriter(partitionCount), plan)) .orElse(plan); } catch (RuntimeException e) { log.warn(e, "Error occurred when determining hash partition count for query %s", session.getQueryId()); } return plan; } private Optional determinePartitionCount( PlanNode plan, Session session, TypeProvider types, TableStatsProvider tableStatsProvider, boolean isWriteQuery) { long minInputSizePerTask = getMinInputSizePerTask(session).toBytes(); long minInputRowsPerTask = getMinInputRowsPerTask(session); if (minInputSizePerTask == 0 || minInputRowsPerTask == 0) { return Optional.empty(); } // Skip for expanding plan nodes like CROSS JOIN or UNNEST which can substantially increase the amount of data. if (isInputMultiplyingPlanNodePresent(plan)) { return Optional.empty(); } int minPartitionCount; int maxPartitionCount; if (getRetryPolicy(session).equals(RetryPolicy.TASK)) { if (isWriteQuery) { minPartitionCount = getFaultTolerantExecutionMinPartitionCountForWrite(session); } else { minPartitionCount = getFaultTolerantExecutionMinPartitionCount(session); } maxPartitionCount = getFaultTolerantExecutionMaxPartitionCount(session); } else { if (isWriteQuery) { minPartitionCount = getMinHashPartitionCountForWrite(session); } else { minPartitionCount = getMinHashPartitionCount(session); } maxPartitionCount = getMaxHashPartitionCount(session); } verify(minPartitionCount <= maxPartitionCount, "minPartitionCount %s larger than maxPartitionCount %s", minPartitionCount, maxPartitionCount); StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types, tableStatsProvider); long queryMaxMemoryPerNode = getQueryMaxMemoryPerNode(session).toBytes(); // Calculate partition count based on nodes output data size and rows Optional partitionCountBasedOnOutputSize = getPartitionCountBasedOnOutputSize( plan, statsProvider, types, minInputSizePerTask, queryMaxMemoryPerNode); Optional partitionCountBasedOnRows = getPartitionCountBasedOnRows(plan, statsProvider, minInputRowsPerTask); if (partitionCountBasedOnOutputSize.isEmpty() || partitionCountBasedOnRows.isEmpty()) { return Optional.empty(); } int partitionCount = max( // Consider both output size and rows count to estimate the value of partition count. This is essential // because huge number of small size rows can be cpu intensive for some operators. On the other // hand, small number of rows with considerable size in bytes can be memory intensive. max(partitionCountBasedOnOutputSize.get(), partitionCountBasedOnRows.get()), minPartitionCount); if (partitionCount >= maxPartitionCount) { return Optional.empty(); } if (partitionCount * 2 >= taskCountEstimator.estimateHashedTaskCount(session) && !getRetryPolicy(session).equals(RetryPolicy.TASK)) { // Do not cap partition count if it's already close to the possible number of tasks. return Optional.empty(); } log.debug("Estimated remote exchange partition count for query %s is %s", session.getQueryId(), partitionCount); return Optional.of(partitionCount); } private static Optional getPartitionCountBasedOnOutputSize( PlanNode plan, StatsProvider statsProvider, TypeProvider types, long minInputSizePerTask, long queryMaxMemoryPerNode) { double sourceTablesOutputSize = getSourceNodesOutputStats( plan, node -> statsProvider.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); double expandingNodesMaxOutputSize = getExpandingNodesMaxOutputStats( plan, node -> statsProvider.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); if (isNaN(sourceTablesOutputSize) || isNaN(expandingNodesMaxOutputSize)) { return Optional.empty(); } int partitionCountBasedOnOutputSize = getPartitionCount( max(sourceTablesOutputSize, expandingNodesMaxOutputSize), minInputSizePerTask); // Calculate partition count based on maximum memory usage. This is based on the assumption that // generally operators won't keep data in memory more than the size of input data. int partitionCountBasedOnMemory = (int) ((max(sourceTablesOutputSize, expandingNodesMaxOutputSize) * 2) / queryMaxMemoryPerNode); return Optional.of(max(partitionCountBasedOnOutputSize, partitionCountBasedOnMemory)); } private static Optional getPartitionCountBasedOnRows(PlanNode plan, StatsProvider statsProvider, long minInputRowsPerTask) { double sourceTablesRowCount = getSourceNodesOutputStats(plan, node -> statsProvider.getStats(node).getOutputRowCount()); double expandingNodesMaxRowCount = getExpandingNodesMaxOutputStats(plan, node -> statsProvider.getStats(node).getOutputRowCount()); if (isNaN(sourceTablesRowCount) || isNaN(expandingNodesMaxRowCount)) { return Optional.empty(); } return Optional.of(getPartitionCount( max(sourceTablesRowCount, expandingNodesMaxRowCount), minInputRowsPerTask)); } private static int getPartitionCount(double outputStats, long minInputStatsPerTask) { return max((int) (outputStats / minInputStatsPerTask), 1); } private static boolean isInputMultiplyingPlanNodePresent(PlanNode root) { return PlanNodeSearcher.searchFrom(root) .where(DeterminePartitionCount::isInputMultiplyingPlanNode) .matches(); } private static boolean isInputMultiplyingPlanNode(PlanNode node) { if (node instanceof UnnestNode) { return true; } if (node instanceof JoinNode joinNode) { // Skip for cross join if (joinNode.isCrossJoin()) { // If any of the input node is scalar then there's no need to skip cross join return !isAtMostScalar(joinNode.getRight()) && !isAtMostScalar(joinNode.getLeft()); } // Skip for joins with multi keys since output row count stats estimation can wrong due to // low correlation between multiple join keys. return joinNode.getCriteria().size() > 1; } return false; } private static double getExpandingNodesMaxOutputStats(PlanNode root, ToDoubleFunction statsMapper) { List expandingNodes = PlanNodeSearcher.searchFrom(root) .where(DeterminePartitionCount::isExpandingPlanNode) .findAll(); return expandingNodes.stream() .mapToDouble(statsMapper) .max() .orElse(0); } private static boolean isExpandingPlanNode(PlanNode node) { return node instanceof JoinNode // consider union node and exchange node with multiple sources as expanding since it merge the rows // from two different sources, thus more data is transferred over the network. || node instanceof UnionNode || (node instanceof ExchangeNode && node.getSources().size() > 1); } private static double getSourceNodesOutputStats(PlanNode root, ToDoubleFunction statsMapper) { List sourceNodes = PlanNodeSearcher.searchFrom(root) .whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class) .findAll(); return sourceNodes.stream() .mapToDouble(statsMapper) .sum(); } private static boolean isEligibleRemoteExchangePresent(PlanNode root) { return PlanNodeSearcher.searchFrom(root) .where(node -> node instanceof ExchangeNode exchangeNode && isEligibleRemoteExchange(exchangeNode)) .matches(); } private static boolean isEligibleRemoteExchange(ExchangeNode exchangeNode) { if (exchangeNode.getScope() != REMOTE || exchangeNode.getType() != REPARTITION) { return false; } PartitioningHandle partitioningHandle = exchangeNode.getPartitioningScheme().getPartitioning().getHandle(); return !partitioningHandle.isScaleWriters() && !partitioningHandle.isSingleNode() && partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle; } private static class Rewriter extends SimplePlanRewriter { private final int partitionCount; private Rewriter(int partitionCount) { this.partitionCount = partitionCount; } @Override public PlanNode visitExchange(ExchangeNode node, RewriteContext context) { List sources = node.getSources().stream() .map(context::rewrite) .collect(toImmutableList()); PartitioningScheme partitioningScheme = node.getPartitioningScheme(); if (isEligibleRemoteExchange(node)) { partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(partitionCount)); } return new ExchangeNode( node.getId(), node.getType(), node.getScope(), partitioningScheme, sources, node.getInputs(), node.getOrderingScheme()); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy