All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.spark.execution.PrestoSparkRemoteSourceFactory Maven / Gradle / Ivy

There is a newer version: 0.289
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.execution;

import com.facebook.presto.Session;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.operator.SourceOperatorFactory;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle;
import com.facebook.presto.spark.execution.PrestoSparkRemoteSourceOperator.SparkRemoteSourceOperatorFactory;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleInput;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShufflePageInput;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.storage.TempDataOperationContext;
import com.facebook.presto.spi.storage.TempStorage;
import com.facebook.presto.sql.planner.RemoteSourceFactory;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.util.CollectionAccumulator;

import java.util.Iterator;
import java.util.List;
import java.util.Map;

import static com.facebook.presto.spark.PrestoSparkSessionProperties.isStorageBasedBroadcastJoinEnabled;
import static com.facebook.presto.spark.util.PrestoSparkUtils.createPagesSerde;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class PrestoSparkRemoteSourceFactory
        implements RemoteSourceFactory
{
    private final BlockEncodingManager blockEncodingManager;
    private final Map> shuffleInputsMap;
    private final Map>> pageInputsMap;
    private final Map> broadcastInputsMap;
    private final int taskId;
    private final CollectionAccumulator shuffleStatsCollector;
    private final TempStorage tempStorage;
    private final TempDataOperationContext tempDataOperationContext;
    private final PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager;
    private final StageId stageId;

    public PrestoSparkRemoteSourceFactory(
            BlockEncodingManager blockEncodingManager,
            Map> shuffleInputsMap,
            Map>> pageInputsMap,
            Map> broadcastInputsMap,
            int taskId,
            CollectionAccumulator shuffleStatsCollector,
            TempStorage tempStorage,
            TempDataOperationContext tempDataOperationContext,
            PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager,
            StageId stageId)
    {
        this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null");
        this.shuffleInputsMap = ImmutableMap.copyOf(requireNonNull(shuffleInputsMap, "shuffleInputsMap is null"));
        this.pageInputsMap = ImmutableMap.copyOf(requireNonNull(pageInputsMap, "pageInputs is null"));
        this.broadcastInputsMap = ImmutableMap.copyOf(requireNonNull(broadcastInputsMap, "broadcastInputsMap is null"));
        this.taskId = taskId;
        this.shuffleStatsCollector = requireNonNull(shuffleStatsCollector, "shuffleStatsCollector is null");
        this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null");
        this.tempStorage = requireNonNull(tempStorage, "tempStorage is null");
        this.prestoSparkBroadcastTableCacheManager = requireNonNull(prestoSparkBroadcastTableCacheManager, "prestoSparkBroadcastTableCacheManager is null");
        this.stageId = requireNonNull(stageId, "stageId is null");
    }

    @Override
    public SourceOperatorFactory createRemoteSource(Session session, int operatorId, PlanNodeId planNodeId, List types)
    {
        List shuffleInputs = shuffleInputsMap.get(planNodeId);
        List> pageInputs = pageInputsMap.get(planNodeId);
        List broadcastInputs = broadcastInputsMap.get(planNodeId);
        checkArgument(shuffleInputs != null || pageInputs != null || broadcastInputs != null, "input not found for plan node with id %s", planNodeId);
        checkArgument(shuffleInputs == null || pageInputs == null, "single remote source cannot accept both, shuffle and page inputs");
        if (broadcastInputs != null) {
            if (isStorageBasedBroadcastJoinEnabled(session)) {
                List> diskPageInputs =
                        broadcastInputs.stream()
                                .map(input -> ((List) input).stream().map(PrestoSparkStorageHandle.class::cast).collect(toImmutableList()))
                                .collect(toImmutableList());
                return new SparkRemoteSourceOperatorFactory(
                        operatorId,
                        planNodeId,
                        new PrestoSparkDiskPageInput(
                                createPagesSerde(blockEncodingManager),
                                tempStorage,
                                tempDataOperationContext,
                                prestoSparkBroadcastTableCacheManager,
                                stageId,
                                planNodeId,
                                diskPageInputs));
            }
            else {
                List> serializedPageInputs =
                        broadcastInputs.stream()
                                .map(input -> ((List) input).stream().map(PrestoSparkSerializedPage.class::cast).iterator())
                                .collect(toImmutableList());
                return new SparkRemoteSourceOperatorFactory(
                        operatorId,
                        planNodeId,
                        new PrestoSparkSerializedPageInput(createPagesSerde(blockEncodingManager), serializedPageInputs));
            }
        }

        if (pageInputs != null) {
            return new SparkRemoteSourceOperatorFactory(
                    operatorId,
                    planNodeId,
                    new PrestoSparkSerializedPageInput(createPagesSerde(blockEncodingManager), pageInputs));
        }

        return new SparkRemoteSourceOperatorFactory(
                operatorId,
                planNodeId,
                new PrestoSparkShufflePageInput(types, shuffleInputs, taskId, shuffleStatsCollector));
    }

    @Override
    public SourceOperatorFactory createMergeRemoteSource(
            Session session,
            int operatorId,
            PlanNodeId planNodeId,
            List types,
            List outputChannels,
            List sortChannels,
            List sortOrder)
    {
        throw new UnsupportedOperationException();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy