io.trino.operator.output.SkewedPartitionRebalancer Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.output;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.execution.resourcegroups.IndexedPriorityQueue;
import io.trino.operator.PartitionFunction;
import io.trino.spi.connector.ConnectorBucketNodeMap;
import io.trino.spi.type.Type;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.SystemPartitioningHandle;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.stream.IntStream;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getMaxMemoryPerPartitionWriter;
import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode;
import static io.trino.sql.planner.PartitioningHandle.isScaledWriterHashDistribution;
import static java.lang.Double.isNaN;
import static java.lang.Math.ceil;
import static java.lang.Math.floorMod;
import static java.lang.Math.max;
/**
* Helps in distributing big or skewed partitions across available tasks to improve the performance of
* partitioned writes.
*
* This rebalancer initialize a bunch of buckets for each task based on a given taskBucketCount and then tries to
* uniformly distribute partitions across those buckets. This helps to mitigate two problems:
* 1. Mitigate skewness across tasks.
* 2. Scale few big partitions across tasks even if there's no skewness among them. This will essentially speed the
* local scaling without impacting much overall resource utilization.
*
* Example:
*
* Before: 3 tasks, 3 buckets per task, and 2 skewed partitions
* Task1 Task2 Task3
* Bucket1 (Part 1) Bucket1 (Part 2) Bucket1
* Bucket2 Bucket2 Bucket2
* Bucket3 Bucket3 Bucket3
*
* After rebalancing:
* Task1 Task2 Task3
* Bucket1 (Part 1) Bucket1 (Part 2) Bucket1 (Part 1)
* Bucket2 (Part 2) Bucket2 (Part 1) Bucket2 (Part 2)
* Bucket3 Bucket3 Bucket3
*/
@ThreadSafe
public class SkewedPartitionRebalancer
{
private static final Logger log = Logger.get(SkewedPartitionRebalancer.class);
// Keep the scale writers partition count big enough such that we could rebalance skewed partitions
// at more granularity, thus leading to less resource utilization at writer stage.
private static final int SCALE_WRITERS_PARTITION_COUNT = 4096;
// If the percentage difference between the two different task buckets with maximum and minimum processed bytes
// since last rebalance is above 0.7 (or 70%), then we consider them skewed.
private static final double TASK_BUCKET_SKEWNESS_THRESHOLD = 0.7;
private final int partitionCount;
private final int taskCount;
private final int taskBucketCount;
private final long minPartitionDataProcessedRebalanceThreshold;
private final long minDataProcessedRebalanceThreshold;
private final int maxPartitionsToRebalance;
private final AtomicLongArray partitionRowCount;
private final AtomicLong dataProcessed;
private final AtomicLong dataProcessedAtLastRebalance;
private final AtomicInteger numOfRebalancedPartitions;
@GuardedBy("this")
private final long[] partitionDataSize;
@GuardedBy("this")
private final long[] partitionDataSizeAtLastRebalance;
@GuardedBy("this")
private final long[] partitionDataSizeSinceLastRebalancePerTask;
@GuardedBy("this")
private final long[] estimatedTaskBucketDataSizeSinceLastRebalance;
private final List> partitionAssignments;
public static boolean checkCanScalePartitionsRemotely(Session session, int taskCount, PartitioningHandle partitioningHandle, NodePartitioningManager nodePartitioningManager)
{
// In case of connector partitioning, check if bucketToPartitions has fixed mapping or not. If it is fixed
// then we can't distribute a bucket across multiple tasks.
boolean hasFixedNodeMapping = partitioningHandle.getCatalogHandle()
.map(handle -> nodePartitioningManager.getConnectorBucketNodeMap(session, partitioningHandle)
.map(ConnectorBucketNodeMap::hasFixedMapping)
.orElse(false))
.orElse(false);
// Use skewed partition rebalancer only when there are more than one tasks
return taskCount > 1 && !hasFixedNodeMapping && isScaledWriterHashDistribution(partitioningHandle);
}
public static PartitionFunction createPartitionFunction(
Session session,
NodePartitioningManager nodePartitioningManager,
PartitioningScheme scheme,
List partitionChannelTypes)
{
PartitioningHandle handle = scheme.getPartitioning().getHandle();
// In case of SystemPartitioningHandle we can use arbitrary bucket count so that skewness mitigation
// is more granular.
// Whereas, in the case of connector partitioning we have to use connector provided bucketCount
// otherwise buckets will get mapped to tasks incorrectly which could affect skewness handling.
//
// For example: if there are 2 hive buckets, 2 tasks, and 10 artificial bucketCount then this
// could be how actual hive buckets are mapped to artificial buckets and tasks.
//
// hive bucket artificial bucket tasks
// 0 0, 2, 4, 6, 8 0, 0, 0, 0, 0
// 1 1, 3, 5, 7, 9 1, 1, 1, 1, 1
//
// Here rebalancing will happen slowly even if there's a skewness at task 0 or hive bucket 0 because
// five artificial buckets resemble the first hive bucket. Therefore, these artificial buckets
// have to write minPartitionDataProcessedRebalanceThreshold before they get scaled to task 1, which is slow
// compared to only a single hive bucket reaching the min limit.
int bucketCount = (handle.getConnectorHandle() instanceof SystemPartitioningHandle)
? SCALE_WRITERS_PARTITION_COUNT
: nodePartitioningManager.getBucketNodeMap(session, handle).getBucketCount();
return nodePartitioningManager.getPartitionFunction(
session,
scheme,
partitionChannelTypes,
IntStream.range(0, bucketCount).toArray());
}
public static int getMaxWritersBasedOnMemory(Session session)
{
return (int) ceil((double) getQueryMaxMemoryPerNode(session).toBytes() / getMaxMemoryPerPartitionWriter(session).toBytes());
}
public static int getScaleWritersMaxSkewedPartitions(Session session)
{
// Set the value of maxSkewedPartitions to scale to 60% of maximum number of writers possible per node.
return (int) (getMaxWritersBasedOnMemory(session) * 0.60);
}
public static int getTaskCount(PartitioningScheme partitioningScheme)
{
// Todo: Handle skewness if there are more nodes/tasks than the buckets coming from connector
// https://github.com/trinodb/trino/issues/17254
int[] bucketToPartition = partitioningScheme.getBucketToPartition()
.orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before calculating taskCount"));
// Buckets can be greater than the actual partitions or tasks. Therefore, use max to find the actual taskCount.
return IntStream.of(bucketToPartition).max().getAsInt() + 1;
}
public SkewedPartitionRebalancer(
int partitionCount,
int taskCount,
int taskBucketCount,
long minPartitionDataProcessedRebalanceThreshold,
long maxDataProcessedRebalanceThreshold,
int maxPartitionsToRebalance)
{
this.partitionCount = partitionCount;
this.taskCount = taskCount;
this.taskBucketCount = taskBucketCount;
this.minPartitionDataProcessedRebalanceThreshold = minPartitionDataProcessedRebalanceThreshold;
this.minDataProcessedRebalanceThreshold = max(minPartitionDataProcessedRebalanceThreshold, maxDataProcessedRebalanceThreshold);
this.maxPartitionsToRebalance = maxPartitionsToRebalance;
this.partitionRowCount = new AtomicLongArray(partitionCount);
this.dataProcessed = new AtomicLong();
this.dataProcessedAtLastRebalance = new AtomicLong();
this.numOfRebalancedPartitions = new AtomicInteger();
this.partitionDataSize = new long[partitionCount];
this.partitionDataSizeAtLastRebalance = new long[partitionCount];
this.partitionDataSizeSinceLastRebalancePerTask = new long[partitionCount];
this.estimatedTaskBucketDataSizeSinceLastRebalance = new long[taskCount * taskBucketCount];
int[] taskBucketIds = new int[taskCount];
ImmutableList.Builder> partitionAssignments = ImmutableList.builder();
for (int partition = 0; partition < partitionCount; partition++) {
int taskId = partition % taskCount;
int bucketId = taskBucketIds[taskId]++ % taskBucketCount;
partitionAssignments.add(new CopyOnWriteArrayList<>(ImmutableList.of(new TaskBucket(taskId, bucketId))));
}
this.partitionAssignments = partitionAssignments.build();
}
@VisibleForTesting
List> getPartitionAssignments()
{
ImmutableList.Builder> assignedTasks = ImmutableList.builder();
for (List partitionAssignment : partitionAssignments) {
List tasks = partitionAssignment.stream()
.map(taskBucket -> taskBucket.taskId)
.collect(toImmutableList());
assignedTasks.add(tasks);
}
return assignedTasks.build();
}
public int getTaskCount()
{
return taskCount;
}
public int getTaskId(int partitionId, long index)
{
List taskIds = partitionAssignments.get(partitionId);
return taskIds.get(floorMod(index, taskIds.size())).taskId;
}
public void addDataProcessed(long dataSize)
{
dataProcessed.addAndGet(dataSize);
}
public void addPartitionRowCount(int partition, long rowCount)
{
partitionRowCount.addAndGet(partition, rowCount);
}
public void rebalance()
{
long currentDataProcessed = dataProcessed.get();
if (shouldRebalance(currentDataProcessed)) {
rebalancePartitions(currentDataProcessed);
}
}
private boolean shouldRebalance(long dataProcessed)
{
// Rebalance only when total bytes processed since last rebalance is greater than rebalance threshold.
// Check if the number of rebalanced partitions is less than maxPartitionsToRebalance.
return (dataProcessed - dataProcessedAtLastRebalance.get()) >= minDataProcessedRebalanceThreshold
&& numOfRebalancedPartitions.get() < maxPartitionsToRebalance;
}
private synchronized void rebalancePartitions(long dataProcessed)
{
if (!shouldRebalance(dataProcessed)) {
return;
}
calculatePartitionDataSize(dataProcessed);
// initialize partitionDataSizeSinceLastRebalancePerTask
for (int partition = 0; partition < partitionCount; partition++) {
int totalAssignedTasks = partitionAssignments.get(partition).size();
long dataSize = partitionDataSize[partition];
partitionDataSizeSinceLastRebalancePerTask[partition] =
(dataSize - partitionDataSizeAtLastRebalance[partition]) / totalAssignedTasks;
partitionDataSizeAtLastRebalance[partition] = dataSize;
}
// Initialize taskBucketMaxPartitions
List> taskBucketMaxPartitions = new ArrayList<>(taskCount * taskBucketCount);
for (int taskId = 0; taskId < taskCount; taskId++) {
for (int bucketId = 0; bucketId < taskBucketCount; bucketId++) {
taskBucketMaxPartitions.add(new IndexedPriorityQueue<>());
}
}
for (int partition = 0; partition < partitionCount; partition++) {
List taskAssignments = partitionAssignments.get(partition);
for (TaskBucket taskBucket : taskAssignments) {
IndexedPriorityQueue queue = taskBucketMaxPartitions.get(taskBucket.id);
queue.addOrUpdate(partition, partitionDataSizeSinceLastRebalancePerTask[partition]);
}
}
// Initialize maxTaskBuckets and minTaskBuckets
IndexedPriorityQueue maxTaskBuckets = new IndexedPriorityQueue<>();
IndexedPriorityQueue minTaskBuckets = new IndexedPriorityQueue<>();
for (int taskId = 0; taskId < taskCount; taskId++) {
for (int bucketId = 0; bucketId < taskBucketCount; bucketId++) {
TaskBucket taskBucket = new TaskBucket(taskId, bucketId);
estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id] =
calculateTaskBucketDataSizeSinceLastRebalance(taskBucketMaxPartitions.get(taskBucket.id));
maxTaskBuckets.addOrUpdate(taskBucket, estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
minTaskBuckets.addOrUpdate(taskBucket, Long.MAX_VALUE - estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
}
}
rebalanceBasedOnTaskBucketSkewness(maxTaskBuckets, minTaskBuckets, taskBucketMaxPartitions);
dataProcessedAtLastRebalance.set(dataProcessed);
}
private void calculatePartitionDataSize(long dataProcessed)
{
long totalPartitionRowCount = 0;
for (int partition = 0; partition < partitionCount; partition++) {
totalPartitionRowCount += partitionRowCount.get(partition);
}
for (int partition = 0; partition < partitionCount; partition++) {
partitionDataSize[partition] = (partitionRowCount.get(partition) * dataProcessed) / totalPartitionRowCount;
}
}
private long calculateTaskBucketDataSizeSinceLastRebalance(IndexedPriorityQueue maxPartitions)
{
long estimatedDataSizeSinceLastRebalance = 0;
for (int partition : maxPartitions) {
estimatedDataSizeSinceLastRebalance += partitionDataSizeSinceLastRebalancePerTask[partition];
}
return estimatedDataSizeSinceLastRebalance;
}
private void rebalanceBasedOnTaskBucketSkewness(
IndexedPriorityQueue maxTaskBuckets,
IndexedPriorityQueue minTaskBuckets,
List> taskBucketMaxPartitions)
{
List scaledPartitions = new ArrayList<>();
while (true) {
TaskBucket maxTaskBucket = maxTaskBuckets.poll();
if (maxTaskBucket == null) {
break;
}
IndexedPriorityQueue maxPartitions = taskBucketMaxPartitions.get(maxTaskBucket.id);
if (maxPartitions.isEmpty()) {
continue;
}
List minSkewedTaskBuckets = findSkewedMinTaskBuckets(maxTaskBucket, minTaskBuckets);
if (minSkewedTaskBuckets.isEmpty()) {
break;
}
while (true) {
Integer maxPartition = maxPartitions.poll();
if (maxPartition == null) {
break;
}
// Rebalance partition only once in a single cycle. Otherwise, rebalancing will happen quite
// aggressively in the early stage of write, while it is not required. Thus, it can have an impact on
// output file sizes and resource usage such that produced files can be small and memory usage
// might be higher.
if (scaledPartitions.contains(maxPartition)) {
continue;
}
int totalAssignedTasks = partitionAssignments.get(maxPartition).size();
if (partitionDataSize[maxPartition] >= (minPartitionDataProcessedRebalanceThreshold * totalAssignedTasks)) {
for (TaskBucket minTaskBucket : minSkewedTaskBuckets) {
if (rebalancePartition(maxPartition, minTaskBucket, maxTaskBuckets, minTaskBuckets)) {
scaledPartitions.add(maxPartition);
break;
}
}
}
else {
break;
}
}
}
}
private List findSkewedMinTaskBuckets(TaskBucket maxTaskBucket, IndexedPriorityQueue minTaskBuckets)
{
ImmutableList.Builder minSkewedTaskBuckets = ImmutableList.builder();
for (TaskBucket minTaskBucket : minTaskBuckets) {
double skewness =
((double) (estimatedTaskBucketDataSizeSinceLastRebalance[maxTaskBucket.id]
- estimatedTaskBucketDataSizeSinceLastRebalance[minTaskBucket.id]))
/ estimatedTaskBucketDataSizeSinceLastRebalance[maxTaskBucket.id];
if (skewness <= TASK_BUCKET_SKEWNESS_THRESHOLD || isNaN(skewness)) {
break;
}
if (maxTaskBucket.taskId != minTaskBucket.taskId) {
minSkewedTaskBuckets.add(minTaskBucket);
}
}
return minSkewedTaskBuckets.build();
}
private boolean rebalancePartition(
int partitionId,
TaskBucket toTaskBucket,
IndexedPriorityQueue maxTasks,
IndexedPriorityQueue minTasks)
{
List assignments = partitionAssignments.get(partitionId);
if (assignments.stream().anyMatch(taskBucket -> taskBucket.taskId == toTaskBucket.taskId)) {
return false;
}
// If the number of rebalanced partitions is less than maxPartitionsToRebalance then assign
// the partition to the task.
if (numOfRebalancedPartitions.get() >= maxPartitionsToRebalance) {
return false;
}
assignments.add(toTaskBucket);
int newTaskCount = assignments.size();
int oldTaskCount = newTaskCount - 1;
// Since a partition is rebalanced from max to min skewed taskBucket, decrease the priority of max
// taskBucket as well as increase the priority of min taskBucket.
for (TaskBucket taskBucket : assignments) {
if (taskBucket.equals(toTaskBucket)) {
estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id] +=
(partitionDataSizeSinceLastRebalancePerTask[partitionId] * oldTaskCount) / newTaskCount;
}
else {
estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id] -=
partitionDataSizeSinceLastRebalancePerTask[partitionId] / newTaskCount;
}
maxTasks.addOrUpdate(taskBucket, estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
minTasks.addOrUpdate(taskBucket, Long.MAX_VALUE - estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
}
// Increment the number of rebalanced partitions.
numOfRebalancedPartitions.incrementAndGet();
log.debug("Rebalanced partition %s to task %s with taskCount %s", partitionId, toTaskBucket.taskId, assignments.size());
return true;
}
private final class TaskBucket
{
private final int taskId;
private final int id;
private TaskBucket(int taskId, int bucketId)
{
this.taskId = taskId;
// Unique id for this task and bucket
this.id = (taskId * taskBucketCount) + bucketId;
}
@Override
public int hashCode()
{
return Objects.hash(taskId, id);
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TaskBucket that = (TaskBucket) o;
return that.id == id;
}
}
}