Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.prestosql.execution.scheduler.SqlQueryScheduler Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.execution.scheduler;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.concurrent.SetThreadName;
import io.airlift.stats.TimeStat;
import io.airlift.units.Duration;
import io.prestosql.Session;
import io.prestosql.connector.CatalogName;
import io.prestosql.execution.BasicStageStats;
import io.prestosql.execution.NodeTaskMap;
import io.prestosql.execution.QueryState;
import io.prestosql.execution.QueryStateMachine;
import io.prestosql.execution.RemoteTask;
import io.prestosql.execution.RemoteTaskFactory;
import io.prestosql.execution.SqlStageExecution;
import io.prestosql.execution.StageId;
import io.prestosql.execution.StageInfo;
import io.prestosql.execution.StageState;
import io.prestosql.execution.TaskStatus;
import io.prestosql.execution.buffer.OutputBuffers;
import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId;
import io.prestosql.failuredetector.FailureDetector;
import io.prestosql.metadata.InternalNode;
import io.prestosql.server.DynamicFilterService.StageDynamicFilters;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.ConnectorPartitionHandle;
import io.prestosql.split.SplitSource;
import io.prestosql.sql.planner.NodePartitionMap;
import io.prestosql.sql.planner.NodePartitioningManager;
import io.prestosql.sql.planner.PartitioningHandle;
import io.prestosql.sql.planner.StageExecutionPlan;
import io.prestosql.sql.planner.plan.PlanFragmentId;
import io.prestosql.sql.planner.plan.PlanNodeId;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Sets.newConcurrentHashSet;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.tryGetFutureValue;
import static io.airlift.concurrent.MoreFutures.whenAnyComplete;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static io.prestosql.SystemSessionProperties.getConcurrentLifespansPerNode;
import static io.prestosql.SystemSessionProperties.getWriterMinSize;
import static io.prestosql.connector.CatalogName.isInternalSystemConnector;
import static io.prestosql.execution.BasicStageStats.aggregateBasicStageStats;
import static io.prestosql.execution.SqlStageExecution.createSqlStageExecution;
import static io.prestosql.execution.StageState.ABORTED;
import static io.prestosql.execution.StageState.CANCELED;
import static io.prestosql.execution.StageState.FAILED;
import static io.prestosql.execution.StageState.FINISHED;
import static io.prestosql.execution.StageState.RUNNING;
import static io.prestosql.execution.StageState.SCHEDULED;
import static io.prestosql.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler;
import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.prestosql.spi.StandardErrorCode.NO_NODES_AVAILABLE;
import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED;
import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static io.prestosql.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
import static io.prestosql.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
import static io.prestosql.sql.planner.plan.ExchangeNode.Type.REPLICATE;
import static io.prestosql.util.Failures.checkCondition;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
public class SqlQueryScheduler
{
private final QueryStateMachine queryStateMachine;
private final ExecutionPolicy executionPolicy;
private final Map stages;
private final ExecutorService executor;
private final StageId rootStageId;
private final Map stageSchedulers;
private final Map stageLinkages;
private final SplitSchedulerStats schedulerStats;
private final boolean summarizeTaskInfo;
private final AtomicBoolean started = new AtomicBoolean();
public static SqlQueryScheduler createSqlQueryScheduler(
QueryStateMachine queryStateMachine,
StageExecutionPlan plan,
NodePartitioningManager nodePartitioningManager,
NodeScheduler nodeScheduler,
RemoteTaskFactory remoteTaskFactory,
Session session,
boolean summarizeTaskInfo,
int splitBatchSize,
ExecutorService queryExecutor,
ScheduledExecutorService schedulerExecutor,
FailureDetector failureDetector,
OutputBuffers rootOutputBuffers,
NodeTaskMap nodeTaskMap,
ExecutionPolicy executionPolicy,
SplitSchedulerStats schedulerStats)
{
SqlQueryScheduler sqlQueryScheduler = new SqlQueryScheduler(
queryStateMachine,
plan,
nodePartitioningManager,
nodeScheduler,
remoteTaskFactory,
session,
summarizeTaskInfo,
splitBatchSize,
queryExecutor,
schedulerExecutor,
failureDetector,
rootOutputBuffers,
nodeTaskMap,
executionPolicy,
schedulerStats);
sqlQueryScheduler.initialize();
return sqlQueryScheduler;
}
private SqlQueryScheduler(
QueryStateMachine queryStateMachine,
StageExecutionPlan plan,
NodePartitioningManager nodePartitioningManager,
NodeScheduler nodeScheduler,
RemoteTaskFactory remoteTaskFactory,
Session session,
boolean summarizeTaskInfo,
int splitBatchSize,
ExecutorService queryExecutor,
ScheduledExecutorService schedulerExecutor,
FailureDetector failureDetector,
OutputBuffers rootOutputBuffers,
NodeTaskMap nodeTaskMap,
ExecutionPolicy executionPolicy,
SplitSchedulerStats schedulerStats)
{
this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null");
this.executionPolicy = requireNonNull(executionPolicy, "schedulerPolicyFactory is null");
this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null");
this.summarizeTaskInfo = summarizeTaskInfo;
// todo come up with a better way to build this, or eliminate this map
ImmutableMap.Builder stageSchedulers = ImmutableMap.builder();
ImmutableMap.Builder stageLinkages = ImmutableMap.builder();
// Only fetch a distribution once per query to assure all stages see the same machine assignments
Map partitioningCache = new HashMap<>();
OutputBufferId rootBufferId = Iterables.getOnlyElement(rootOutputBuffers.getBuffers().keySet());
List stages = createStages(
(fragmentId, tasks, noMoreExchangeLocations) -> updateQueryOutputLocations(queryStateMachine, rootBufferId, tasks, noMoreExchangeLocations),
new AtomicInteger(),
plan.withBucketToPartition(Optional.of(new int[1])),
nodeScheduler,
remoteTaskFactory,
session,
splitBatchSize,
partitioningHandle -> partitioningCache.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(session, handle)),
nodePartitioningManager,
queryExecutor,
schedulerExecutor,
failureDetector,
nodeTaskMap,
stageSchedulers,
stageLinkages);
SqlStageExecution rootStage = stages.get(0);
rootStage.setOutputBuffers(rootOutputBuffers);
this.rootStageId = rootStage.getStageId();
this.stages = stages.stream()
.collect(toImmutableMap(SqlStageExecution::getStageId, identity()));
this.stageSchedulers = stageSchedulers.build();
this.stageLinkages = stageLinkages.build();
this.executor = queryExecutor;
}
// this is a separate method to ensure that the `this` reference is not leaked during construction
private void initialize()
{
SqlStageExecution rootStage = stages.get(rootStageId);
rootStage.addStateChangeListener(state -> {
if (state == FINISHED) {
queryStateMachine.transitionToFinishing();
}
else if (state == CANCELED) {
// output stage was canceled
queryStateMachine.transitionToCanceled();
}
});
for (SqlStageExecution stage : stages.values()) {
stage.addStateChangeListener(state -> {
if (queryStateMachine.isDone()) {
return;
}
if (state == FAILED) {
queryStateMachine.transitionToFailed(stage.getStageInfo().getFailureCause().toException());
}
else if (state == ABORTED) {
// this should never happen, since abort can only be triggered in query clean up after the query is finished
queryStateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "Query stage was aborted"));
}
else if (queryStateMachine.getQueryState() == QueryState.STARTING) {
// if the stage has at least one task, we are running
if (stage.hasTasks()) {
queryStateMachine.transitionToRunning();
}
}
});
}
// when query is done or any time a stage completes, attempt to transition query to "final query info ready"
queryStateMachine.addStateChangeListener(newState -> {
if (newState.isDone()) {
queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo()));
}
});
for (SqlStageExecution stage : stages.values()) {
stage.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo())));
}
}
private static void updateQueryOutputLocations(QueryStateMachine queryStateMachine, OutputBufferId rootBufferId, Set tasks, boolean noMoreExchangeLocations)
{
Set bufferLocations = tasks.stream()
.map(task -> task.getTaskStatus().getSelf())
.map(location -> uriBuilderFrom(location).appendPath("results").appendPath(rootBufferId.toString()).build())
.collect(toImmutableSet());
queryStateMachine.updateOutputLocations(bufferLocations, noMoreExchangeLocations);
}
private List createStages(
ExchangeLocationsConsumer parent,
AtomicInteger nextStageId,
StageExecutionPlan plan,
NodeScheduler nodeScheduler,
RemoteTaskFactory remoteTaskFactory,
Session session,
int splitBatchSize,
Function partitioningCache,
NodePartitioningManager nodePartitioningManager,
ExecutorService queryExecutor,
ScheduledExecutorService schedulerExecutor,
FailureDetector failureDetector,
NodeTaskMap nodeTaskMap,
ImmutableMap.Builder stageSchedulers,
ImmutableMap.Builder stageLinkages)
{
ImmutableList.Builder stages = ImmutableList.builder();
StageId stageId = new StageId(queryStateMachine.getQueryId(), nextStageId.getAndIncrement());
SqlStageExecution stage = createSqlStageExecution(
stageId,
plan.getFragment(),
plan.getTables(),
remoteTaskFactory,
session,
summarizeTaskInfo,
nodeTaskMap,
queryExecutor,
failureDetector,
schedulerStats);
stages.add(stage);
Optional bucketToPartition;
PartitioningHandle partitioningHandle = plan.getFragment().getPartitioning();
if (partitioningHandle.equals(SOURCE_DISTRIBUTION)) {
// nodes are selected dynamically based on the constraints of the splits and the system load
Entry entry = Iterables.getOnlyElement(plan.getSplitSources().entrySet());
PlanNodeId planNodeId = entry.getKey();
SplitSource splitSource = entry.getValue();
Optional catalogName = Optional.of(splitSource.getCatalogName())
.filter(catalog -> !isInternalSystemConnector(catalog));
NodeSelector nodeSelector = nodeScheduler.createNodeSelector(catalogName);
SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stage::getAllTasks);
checkArgument(!plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution());
stageSchedulers.put(stageId, newSourcePartitionedSchedulerAsStageScheduler(stage, planNodeId, splitSource, placementPolicy, splitBatchSize));
bucketToPartition = Optional.of(new int[1]);
}
else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
bucketToPartition = Optional.of(new int[1]);
}
else {
Map splitSources = plan.getSplitSources();
if (!splitSources.isEmpty()) {
// contains local source
List schedulingOrder = plan.getFragment().getPartitionedSources();
Optional catalogName = partitioningHandle.getConnectorId();
catalogName.orElseThrow(() -> new IllegalArgumentException("No connector ID for partitioning handle: " + partitioningHandle));
List connectorPartitionHandles;
boolean groupedExecutionForStage = plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution();
if (groupedExecutionForStage) {
connectorPartitionHandles = nodePartitioningManager.listPartitionHandles(session, partitioningHandle);
checkState(!ImmutableList.of(NOT_PARTITIONED).equals(connectorPartitionHandles));
}
else {
connectorPartitionHandles = ImmutableList.of(NOT_PARTITIONED);
}
BucketNodeMap bucketNodeMap;
List stageNodeList;
if (plan.getFragment().getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) {
// no remote source
boolean dynamicLifespanSchedule = plan.getFragment().getStageExecutionDescriptor().isDynamicLifespanSchedule();
bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, dynamicLifespanSchedule);
// verify execution is consistent with planner's decision on dynamic lifespan schedule
verify(bucketNodeMap.isDynamic() == dynamicLifespanSchedule);
stageNodeList = new ArrayList<>(nodeScheduler.createNodeSelector(catalogName).allNodes());
Collections.shuffle(stageNodeList);
bucketToPartition = Optional.empty();
}
else {
// cannot use dynamic lifespan schedule
verify(!plan.getFragment().getStageExecutionDescriptor().isDynamicLifespanSchedule());
// remote source requires nodePartitionMap
NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning());
if (groupedExecutionForStage) {
checkState(connectorPartitionHandles.size() == nodePartitionMap.getBucketToPartition().length);
}
stageNodeList = nodePartitionMap.getPartitionToNode();
bucketNodeMap = nodePartitionMap.asBucketNodeMap();
bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition());
}
stageSchedulers.put(stageId, new FixedSourcePartitionedScheduler(
stage,
splitSources,
plan.getFragment().getStageExecutionDescriptor(),
schedulingOrder,
stageNodeList,
bucketNodeMap,
splitBatchSize,
getConcurrentLifespansPerNode(session),
nodeScheduler.createNodeSelector(catalogName),
connectorPartitionHandles));
}
else {
// all sources are remote
NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning());
List partitionToNode = nodePartitionMap.getPartitionToNode();
// todo this should asynchronously wait a standard timeout period before failing
checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available");
stageSchedulers.put(stageId, new FixedCountScheduler(stage, partitionToNode));
bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition());
}
}
ImmutableSet.Builder childStagesBuilder = ImmutableSet.builder();
for (StageExecutionPlan subStagePlan : plan.getSubStages()) {
List subTree = createStages(
stage::addExchangeLocations,
nextStageId,
subStagePlan.withBucketToPartition(bucketToPartition),
nodeScheduler,
remoteTaskFactory,
session,
splitBatchSize,
partitioningCache,
nodePartitioningManager,
queryExecutor,
schedulerExecutor,
failureDetector,
nodeTaskMap,
stageSchedulers,
stageLinkages);
stages.addAll(subTree);
SqlStageExecution childStage = subTree.get(0);
childStagesBuilder.add(childStage);
}
Set childStages = childStagesBuilder.build();
stage.addStateChangeListener(newState -> {
if (newState.isDone()) {
childStages.forEach(SqlStageExecution::cancel);
}
});
stageLinkages.put(stageId, new StageLinkage(plan.getFragment().getId(), parent, childStages));
if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
Supplier> sourceTasksProvider = () -> childStages.stream()
.map(SqlStageExecution::getAllTasks)
.flatMap(Collection::stream)
.map(RemoteTask::getTaskStatus)
.collect(toList());
Supplier> writerTasksProvider = () -> stage.getAllTasks().stream()
.map(RemoteTask::getTaskStatus)
.collect(toList());
ScaledWriterScheduler scheduler = new ScaledWriterScheduler(
stage,
sourceTasksProvider,
writerTasksProvider,
nodeScheduler.createNodeSelector(Optional.empty()),
schedulerExecutor,
getWriterMinSize(session));
whenAllStages(childStages, StageState::isDone)
.addListener(scheduler::finish, directExecutor());
stageSchedulers.put(stageId, scheduler);
}
return stages.build();
}
public BasicStageStats getBasicStageStats()
{
List stageStats = stages.values().stream()
.map(SqlStageExecution::getBasicStageStats)
.collect(toImmutableList());
return aggregateBasicStageStats(stageStats);
}
public StageInfo getStageInfo()
{
Map stageInfos = stages.values().stream()
.map(SqlStageExecution::getStageInfo)
.collect(toImmutableMap(StageInfo::getStageId, identity()));
return buildStageInfo(rootStageId, stageInfos);
}
public List getStageDynamicFilters()
{
return stages.values().stream()
.map(SqlStageExecution::getStageDynamicFilters)
.collect(toImmutableList());
}
private StageInfo buildStageInfo(StageId stageId, Map stageInfos)
{
StageInfo parent = stageInfos.get(stageId);
checkArgument(parent != null, "No stageInfo for %s", parent);
List childStages = stageLinkages.get(stageId).getChildStageIds().stream()
.map(childStageId -> buildStageInfo(childStageId, stageInfos))
.collect(toImmutableList());
if (childStages.isEmpty()) {
return parent;
}
return new StageInfo(
parent.getStageId(),
parent.getState(),
parent.getPlan(),
parent.getTypes(),
parent.getStageStats(),
parent.getTasks(),
childStages,
parent.getTables(),
parent.getFailureCause());
}
public long getUserMemoryReservation()
{
return stages.values().stream()
.mapToLong(SqlStageExecution::getUserMemoryReservation)
.sum();
}
public long getTotalMemoryReservation()
{
return stages.values().stream()
.mapToLong(SqlStageExecution::getTotalMemoryReservation)
.sum();
}
public Duration getTotalCpuTime()
{
long millis = stages.values().stream()
.mapToLong(stage -> stage.getTotalCpuTime().toMillis())
.sum();
return new Duration(millis, MILLISECONDS);
}
public void start()
{
if (started.compareAndSet(false, true)) {
executor.submit(this::schedule);
}
}
private void schedule()
{
try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) {
Set completedStages = new HashSet<>();
ExecutionSchedule executionSchedule = executionPolicy.createExecutionSchedule(stages.values());
while (!executionSchedule.isFinished()) {
List> blockedStages = new ArrayList<>();
for (SqlStageExecution stage : executionSchedule.getStagesToSchedule()) {
stage.beginScheduling();
// perform some scheduling work
ScheduleResult result = stageSchedulers.get(stage.getStageId())
.schedule();
// modify parent and children based on the results of the scheduling
if (result.isFinished()) {
stage.schedulingComplete();
}
else if (!result.getBlocked().isDone()) {
blockedStages.add(result.getBlocked());
}
stageLinkages.get(stage.getStageId())
.processScheduleResults(stage.getState(), result.getNewTasks());
schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled());
if (result.getBlockedReason().isPresent()) {
switch (result.getBlockedReason().get()) {
case WRITER_SCALING:
// no-op
break;
case WAITING_FOR_SOURCE:
schedulerStats.getWaitingForSource().update(1);
break;
case SPLIT_QUEUES_FULL:
schedulerStats.getSplitQueuesFull().update(1);
break;
case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE:
case NO_ACTIVE_DRIVER_GROUP:
break;
default:
throw new UnsupportedOperationException("Unknown blocked reason: " + result.getBlockedReason().get());
}
}
}
// make sure to update stage linkage at least once per loop to catch async state changes (e.g., partial cancel)
for (SqlStageExecution stage : stages.values()) {
if (!completedStages.contains(stage.getStageId()) && stage.getState().isDone()) {
stageLinkages.get(stage.getStageId())
.processScheduleResults(stage.getState(), ImmutableSet.of());
completedStages.add(stage.getStageId());
}
}
// wait for a state change and then schedule again
if (!blockedStages.isEmpty()) {
try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) {
tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS);
}
for (ListenableFuture blockedStage : blockedStages) {
blockedStage.cancel(true);
}
}
}
for (SqlStageExecution stage : stages.values()) {
StageState state = stage.getState();
if (state != SCHEDULED && state != RUNNING && !state.isDone()) {
throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage %s is in state %s", stage.getStageId(), state));
}
}
}
catch (Throwable t) {
queryStateMachine.transitionToFailed(t);
throw t;
}
finally {
RuntimeException closeError = new RuntimeException();
for (StageScheduler scheduler : stageSchedulers.values()) {
try {
scheduler.close();
}
catch (Throwable t) {
queryStateMachine.transitionToFailed(t);
// Self-suppression not permitted
if (closeError != t) {
closeError.addSuppressed(t);
}
}
}
if (closeError.getSuppressed().length > 0) {
throw closeError;
}
}
}
public void cancelStage(StageId stageId)
{
try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) {
SqlStageExecution sqlStageExecution = stages.get(stageId);
SqlStageExecution stage = requireNonNull(sqlStageExecution, () -> format("Stage '%s' does not exist", stageId));
stage.cancel();
}
}
public void abort()
{
try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) {
stages.values().forEach(SqlStageExecution::abort);
}
}
private static ListenableFuture whenAllStages(Collection stages, Predicate predicate)
{
checkArgument(!stages.isEmpty(), "stages is empty");
Set stageIds = newConcurrentHashSet(stages.stream()
.map(SqlStageExecution::getStageId)
.collect(toSet()));
SettableFuture future = SettableFuture.create();
for (SqlStageExecution stage : stages) {
stage.addStateChangeListener(state -> {
if (predicate.test(state) && stageIds.remove(stage.getStageId()) && stageIds.isEmpty()) {
future.set(null);
}
});
}
return future;
}
private interface ExchangeLocationsConsumer
{
void addExchangeLocations(PlanFragmentId fragmentId, Set tasks, boolean noMoreExchangeLocations);
}
private static class StageLinkage
{
private final PlanFragmentId currentStageFragmentId;
private final ExchangeLocationsConsumer parent;
private final Set childOutputBufferManagers;
private final Set childStageIds;
public StageLinkage(PlanFragmentId fragmentId, ExchangeLocationsConsumer parent, Set children)
{
this.currentStageFragmentId = fragmentId;
this.parent = parent;
this.childOutputBufferManagers = children.stream()
.map(childStage -> {
PartitioningHandle partitioningHandle = childStage.getFragment().getPartitioningScheme().getPartitioning().getHandle();
if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) {
return new BroadcastOutputBufferManager(childStage::setOutputBuffers);
}
else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
return new ScaledOutputBufferManager(childStage::setOutputBuffers);
}
else {
int partitionCount = Ints.max(childStage.getFragment().getPartitioningScheme().getBucketToPartition().get()) + 1;
return new PartitionedOutputBufferManager(partitioningHandle, partitionCount, childStage::setOutputBuffers);
}
})
.collect(toImmutableSet());
this.childStageIds = children.stream()
.map(SqlStageExecution::getStageId)
.collect(toImmutableSet());
}
public Set getChildStageIds()
{
return childStageIds;
}
public void processScheduleResults(StageState newState, Set newTasks)
{
boolean noMoreTasks = !newState.canScheduleMoreTasks();
// Add an exchange location to the parent stage for each new task
parent.addExchangeLocations(currentStageFragmentId, newTasks, noMoreTasks);
if (!childOutputBufferManagers.isEmpty()) {
// Add an output buffer to the child stages for each new task
List newOutputBuffers = newTasks.stream()
.map(task -> new OutputBufferId(task.getTaskId().getId()))
.collect(toImmutableList());
for (OutputBufferManager child : childOutputBufferManagers) {
child.addOutputBuffers(newOutputBuffers, noMoreTasks);
}
}
}
}
}