All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.spark.execution.task.PrestoSparkTaskExecution Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.execution.task;

import com.facebook.airlift.concurrent.SetThreadName;
import com.facebook.presto.event.SplitMonitor;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.SplitRunner;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskStateMachine;
import com.facebook.presto.execution.executor.TaskExecutor;
import com.facebook.presto.execution.executor.TaskHandle;
import com.facebook.presto.operator.Driver;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.DriverFactory;
import com.facebook.presto.operator.DriverStats;
import com.facebook.presto.operator.PipelineContext;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.spi.SplitWeight;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan;
import com.google.common.base.Joiner;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.Duration;

import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.getInitialSplitsPerNode;
import static com.facebook.presto.SystemSessionProperties.getMaxDriversPerTask;
import static com.facebook.presto.SystemSessionProperties.getSplitConcurrencyAdjustmentInterval;
import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
import static com.facebook.presto.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.SECONDS;

/**
 * The PrestoSparkTaskExecution is a simplified version of SqlTaskExecution.
 * It doesn't support grouped execution that is not needed on Presto on Spark.
 * Unlike the SqlTaskExecution the PrestoSparkTaskExecution does not require
 * the output buffer to be drained to mark the task as finished. As long as
 * all driver as finished the task execution is marked as finished. That allows to
 * have more control over the output Iterator lifecycle in the PrestoSparkTaskExecutor
 */
public class PrestoSparkTaskExecution
{
    private static final int MAX_JAVA_DRIVERS_FOR_NATIVE_TASK = 1;

    private final TaskId taskId;
    private final TaskStateMachine taskStateMachine;
    private final TaskContext taskContext;

    private final TaskHandle taskHandle;
    private final TaskExecutor taskExecutor;

    private final Executor notificationExecutor;

    private final SplitMonitor splitMonitor;

    private final List schedulingOrder;
    private final Map driverRunnerFactoriesWithSplitLifeCycle;
    private final List driverRunnerFactoriesWithTaskLifeCycle;

    /**
     * Number of drivers that have been sent to the TaskExecutor that have not finished.
     */
    private final AtomicInteger remainingDrivers = new AtomicInteger();

    private final AtomicBoolean started = new AtomicBoolean();

    public PrestoSparkTaskExecution(
            TaskStateMachine taskStateMachine,
            TaskContext taskContext,
            LocalExecutionPlan localExecutionPlan,
            TaskExecutor taskExecutor,
            SplitMonitor splitMonitor,
            Executor notificationExecutor,
            ScheduledExecutorService memoryUpdateExecutor)
    {
        this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null");
        this.taskId = taskStateMachine.getTaskId();
        this.taskContext = requireNonNull(taskContext, "taskContext is null");

        this.taskExecutor = requireNonNull(taskExecutor, "driverExecutor is null");
        this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null");

        this.splitMonitor = requireNonNull(splitMonitor, "splitMonitor is null");

        // index driver factories
        schedulingOrder = localExecutionPlan.getTableScanSourceOrder();
        Set tableScanSources = ImmutableSet.copyOf(schedulingOrder);
        ImmutableMap.Builder driverRunnerFactoriesWithSplitLifeCycle = ImmutableMap.builder();
        ImmutableList.Builder driverRunnerFactoriesWithTaskLifeCycle = ImmutableList.builder();
        for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) {
            Optional sourceId = driverFactory.getSourceId();
            if (sourceId.isPresent() && tableScanSources.contains(sourceId.get())) {
                driverRunnerFactoriesWithSplitLifeCycle.put(sourceId.get(), new DriverSplitRunnerFactory(driverFactory, true));
            }
            else {
                checkArgument(
                        driverFactory.getPipelineExecutionStrategy() == UNGROUPED_EXECUTION,
                        "unexpected pipeline execution strategy: %s",
                        driverFactory.getPipelineExecutionStrategy());
                driverRunnerFactoriesWithTaskLifeCycle.add(new DriverSplitRunnerFactory(driverFactory, false));
            }
        }
        this.driverRunnerFactoriesWithSplitLifeCycle = driverRunnerFactoriesWithSplitLifeCycle.build();
        this.driverRunnerFactoriesWithTaskLifeCycle = driverRunnerFactoriesWithTaskLifeCycle.build();

        checkArgument(this.driverRunnerFactoriesWithSplitLifeCycle.keySet().equals(tableScanSources),
                "Fragment is partitioned, but not all partitioned drivers were found");

        taskHandle = createTaskHandle(taskStateMachine, taskContext, localExecutionPlan, taskExecutor);

        requireNonNull(memoryUpdateExecutor, "memoryUpdateExecutor is null");
        memoryUpdateExecutor.schedule(taskContext::updatePeakMemory, 1, SECONDS);
    }

    // this is a separate method to ensure that the `this` reference is not leaked during construction
    private static TaskHandle createTaskHandle(
            TaskStateMachine taskStateMachine,
            TaskContext taskContext,
            LocalExecutionPlan localExecutionPlan,
            TaskExecutor taskExecutor)
    {
        TaskHandle taskHandle = taskExecutor.addTask(
                taskStateMachine.getTaskId(),
                () -> 0,
                getInitialSplitsPerNode(taskContext.getSession()),
                getSplitConcurrencyAdjustmentInterval(taskContext.getSession()),
                getMaxDriversPerTask(taskContext.getSession()));
        taskStateMachine.addStateChangeListener(state -> {
            if (state.isDone()) {
                taskExecutor.removeTask(taskHandle);
                for (DriverFactory factory : localExecutionPlan.getDriverFactories()) {
                    factory.noMoreDrivers();
                }
            }
        });
        return taskHandle;
    }

    public void start(List sources)
    {
        requireNonNull(sources, "sources is null");

        checkState(started.compareAndSet(false, true), "already started");

        scheduleDriversForTaskLifeCycle();
        scheduleDriversForSplitLifeCycle(sources);
        checkTaskCompletion();
    }

    private void scheduleDriversForTaskLifeCycle()
    {
        List runners = new ArrayList<>();
        for (DriverSplitRunnerFactory driverRunnerFactory : driverRunnerFactoriesWithTaskLifeCycle) {
            for (int i = 0; i < driverRunnerFactory.getDriverInstances().orElse(1); i++) {
                runners.add(driverRunnerFactory.createDriverRunner(null));
            }
        }
        enqueueDriverSplitRunner(true, runners);
        for (DriverSplitRunnerFactory driverRunnerFactory : driverRunnerFactoriesWithTaskLifeCycle) {
            driverRunnerFactory.noMoreDriverRunner();
            verify(driverRunnerFactory.isNoMoreDriverRunner());
        }
    }

    private synchronized void scheduleDriversForSplitLifeCycle(List sources)
    {
        checkArgument(sources.stream().allMatch(TaskSource::isNoMoreSplits), "All task sources are expected to be final");

        ListMultimap splits = ArrayListMultimap.create();
        for (TaskSource taskSource : sources) {
            splits.putAll(taskSource.getPlanNodeId(), taskSource.getSplits());
        }

        for (PlanNodeId planNodeId : schedulingOrder) {
            DriverSplitRunnerFactory driverSplitRunnerFactory = driverRunnerFactoriesWithSplitLifeCycle.get(planNodeId);
            List planNodeSplits = splits.get(planNodeId);
            scheduleTableScanSource(driverSplitRunnerFactory, planNodeSplits);
        }
    }

    private synchronized void scheduleTableScanSource(DriverSplitRunnerFactory factory, List splits)
    {
        factory.splitsAdded(splits.size(), SplitWeight.rawValueSum(splits, scheduledSplit -> scheduledSplit.getSplit().getSplitWeight()));

        // Enqueue driver runners with split lifecycle for this plan node and driver life cycle combination.
        ImmutableList.Builder runners = ImmutableList.builder();
        // For native execution, process all splits in a single driver.
        if (isNativeExecutionEnabled(this.taskContext.getSession())) {
            runners.add(factory.createDriverRunner(splits));
        }
        else { // For Java execution, process each split in a separate driver.
            for (ScheduledSplit scheduledSplit : splits) {
                runners.add(factory.createDriverRunner(ImmutableList.of(scheduledSplit)));
            }
        }

        enqueueDriverSplitRunner(false, runners.build());

        factory.noMoreDriverRunner();
    }

    private synchronized void enqueueDriverSplitRunner(boolean forceRunSplit, List runners)
    {
        // schedule driver to be executed
        List> finishedFutures = taskExecutor.enqueueSplits(taskHandle, forceRunSplit, runners);
        checkState(finishedFutures.size() == runners.size(), "Expected %s futures but got %s", runners.size(), finishedFutures.size());

        // when driver completes, update state and fire events
        for (int i = 0; i < finishedFutures.size(); i++) {
            ListenableFuture finishedFuture = finishedFutures.get(i);
            final DriverSplitRunner splitRunner = runners.get(i);

            // record new driver
            remainingDrivers.incrementAndGet();

            Futures.addCallback(finishedFuture, new FutureCallback()
            {
                @Override
                public void onSuccess(Object result)
                {
                    try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) {
                        // record driver is finished
                        remainingDrivers.decrementAndGet();

                        checkTaskCompletion();

                        splitMonitor.splitCompletedEvent(taskId, getDriverStats());
                    }
                }

                @Override
                public void onFailure(Throwable cause)
                {
                    try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) {
                        taskStateMachine.failed(cause);

                        // record driver is finished
                        remainingDrivers.decrementAndGet();

                        // fire failed event with cause
                        splitMonitor.splitFailedEvent(taskId, getDriverStats(), cause);
                    }
                }

                private DriverStats getDriverStats()
                {
                    DriverContext driverContext = splitRunner.getDriverContext();
                    DriverStats driverStats;
                    if (driverContext != null) {
                        driverStats = driverContext.getDriverStats();
                    }
                    else {
                        // split runner did not start successfully
                        driverStats = new DriverStats();
                    }

                    return driverStats;
                }
            }, notificationExecutor);
        }
    }

    private synchronized void checkTaskCompletion()
    {
        if (taskStateMachine.getState().isDone()) {
            return;
        }

        // are there more partition splits expected?
        for (DriverSplitRunnerFactory driverSplitRunnerFactory : driverRunnerFactoriesWithSplitLifeCycle.values()) {
            if (!driverSplitRunnerFactory.isNoMoreDriverRunner()) {
                return;
            }
        }
        // do we still have running tasks?
        if (remainingDrivers.get() != 0) {
            return;
        }

        // Cool! All done!
        taskStateMachine.finished();
    }

    @Override
    public String toString()
    {
        return toStringHelper(this)
                .add("taskId", taskId)
                .add("remainingDrivers", remainingDrivers.get())
                .toString();
    }

    private class DriverSplitRunnerFactory
    {
        private final DriverFactory driverFactory;
        private final PipelineContext pipelineContext;

        private final AtomicInteger pendingCreation = new AtomicInteger();
        private final AtomicBoolean noMoreDriverRunner = new AtomicBoolean();
        private final AtomicBoolean closed = new AtomicBoolean();

        private DriverSplitRunnerFactory(DriverFactory driverFactory, boolean partitioned)
        {
            this.driverFactory = requireNonNull(driverFactory, "driverFactory is null");
            this.pipelineContext = taskContext.addPipelineContext(driverFactory.getPipelineId(), driverFactory.isInputDriver(), driverFactory.isOutputDriver(), partitioned);
        }

        public DriverSplitRunner createDriverRunner(@Nullable List scheduledSplits)
        {
            checkState(!noMoreDriverRunner.get(), "Cannot create driver for pipeline: %s", pipelineContext.getPipelineId());
            pendingCreation.incrementAndGet();
            // create driver context immediately so the driver existence is recorded in the stats
            // splitWeight can be 0 as we don't load balance the executor based on their load average
            DriverContext driverContext = pipelineContext.addDriverContext(0, Lifespan.taskWide(), driverFactory.getFragmentResultCacheContext());
            return new DriverSplitRunner(this, driverContext, scheduledSplits);
        }

        public Driver createDriver(DriverContext driverContext, @Nullable List scheduledSplits)
        {
            Driver driver = driverFactory.createDriver(driverContext);
            if (scheduledSplits != null && scheduledSplits.size() > 0) {
                boolean isNativeExecutionEnabled = isNativeExecutionEnabled(driver.getDriverContext().getSession());
                if (!isNativeExecutionEnabled && scheduledSplits.size() != 1) {
                    throw new IllegalArgumentException(format("non-native (java) execution requires only one scheduledSplits but [%d] were found [%s]",
                            scheduledSplits.size(),
                            Joiner.on(",").join(scheduledSplits.stream().map(ScheduledSplit::toString).collect(Collectors.toList()))));
                }
                PlanNodeId sourceNodeId = isNativeExecutionEnabled ? driver.getSourceId().get() : Iterables.getOnlyElement(scheduledSplits).getPlanNodeId();
                // TableScanOperator requires partitioned split to be added before the first call to process
                driver.updateSource(new TaskSource(sourceNodeId, ImmutableSet.copyOf(scheduledSplits), true));
            }

            verify(pendingCreation.get() > 0, "pendingCreation is expected to be greater than zero");
            pendingCreation.decrementAndGet();

            closeDriverFactoryIfFullyCreated();

            return driver;
        }

        public void noMoreDriverRunner()
        {
            if (noMoreDriverRunner.get()) {
                return;
            }
            noMoreDriverRunner.set(true);
            closeDriverFactoryIfFullyCreated();
        }

        public boolean isNoMoreDriverRunner()
        {
            return noMoreDriverRunner.get();
        }

        public void closeDriverFactoryIfFullyCreated()
        {
            if (closed.get()) {
                return;
            }
            if (isNoMoreDriverRunner() && pendingCreation.get() == 0) {
                // ensure noMoreDrivers is called only once
                if (!closed.compareAndSet(false, true)) {
                    return;
                }
                driverFactory.noMoreDrivers(Lifespan.taskWide());
                driverFactory.noMoreDrivers();
            }
        }

        public OptionalInt getDriverInstances()
        {
            return driverFactory.getDriverInstances();
        }

        public void splitsAdded(int count, long weightSum)
        {
            pipelineContext.splitsAdded(count, weightSum);
        }
    }

    private static class DriverSplitRunner
            implements SplitRunner
    {
        private final DriverSplitRunnerFactory driverSplitRunnerFactory;
        private final DriverContext driverContext;

        @GuardedBy("this")
        private boolean closed;

        @Nullable
        private List scheduledSplits;

        @GuardedBy("this")
        private Driver driver;

        private DriverSplitRunner(DriverSplitRunnerFactory driverSplitRunnerFactory, DriverContext driverContext, @Nullable List scheduledSplits)
        {
            this.driverSplitRunnerFactory = requireNonNull(driverSplitRunnerFactory, "driverFactory is null");
            this.driverContext = requireNonNull(driverContext, "driverContext is null");
            this.scheduledSplits = scheduledSplits;
        }

        public synchronized DriverContext getDriverContext()
        {
            if (driver == null) {
                return null;
            }
            return driver.getDriverContext();
        }

        @Override
        public synchronized boolean isFinished()
        {
            if (closed) {
                return true;
            }

            return driver != null && driver.isFinished();
        }

        @Override
        public ListenableFuture processFor(Duration duration)
        {
            Driver driver;
            synchronized (this) {
                // if close() was called before we get here, there's not point in even creating the driver
                if (closed) {
                    return Futures.immediateFuture(null);
                }

                if (this.driver == null) {
                    this.driver = driverSplitRunnerFactory.createDriver(driverContext, scheduledSplits);
                }

                driver = this.driver;
            }

            return driver.processFor(duration);
        }

        @Override
        public String getInfo()
        {
            return scheduledSplits == null || scheduledSplits.isEmpty() ? "" :
                    format("DriverRunner splitCount=%d [%s]",
                            scheduledSplits.size(),
                            Joiner.on(",").join(scheduledSplits.stream().map(
                                    split -> split.getSplit().getConnectorSplit()).collect(Collectors.toList())));
        }

        @Override
        public void close()
        {
            Driver driver;
            synchronized (this) {
                closed = true;
                driver = this.driver;
            }

            if (driver != null) {
                driver.close();
            }
        }
    }
}