All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.spark.classloader_interface.PrestoSparkNativeExecutionShuffleManager Maven / Gradle / Ivy

There is a newer version: 0.290
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.classloader_interface;

import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.BaseShuffleHandle;
import org.apache.spark.shuffle.ShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.sort.BypassMergeSortShuffleHandle;
import org.apache.spark.storage.BlockManager;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static com.facebook.presto.spark.launcher.internal.com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

/*
 * {@link PrestoSparkNativeExecutionShuffleManager} is the shuffle manager implementing the Spark shuffle manager interface specifically for native execution. The reasons we have this
 *  new shuffle manager are:
 * 1. To bypass calling into Spark java shuffle writer/reader since the actual shuffle read/write will happen in C++ side. In PrestoSparkNativeExecutionShuffleManager, we registered
 *    a pair of no-op shuffle reader/writer to hook-up with regular Spark shuffle workflow.
 * 2. To capture the shuffle metadata (eg. {@link ShuffleHandle}) for later use. These metadata are only available during shuffle writer creation internally which is beyond the whole
 *    Presto-Spark native execution flow. By using the {@link PrestoSparkNativeExecutionShuffleManager}, we capture and store these metadata inside the shuffle manager and provide
 *    the APIs to allow native execution runtime access.
 * */
public class PrestoSparkNativeExecutionShuffleManager
        implements ShuffleManager
{
    private final Map partitionIdToShuffleHandle = new ConcurrentHashMap<>();
    private final Map> shuffleIdToBaseShuffleHandle = new ConcurrentHashMap<>();
    private final ShuffleManager fallbackShuffleManager;
    private static final String FALLBACK_SPARK_SHUFFLE_MANAGER = "spark.fallback.shuffle.manager";

    public PrestoSparkNativeExecutionShuffleManager(SparkConf conf)
    {
        fallbackShuffleManager = instantiateClass(conf.get(FALLBACK_SPARK_SHUFFLE_MANAGER), conf);
    }

    // Create an instance of the class with the given name, possibly initializing it with our conf
    private static  T instantiateClass(String className, SparkConf conf)
    {
        try {
            return (T) (Class.forName(className).getConstructor(SparkConf.class).newInstance(conf));
        }
        catch (ClassNotFoundException | InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
            throw new RuntimeException(format("%s class not found", className), e);
        }
    }

    protected void registerShuffleHandle(BaseShuffleHandle handle, int stageId, int mapId)
    {
        partitionIdToShuffleHandle.put(new StageAndMapId(stageId, mapId), handle);
        shuffleIdToBaseShuffleHandle.put(handle.shuffleId(), handle);
    }

    protected void unregisterShuffleHandle(int shuffleId, int stageId, int mapId)
    {
        partitionIdToShuffleHandle.remove(new StageAndMapId(stageId, mapId));
        shuffleIdToBaseShuffleHandle.remove(shuffleId);
    }

    @Override
    public  ShuffleHandle registerShuffle(int shuffleId, int numMaps, ShuffleDependency dependency)
    {
        return fallbackShuffleManager.registerShuffle(shuffleId, numMaps, dependency);
    }

    @Override
    public  ShuffleWriter getWriter(ShuffleHandle handle, int mapId, TaskContext context)
    {
        checkState(
                requireNonNull(handle, "handle is null") instanceof BypassMergeSortShuffleHandle,
                "class %s is not instance of BypassMergeSortShuffleHandle", handle.getClass().getName());
        BaseShuffleHandle baseShuffleHandle = (BaseShuffleHandle) handle;
        int shuffleId = baseShuffleHandle.shuffleId();
        int stageId = context.stageId();
        registerShuffleHandle(baseShuffleHandle, stageId, mapId);
        return new EmptyShuffleWriter<>(
                baseShuffleHandle.dependency().partitioner().numPartitions(),
                () -> unregisterShuffleHandle(shuffleId, stageId, mapId));
    }

    @Override
    public  ShuffleReader getReader(ShuffleHandle handle, int startPartition, int endPartition, TaskContext context)
    {
        return new EmptyShuffleReader<>();
    }

    @Override
    public boolean unregisterShuffle(int shuffleId)
    {
        fallbackShuffleManager.unregisterShuffle(shuffleId);
        return true;
    }

    @Override
    public ShuffleBlockResolver shuffleBlockResolver()
    {
        return fallbackShuffleManager.shuffleBlockResolver();
    }

    @Override
    public void stop()
    {
        fallbackShuffleManager.stop();
    }

    /*
     * This method can only be called inside Rdd's compute method otherwise the shuffleDependencyMap may not contain corresponding ShuffleHandle object.
     * The reason is that in Spark's ShuffleMapTask, it's guaranteed to call writer.getWriter(handle, mapId, context) first before calling the Rdd.compute()
     * method, therefore, the ShuffleHandle object will always be added to shuffleDependencyMap in getWriter before Rdd.compute().
     */
    public Optional getShuffleHandle(int stageId, int mapId)
    {
        return Optional.ofNullable(partitionIdToShuffleHandle.getOrDefault(new StageAndMapId(stageId, mapId), null));
    }

    public boolean hasRegisteredShuffleHandles()
    {
        return !partitionIdToShuffleHandle.isEmpty() || !shuffleIdToBaseShuffleHandle.isEmpty();
    }

    public int getNumOfPartitions(int shuffleId)
    {
        if (!shuffleIdToBaseShuffleHandle.containsKey(shuffleId)) {
            throw new RuntimeException(format("shuffleId=[%s] is not registered", shuffleId));
        }
        return shuffleIdToBaseShuffleHandle.get(shuffleId).dependency().partitioner().numPartitions();
    }

    static class EmptyShuffleReader
            implements ShuffleReader
    {
        @Override
        public Iterator> read()
        {
            return emptyScalaIterator();
        }
    }

    static class EmptyShuffleWriter
            extends ShuffleWriter
    {
        private final long[] mapStatus;
        private final Runnable onStop;
        private static final long DEFAULT_MAP_STATUS = 1L;

        public EmptyShuffleWriter(int totalMapStages, Runnable onStop)
        {
            this.mapStatus = new long[totalMapStages];
            this.onStop = requireNonNull(onStop, "onStop is null");
            Arrays.fill(mapStatus, DEFAULT_MAP_STATUS);
        }

        @Override
        public void write(Iterator> records)
                throws IOException
        {
            if (records.hasNext()) {
                throw new RuntimeException("EmptyShuffleWriter can only take empty write input.");
            }
        }

        @Override
        public Option stop(boolean success)
        {
            onStop.run();
            BlockManager blockManager = SparkEnv.get().blockManager();
            return Option.apply(MapStatus$.MODULE$.apply(blockManager.blockManagerId(), mapStatus));
        }
    }

    public static class StageAndMapId
    {
        private final int stageId;
        private final int mapId;

        public StageAndMapId(int stageId, int mapId)
        {
            this.stageId = stageId;
            this.mapId = mapId;
        }

        public int getStageId()
        {
            return stageId;
        }

        public int getMapId()
        {
            return mapId;
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            StageAndMapId that = (StageAndMapId) o;
            return stageId == that.stageId && mapId == that.mapId;
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(stageId, mapId);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy