com.facebook.presto.spark.execution.task.PrestoSparkTaskExecution Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.execution.task;
import com.facebook.airlift.concurrent.SetThreadName;
import com.facebook.presto.event.SplitMonitor;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.SplitRunner;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskStateMachine;
import com.facebook.presto.execution.executor.TaskExecutor;
import com.facebook.presto.execution.executor.TaskHandle;
import com.facebook.presto.operator.Driver;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.DriverFactory;
import com.facebook.presto.operator.DriverStats;
import com.facebook.presto.operator.PipelineContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.spi.SplitWeight;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan;
import com.google.common.base.Joiner;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.Duration;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.getInitialSplitsPerNode;
import static com.facebook.presto.SystemSessionProperties.getMaxDriversPerTask;
import static com.facebook.presto.SystemSessionProperties.getSplitConcurrencyAdjustmentInterval;
import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
import static com.facebook.presto.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.SECONDS;
/**
* The PrestoSparkTaskExecution is a simplified version of SqlTaskExecution.
* It doesn't support grouped execution that is not needed on Presto on Spark.
* Unlike the SqlTaskExecution the PrestoSparkTaskExecution does not require
* the output buffer to be drained to mark the task as finished. As long as
* all driver as finished the task execution is marked as finished. That allows to
* have more control over the output Iterator lifecycle in the PrestoSparkTaskExecutor
*/
public class PrestoSparkTaskExecution
{
private static final int MAX_JAVA_DRIVERS_FOR_NATIVE_TASK = 1;
private final TaskId taskId;
private final TaskStateMachine taskStateMachine;
private final TaskContext taskContext;
private final TaskHandle taskHandle;
private final TaskExecutor taskExecutor;
private final Executor notificationExecutor;
private final SplitMonitor splitMonitor;
private final List schedulingOrder;
private final Map driverRunnerFactoriesWithSplitLifeCycle;
private final List driverRunnerFactoriesWithTaskLifeCycle;
/**
* Number of drivers that have been sent to the TaskExecutor that have not finished.
*/
private final AtomicInteger remainingDrivers = new AtomicInteger();
private final AtomicBoolean started = new AtomicBoolean();
public PrestoSparkTaskExecution(
TaskStateMachine taskStateMachine,
TaskContext taskContext,
LocalExecutionPlan localExecutionPlan,
TaskExecutor taskExecutor,
SplitMonitor splitMonitor,
Executor notificationExecutor,
ScheduledExecutorService memoryUpdateExecutor)
{
this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null");
this.taskId = taskStateMachine.getTaskId();
this.taskContext = requireNonNull(taskContext, "taskContext is null");
this.taskExecutor = requireNonNull(taskExecutor, "driverExecutor is null");
this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null");
this.splitMonitor = requireNonNull(splitMonitor, "splitMonitor is null");
// index driver factories
schedulingOrder = localExecutionPlan.getTableScanSourceOrder();
Set tableScanSources = ImmutableSet.copyOf(schedulingOrder);
ImmutableMap.Builder driverRunnerFactoriesWithSplitLifeCycle = ImmutableMap.builder();
ImmutableList.Builder driverRunnerFactoriesWithTaskLifeCycle = ImmutableList.builder();
for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) {
Optional sourceId = driverFactory.getSourceId();
if (sourceId.isPresent() && tableScanSources.contains(sourceId.get())) {
driverRunnerFactoriesWithSplitLifeCycle.put(sourceId.get(), new DriverSplitRunnerFactory(driverFactory, true));
}
else {
checkArgument(
driverFactory.getPipelineExecutionStrategy() == UNGROUPED_EXECUTION,
"unexpected pipeline execution strategy: %s",
driverFactory.getPipelineExecutionStrategy());
driverRunnerFactoriesWithTaskLifeCycle.add(new DriverSplitRunnerFactory(driverFactory, false));
}
}
this.driverRunnerFactoriesWithSplitLifeCycle = driverRunnerFactoriesWithSplitLifeCycle.build();
this.driverRunnerFactoriesWithTaskLifeCycle = driverRunnerFactoriesWithTaskLifeCycle.build();
checkArgument(this.driverRunnerFactoriesWithSplitLifeCycle.keySet().equals(tableScanSources),
"Fragment is partitioned, but not all partitioned drivers were found");
taskHandle = createTaskHandle(taskStateMachine, taskContext, localExecutionPlan, taskExecutor);
requireNonNull(memoryUpdateExecutor, "memoryUpdateExecutor is null");
memoryUpdateExecutor.schedule(taskContext::updatePeakMemory, 1, SECONDS);
}
// this is a separate method to ensure that the `this` reference is not leaked during construction
private static TaskHandle createTaskHandle(
TaskStateMachine taskStateMachine,
TaskContext taskContext,
LocalExecutionPlan localExecutionPlan,
TaskExecutor taskExecutor)
{
TaskHandle taskHandle = taskExecutor.addTask(
taskStateMachine.getTaskId(),
() -> 0,
getInitialSplitsPerNode(taskContext.getSession()),
getSplitConcurrencyAdjustmentInterval(taskContext.getSession()),
getMaxDriversPerTask(taskContext.getSession()));
taskStateMachine.addStateChangeListener(state -> {
if (state.isDone()) {
taskExecutor.removeTask(taskHandle);
for (DriverFactory factory : localExecutionPlan.getDriverFactories()) {
factory.noMoreDrivers();
}
}
});
return taskHandle;
}
public void start(List sources)
{
requireNonNull(sources, "sources is null");
checkState(started.compareAndSet(false, true), "already started");
scheduleDriversForTaskLifeCycle();
scheduleDriversForSplitLifeCycle(sources);
checkTaskCompletion();
}
private void scheduleDriversForTaskLifeCycle()
{
List runners = new ArrayList<>();
for (DriverSplitRunnerFactory driverRunnerFactory : driverRunnerFactoriesWithTaskLifeCycle) {
for (int i = 0; i < driverRunnerFactory.getDriverInstances().orElse(1); i++) {
runners.add(driverRunnerFactory.createDriverRunner(null));
}
}
enqueueDriverSplitRunner(true, runners);
for (DriverSplitRunnerFactory driverRunnerFactory : driverRunnerFactoriesWithTaskLifeCycle) {
driverRunnerFactory.noMoreDriverRunner();
verify(driverRunnerFactory.isNoMoreDriverRunner());
}
}
private synchronized void scheduleDriversForSplitLifeCycle(List sources)
{
checkArgument(sources.stream().allMatch(TaskSource::isNoMoreSplits), "All task sources are expected to be final");
ListMultimap splits = ArrayListMultimap.create();
for (TaskSource taskSource : sources) {
splits.putAll(taskSource.getPlanNodeId(), taskSource.getSplits());
}
for (PlanNodeId planNodeId : schedulingOrder) {
DriverSplitRunnerFactory driverSplitRunnerFactory = driverRunnerFactoriesWithSplitLifeCycle.get(planNodeId);
List planNodeSplits = splits.get(planNodeId);
scheduleTableScanSource(driverSplitRunnerFactory, planNodeSplits);
}
}
private synchronized void scheduleTableScanSource(DriverSplitRunnerFactory factory, List splits)
{
factory.splitsAdded(splits.size(), SplitWeight.rawValueSum(splits, scheduledSplit -> scheduledSplit.getSplit().getSplitWeight()));
// Enqueue driver runners with split lifecycle for this plan node and driver life cycle combination.
ImmutableList.Builder runners = ImmutableList.builder();
// For native execution, process all splits in a single driver.
if (isNativeExecutionEnabled(this.taskContext.getSession())) {
runners.add(factory.createDriverRunner(splits));
}
else { // For Java execution, process each split in a separate driver.
for (ScheduledSplit scheduledSplit : splits) {
runners.add(factory.createDriverRunner(ImmutableList.of(scheduledSplit)));
}
}
enqueueDriverSplitRunner(false, runners.build());
factory.noMoreDriverRunner();
}
private synchronized void enqueueDriverSplitRunner(boolean forceRunSplit, List runners)
{
// schedule driver to be executed
List> finishedFutures = taskExecutor.enqueueSplits(taskHandle, forceRunSplit, runners);
checkState(finishedFutures.size() == runners.size(), "Expected %s futures but got %s", runners.size(), finishedFutures.size());
// when driver completes, update state and fire events
for (int i = 0; i < finishedFutures.size(); i++) {
ListenableFuture> finishedFuture = finishedFutures.get(i);
final DriverSplitRunner splitRunner = runners.get(i);
// record new driver
remainingDrivers.incrementAndGet();
Futures.addCallback(finishedFuture, new FutureCallback
© 2015 - 2025 Weber Informatics LLC | Privacy Policy