All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.impl.paramavg;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.datavec.spark.util.BroadcastHadoopConfigHolder;
import org.deeplearning4j.core.loader.DataSetLoader;
import org.deeplearning4j.core.loader.MultiDataSetLoader;
import org.deeplearning4j.core.loader.impl.SerializedDataSetLoader;
import org.deeplearning4j.core.loader.impl.SerializedMultiDataSetLoader;
import org.deeplearning4j.core.storage.Persistable;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.StatsStorageRouterProvider;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.spark.api.*;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.*;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.core.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;

import java.io.IOException;
import java.io.OutputStream;
import java.util.*;

import static org.nd4j.shade.guava.base.Preconditions.checkArgument;

@Data
@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
                "trainingMasterUID"})
@EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
                "trainingMasterUID"})
@Slf4j
public class ParameterAveragingTrainingMaster
                extends BaseTrainingMaster
                implements TrainingMaster {

    protected static final int COALESCE_THRESHOLD = 3;


    protected boolean saveUpdater;
    protected Integer numWorkers;
    protected int rddDataSetNumExamples;

    protected int averagingFrequency;
    protected int aggregationDepth;
    protected int prefetchNumBatches;
    protected int iterationCount = 0;

    protected Collection trainingHookList;

    protected ParameterAveragingTrainingMaster() {
        // no-arg constructor for Jackson

        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID =
                        System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    protected ParameterAveragingTrainingMaster(Builder builder) {
        this.saveUpdater = builder.saveUpdater;
        this.numWorkers = builder.numWorkers;
        this.rddDataSetNumExamples = builder.rddDataSetNumExamples;
        this.batchSizePerWorker = builder.batchSizePerWorker;
        this.averagingFrequency = builder.averagingFrequency;
        this.aggregationDepth = builder.aggregationDepth;
        this.prefetchNumBatches = builder.prefetchNumBatches;
        this.repartition = builder.repartition;
        this.repartitionStrategy = builder.repartitionStrategy;
        this.storageLevel = builder.storageLevel;
        this.storageLevelStreams = builder.storageLevelStreams;
        this.rddTrainingApproach = builder.rddTrainingApproach;
        this.exportDirectory = builder.exportDirectory;
        this.trainingHookList = builder.trainingHooks;
        this.collectTrainingStats = builder.collectTrainingStats;
        if (collectTrainingStats)
            stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();


        if (builder.rngSeed == null) {
            this.rng = new Random();
        } else {
            this.rng = new Random(builder.rngSeed);
        }

        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID =
                        System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples,
                    int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches) {
        this(saveUpdater, numWorkers, rddDataSetNumExamples, batchSizePerWorker, averagingFrequency, 2,
                        prefetchNumBatches, Repartition.Always, RepartitionStrategy.Balanced, false);
    }

    /**
     * @param saveUpdater           If true: save (and average) the updater state when doing parameter averaging
     * @param numWorkers            Number of workers (executors * threads per executor) for the cluster
     * @param rddDataSetNumExamples Number of examples in each DataSet object in the {@code RDD}
     * @param batchSizePerWorker    Number of examples to use per worker per fit
     * @param averagingFrequency    Frequency (in number of minibatches) with which to average parameters
     * @param aggregationDepth      Number of aggregation levels used in parameter aggregation
     * @param prefetchNumBatches    Number of batches to asynchronously prefetch (0: disable)
     * @param repartition           Set if/when repartitioning should be conducted for the training data
     * @param repartitionStrategy   Repartitioning strategy to use. See {@link RepartitionStrategy}
     * @param collectTrainingStats  If true: collect training statistics for debugging/optimization purposes
     */
    public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples,
                    int batchSizePerWorker, int averagingFrequency, int aggregationDepth, int prefetchNumBatches,
                    Repartition repartition, RepartitionStrategy repartitionStrategy, boolean collectTrainingStats) {
        this(saveUpdater, numWorkers, rddDataSetNumExamples, batchSizePerWorker, averagingFrequency, aggregationDepth,
                        prefetchNumBatches, repartition, repartitionStrategy, StorageLevel.MEMORY_ONLY_SER(),
                        collectTrainingStats);
    }

    public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples,
                    int batchSizePerWorker, int averagingFrequency, int aggregationDepth, int prefetchNumBatches,
                    Repartition repartition, RepartitionStrategy repartitionStrategy, StorageLevel storageLevel,
                    boolean collectTrainingStats) {
        checkArgument(numWorkers > 0, "Invalid number of workers: " + numWorkers + " (must be >= 1)");
        checkArgument(rddDataSetNumExamples > 0,
                        "Invalid rdd data set size: " + rddDataSetNumExamples + " (must be >= 1)");
        checkArgument(averagingFrequency > 0, "Invalid input: averaging frequency must be >= 1");
        checkArgument(aggregationDepth > 0, "Invalid input: tree aggregation channels must be >= 1");

        this.saveUpdater = saveUpdater;
        this.numWorkers = numWorkers;
        this.rddDataSetNumExamples = rddDataSetNumExamples;
        this.batchSizePerWorker = batchSizePerWorker;
        this.averagingFrequency = averagingFrequency;
        this.aggregationDepth = aggregationDepth;
        this.prefetchNumBatches = prefetchNumBatches;
        this.collectTrainingStats = collectTrainingStats;
        this.repartition = repartition;
        this.repartitionStrategy = repartitionStrategy;
        this.storageLevel = storageLevel;
        if (collectTrainingStats)
            stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();

        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID =
                        System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }



    /**
     * Remove a training hook from the worker
     *
     * @param trainingHook the training hook to remove
     */
    @Override
    public void removeHook(TrainingHook trainingHook) {
        if (trainingHookList == null)
            return;
        trainingHookList.remove(trainingHook);
    }

    /**
     * Add a hook for the master for pre and post training
     *
     * @param trainingHook the training hook to add
     */
    @Override
    public void addHook(TrainingHook trainingHook) {
        if (trainingHookList == null) {
            trainingHookList = new ArrayList<>();
        }
        trainingHookList.add(trainingHook);
    }

    @Override
    public String toJson() {
        ObjectMapper om = getJsonMapper();

        try {
            return om.writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    @Override
    public String toYaml() {
        ObjectMapper om = getYamlMapper();

        try {
            return om.writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    /**
     * Create a ParameterAveragingTrainingMaster instance by deserializing a JSON string that has been serialized with
     * {@link #toJson()}
     *
     * @param jsonStr ParameterAveragingTrainingMaster configuration serialized as JSON
     */
    public static ParameterAveragingTrainingMaster fromJson(String jsonStr) {
        ObjectMapper om = getJsonMapper();
        try {
            return om.readValue(jsonStr, ParameterAveragingTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    /**
     * Create a ParameterAveragingTrainingMaster instance by deserializing a YAML string that has been serialized with
     * {@link #toYaml()}
     *
     * @param yamlStr ParameterAveragingTrainingMaster configuration serialized as YAML
     */
    public static ParameterAveragingTrainingMaster fromYaml(String yamlStr) {
        ObjectMapper om = getYamlMapper();
        try {
            return om.readValue(yamlStr, ParameterAveragingTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }


    @Override
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(),
                        network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray());

        if (collectTrainingStats)
            stats.logBroadcastStart();
        Broadcast broadcast = network.getSparkContext().broadcast(tuple);
        if (collectTrainingStats)
            stats.logBroadcastEnd();

        WorkerConfiguration configuration = new WorkerConfiguration(false, rddDataSetNumExamples, batchSizePerWorker,
                        averagingFrequency, prefetchNumBatches, collectTrainingStats);
        return new ParameterAveragingTrainingWorker(broadcast, saveUpdater, configuration, trainingHookList, listeners,
                        getRouterProvider());
    }

    @Override
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph) {
        NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getConfiguration(),
                        graph.getNetwork().params(), graph.getNetwork().getUpdater().getStateViewArray());

        if (collectTrainingStats)
            stats.logBroadcastStart();
        Broadcast broadcast = graph.getSparkContext().broadcast(tuple);
        if (collectTrainingStats)
            stats.logBroadcastEnd();

        WorkerConfiguration configuration = new WorkerConfiguration(true, rddDataSetNumExamples, batchSizePerWorker,
                        averagingFrequency, prefetchNumBatches, collectTrainingStats);
        return new ParameterAveragingTrainingWorker(broadcast, saveUpdater, configuration, trainingHookList, listeners,
                        getRouterProvider());
    }

    protected int numObjectsEachWorker(int numExamplesEachRddObject) {
        return batchSizePerWorker * averagingFrequency / numExamplesEachRddObject;
    }

    protected int getNumDataSetObjectsPerSplit(int numExamplesEachRddObject) {
        int dataSetObjectsPerSplit;
        if (numExamplesEachRddObject == 1) {
            dataSetObjectsPerSplit = numWorkers * batchSizePerWorker * averagingFrequency;
        } else {
            int numDataSetObjsReqEachWorker = numObjectsEachWorker(numExamplesEachRddObject);
            if (numDataSetObjsReqEachWorker < 1) {
                //In this case: more examples in a DataSet object than we actually require
                //For example, 100 examples in DataSet, with batchSizePerWorker=50 and averagingFrequency=1
                numDataSetObjsReqEachWorker = 1;
            }

            dataSetObjectsPerSplit = numDataSetObjsReqEachWorker * numWorkers;
        }
        return dataSetObjectsPerSplit;
    }

    @Override
    public void executeTraining(SparkDl4jMultiLayer network, JavaRDD trainingData) {
        if (numWorkers == null)
            numWorkers = network.getSparkContext().defaultParallelism();

        if (rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(network, trainingData);
        } else {
            //Export data if required (or, use cached export)
            JavaRDD paths = exportIfRequired(network.getSparkContext(), trainingData);
            executeTrainingPathsHelper(network, null, paths, new SerializedDataSetLoader(), null, batchSizePerWorker); //Originally (pre-export): had rddDataSetNumExamples per DataSet. Now we have batchSizePerWorker per exported DataSet
        }
    }

    protected > long getTotalDataSetObjectCount(
                    JavaRDDLike trainingData) {
        if (collectTrainingStats)
            stats.logCountStart();
        long totalDataSetObjectCount = trainingData.count();
        if (collectTrainingStats)
            stats.logCountEnd();
        return totalDataSetObjectCount;
    }

    protected  JavaPairRDD[] getSplitRDDs(JavaPairRDD trainingData,
                    int totalDataSetObjectCount) {
        int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(rddDataSetNumExamples);

        if (collectTrainingStats)
            stats.logSplitStart();
        JavaPairRDD[] splits = SparkUtils.balancedRandomSplit(totalDataSetObjectCount, dataSetObjectsPerSplit,
                        trainingData, rng.nextLong());
        if (collectTrainingStats)
            stats.logSplitEnd();
        return splits;
    }

    protected  JavaRDD[] getSplitRDDs(JavaRDD trainingData, int totalDataSetObjectCount,
                    int examplesPerDataSetObject) {
        int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(examplesPerDataSetObject);

        if (collectTrainingStats)
            stats.logSplitStart();
        JavaRDD[] splits = SparkUtils.balancedRandomSplit(totalDataSetObjectCount, dataSetObjectsPerSplit,
                        trainingData, rng.nextLong());
        if (collectTrainingStats)
            stats.logSplitEnd();
        return splits;
    }

    protected void executeTrainingDirect(SparkDl4jMultiLayer network, JavaRDD trainingData) {
        if (collectTrainingStats)
            stats.logFitStart();
        //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
        // number of minibatches between averagings
        //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
        if (storageLevel != null)
            trainingData.persist(storageLevel);

        long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);
        JavaRDD[] splits = getSplitRDDs(trainingData, (int) totalDataSetObjectCount, rddDataSetNumExamples);

        int splitNum = 1;
        for (JavaRDD split : splits) {
            doIteration(network, split, splitNum++, splits.length);
        }

        if (collectTrainingStats)
            stats.logFitEnd((int) totalDataSetObjectCount);
    }

    @Override
    public void executeTrainingPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD trainingDataPaths, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader){
        executeTrainingPathsHelper(network, graph, trainingDataPaths, dsLoader, mdsLoader, rddDataSetNumExamples);
    }

    protected void executeTrainingPathsHelper(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD trainingDataPaths,
                                              DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, int dataSetObjectsNumExamples) {
        if (numWorkers == null)
            numWorkers = network.getSparkContext().defaultParallelism();

        if (collectTrainingStats)
            stats.logFitStart();
        if (storageLevelStreams != null)
            trainingDataPaths.persist(storageLevelStreams);

        long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingDataPaths);
        JavaRDD[] splits =
                        getSplitRDDs(trainingDataPaths, (int) totalDataSetObjectCount, dataSetObjectsNumExamples);

        int splitNum = 1;
        for (JavaRDD split : splits) {
            doIterationPaths(network, graph, split, splitNum++, splits.length, dataSetObjectsNumExamples, dsLoader, mdsLoader);
        }

        if (collectTrainingStats)
            stats.logFitEnd((int) totalDataSetObjectCount);
    }

    @Override
    public void executeTraining(SparkComputationGraph graph, JavaRDD trainingData) {
        if (numWorkers == null)
            numWorkers = graph.getSparkContext().defaultParallelism();

        JavaRDD mdsTrainingData = trainingData.map(new DataSetToMultiDataSetFn());

        executeTrainingMDS(graph, mdsTrainingData);
    }

    @Override
    public void executeTrainingMDS(SparkComputationGraph graph, JavaRDD trainingData) {
        if (numWorkers == null)
            numWorkers = graph.getSparkContext().defaultParallelism();

        if (rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(graph, trainingData);
        } else {
            //Export data if required (or, use cached export)
            JavaRDD paths = exportIfRequiredMDS(graph.getSparkContext(), trainingData);
            executeTrainingPathsHelper(null, graph, paths, null, new SerializedMultiDataSetLoader(), batchSizePerWorker);
        }
    }

    protected void executeTrainingDirect(SparkComputationGraph graph, JavaRDD trainingData) {
        if (collectTrainingStats)
            stats.logFitStart();
        //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
        // number of minibatches between averaging
        //But to do that, we need to know: (a) the number of examples, and (b) the number of workers
        if (storageLevel != null)
            trainingData.persist(storageLevel);

        long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);

        JavaRDD[] splits =
                        getSplitRDDs(trainingData, (int) totalDataSetObjectCount, rddDataSetNumExamples);

        int splitNum = 1;
        for (JavaRDD split : splits) {
            doIteration(graph, split, splitNum++, splits.length);
        }

        if (collectTrainingStats)
            stats.logFitEnd((int) totalDataSetObjectCount);
    }

    @Override
    public void setCollectTrainingStats(boolean collectTrainingStats) {
        this.collectTrainingStats = collectTrainingStats;
        if (collectTrainingStats) {
            if (this.stats == null)
                this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        } else {
            this.stats = null;
        }
    }

    @Override
    public boolean getIsCollectTrainingStats() {
        return collectTrainingStats;
    }

    @Override
    public SparkTrainingStats getTrainingStats() {
        if (stats != null)
            return stats.build();
        return null;
    }

    @Override
    public void setListeners(Collection listeners) {
        setListeners(null, listeners);
    }

    @Override
    public void setListeners(StatsStorageRouter statsStorage, Collection listeners) {
        this.statsStorage = statsStorage;
        this.listeners = listeners == null ? null : new ArrayList<>(listeners);
    }



    protected void doIteration(SparkDl4jMultiLayer network, JavaRDD split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
                        splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
        if (collectTrainingStats)
            stats.logMapPartitionsStart();

        JavaRDD splitData = split;
        if (collectTrainingStats)
            stats.logRepartitionStart();
        splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy,
                        numObjectsEachWorker(rddDataSetNumExamples), numWorkers);
        int nPartitions = splitData.partitions().size();
        if (collectTrainingStats && repartition != Repartition.Never)
            stats.logRepartitionEnd();


        FlatMapFunction, ParameterAveragingTrainingResult> function =
                        new ExecuteWorkerFlatMap<>(getWorkerInstance(network));
        JavaRDD result = splitData.mapPartitions(function);
        processResults(network, null, result, splitNum, numSplits);

        if (collectTrainingStats)
            stats.logMapPartitionsEnd(nPartitions);
    }

    @Deprecated
    protected void doIterationPDS(SparkDl4jMultiLayer network, SparkComputationGraph graph,
                    JavaRDD split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
                        splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
        if (collectTrainingStats)
            stats.logMapPartitionsStart();

        JavaRDD splitData = split;
        if (collectTrainingStats)
            stats.logRepartitionStart();
        splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy,
                        numObjectsEachWorker(rddDataSetNumExamples), numWorkers);
        int nPartitions = splitData.partitions().size();
        if (collectTrainingStats && repartition != Repartition.Never)
            stats.logRepartitionEnd();

        FlatMapFunction, ParameterAveragingTrainingResult> function;
        if (network != null)
            function = new ExecuteWorkerPDSFlatMap<>(getWorkerInstance(network));
        else
            function = new ExecuteWorkerPDSFlatMap<>(getWorkerInstance(graph));

        JavaRDD result = splitData.mapPartitions(function);
        processResults(network, graph, result, splitNum, numSplits);

        if (collectTrainingStats)
            stats.logMapPartitionsEnd(nPartitions);
    }

    protected void doIterationPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD split,
                    int splitNum, int numSplits, int dataSetObjectNumExamples, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
                        splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
        if (collectTrainingStats)
            stats.logMapPartitionsStart();

        JavaRDD splitData = split;
        if (collectTrainingStats)
            stats.logRepartitionStart();
        splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy,
                        numObjectsEachWorker(dataSetObjectNumExamples), numWorkers);
        int nPartitions = splitData.partitions().size();
        if (collectTrainingStats && repartition != Repartition.Never)
            stats.logRepartitionEnd();

        JavaSparkContext sc = (network != null ? network.getSparkContext() : graph.getSparkContext());
        FlatMapFunction, ParameterAveragingTrainingResult> function;
        if (network != null) {
            if(dsLoader != null){
                function = new ExecuteWorkerPathFlatMap<>(getWorkerInstance(network), dsLoader, BroadcastHadoopConfigHolder.get(sc));
            } else {
                function = new ExecuteWorkerPathMDSFlatMap<>(getWorkerInstance(network), mdsLoader, BroadcastHadoopConfigHolder.get(sc));
            }
        } else {
            if(dsLoader != null){
                function = new ExecuteWorkerPathFlatMap<>(getWorkerInstance(graph), dsLoader, BroadcastHadoopConfigHolder.get(sc));
            } else {
                function = new ExecuteWorkerPathMDSFlatMap<>(getWorkerInstance(graph), mdsLoader, BroadcastHadoopConfigHolder.get(sc));
            }
        }

        JavaRDD result = splitData.mapPartitions(function);
        processResults(network, graph, result, splitNum, numSplits);

        if (collectTrainingStats)
            stats.logMapPartitionsEnd(nPartitions);
    }

    protected void doIteration(SparkComputationGraph graph, JavaRDD split, int splitNum, int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
                        splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
        if (collectTrainingStats)
            stats.logMapPartitionsStart();

        JavaRDD splitData = split;

        splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy,
                        numObjectsEachWorker(rddDataSetNumExamples), numWorkers);
        int nPartitions = split.partitions().size();

        FlatMapFunction, ParameterAveragingTrainingResult> function =
                        new ExecuteWorkerMultiDataSetFlatMap<>(getWorkerInstance(graph));
        JavaRDD result = splitData.mapPartitions(function);
        processResults(null, graph, result, splitNum, numSplits);

        if (collectTrainingStats)
            stats.logMapPartitionsEnd(nPartitions);
    }

    protected void doIterationPDS_MDS(SparkComputationGraph graph, JavaRDD split, int splitNum,
                    int numSplits) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
                        splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
        if (collectTrainingStats)
            stats.logMapPartitionsStart();

        JavaRDD splitData = split;
        if (collectTrainingStats)
            stats.logRepartitionStart();
        splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy,
                        numObjectsEachWorker(rddDataSetNumExamples), numWorkers);
        int nPartitions = splitData.partitions().size();
        if (collectTrainingStats && repartition != Repartition.Never)
            stats.logRepartitionEnd();

        FlatMapFunction, ParameterAveragingTrainingResult> function =
                        new ExecuteWorkerPDSMDSFlatMap<>(getWorkerInstance(graph));

        JavaRDD result = splitData.mapPartitions(function);
        processResults(null, graph, result, splitNum, numSplits);

        if (collectTrainingStats)
            stats.logMapPartitionsEnd(nPartitions);
    }


    protected void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph,
                    JavaRDD results, int splitNum, int totalSplits) {
        //Need to do parameter averaging, and where necessary also do averaging of the updaters
        //Let's do all of this in ONE step, such that we don't have extra synchronization costs

        if (collectTrainingStats)
            stats.logAggregateStartTime();
        ParameterAveragingAggregationTuple tuple =
                        results.treeAggregate(null, new ParameterAveragingElementAddFunction(),
                                        new ParameterAveragingElementCombineFunction(), this.aggregationDepth);
        INDArray params = tuple.getParametersSum();
        int aggCount = tuple.getAggregationsCount();
        SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
        if (collectTrainingStats)
            stats.logAggregationEndTime();


        if (collectTrainingStats)
            stats.logProcessParamsUpdaterStart();
        if (params != null) {
            params.divi(aggCount);
            INDArray updaterState = tuple.getUpdaterStateSum();
            if (updaterState != null)
                updaterState.divi(aggCount); //May be null if all SGD updaters, for example

            if (network != null) {
                MultiLayerNetwork net = network.getNetwork();
                net.setParameters(params);
                if (updaterState != null)
                    net.getUpdater().setStateViewArray(null, updaterState, false);

                network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
            } else {
                ComputationGraph g = graph.getNetwork();
                g.setParams(params);
                if (updaterState != null)
                    g.getUpdater().setStateViewArray(updaterState);

                graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
            }
        } else {
            log.info("Skipping imbalanced split with no data for all executors");
        }



        if (collectTrainingStats) {
            stats.logProcessParamsUpdaterEnd();
            stats.addWorkerStats(aggregatedStats);
        }

        if (statsStorage != null) {
            Collection meta = tuple.getListenerMetaData();
            if (meta != null && !meta.isEmpty()) {
                statsStorage.putStorageMetaData(meta);
            }

            Collection staticInfo = tuple.getListenerStaticInfo();
            if (staticInfo != null && !staticInfo.isEmpty()) {
                statsStorage.putStaticInfo(staticInfo);
            }

            Collection updates = tuple.getListenerUpdates();
            if (updates != null && !updates.isEmpty()) {
                statsStorage.putUpdate(updates);
            }
        }

        Nd4j.getExecutioner().commit();

        log.info("Completed training of split {} of {}", splitNum, totalSplits);

        if (params != null) {
            //Params may be null for edge case (empty RDD)
            if (network != null) {
                MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
                int numUpdates = averagingFrequency;
                conf.setIterationCount(conf.getIterationCount() + numUpdates);
            } else {
                ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
                int numUpdates = averagingFrequency;
                conf.setIterationCount(conf.getIterationCount() + numUpdates);
            }
        }
    }



    protected StatsStorageRouterProvider getRouterProvider() {
        if (statsStorage == null)
            return null; //Not needed
        return new VanillaStatsStorageRouterProvider();
    }


    public static class Builder {
        protected boolean saveUpdater;
        protected Integer numWorkers;
        protected int rddDataSetNumExamples;
        protected int batchSizePerWorker = 16;
        protected int averagingFrequency = 5;
        protected int aggregationDepth = 2;
        protected int prefetchNumBatches = 0;
        protected Repartition repartition = Repartition.Always;
        protected RepartitionStrategy repartitionStrategy = RepartitionStrategy.Balanced;
        protected StorageLevel storageLevel = StorageLevel.MEMORY_ONLY_SER();
        protected StorageLevel storageLevelStreams = StorageLevel.MEMORY_ONLY();
        protected RDDTrainingApproach rddTrainingApproach = RDDTrainingApproach.Export;
        protected String exportDirectory = null;
        protected Long rngSeed;
        protected Collection trainingHooks;
        protected boolean collectTrainingStats = false;


        /**
         * Adds training hooks to the master.
         * The training master will setup the workers
         * with the desired hooks for training.
         * This can allow for tings like parameter servers
         * and async updates as well as collecting statistics.
         *
         * @param trainingHooks the training hooks to ad
         * @return
         */
        public Builder trainingHooks(Collection trainingHooks) {
            this.trainingHooks = trainingHooks;
            return this;
        }

        /**
         * Adds training hooks to the master.
         * The training master will setup the workers
         * with the desired hooks for training.
         * This can allow for tings like parameter servers
         * and async updates as well as collecting statistics.
         * @param hooks the training hooks to ad
         * @return
         */
        public Builder trainingHooks(TrainingHook... hooks) {
            this.trainingHooks = Arrays.asList(hooks);
            return this;
        }

        /**
         * Same as {@link #Builder(Integer, int)} but automatically set number of workers based on JavaSparkContext.defaultParallelism()
         *
         * @param rddDataSetNumExamples Number of examples in each DataSet object in the {@code RDD}
         */
        public Builder(int rddDataSetNumExamples) {
            this(null, rddDataSetNumExamples);
        }

        /**
         * Create a builder, where the following number of workers (Spark executors * number of threads per executor) are
         * being used.
* Note: this should match the configuration of the cluster.
*

* It is also necessary to specify how many examples are in each DataSet that appears in the {@code RDD} * or {@code JavaRDD} used for training.
* Two most common cases here:
* (a) Preprocessed data pipelines will often load binary DataSet objects with N > 1 examples in each; in this case, * rddDataSetNumExamples should be set to N
* (b) "In line" data pipelines (for example, CSV String -> record reader -> DataSet just before training) will * typically have exactly 1 example in each DataSet object. In this case, rddDataSetNumExamples should be set to 1 * * @param numWorkers Number of Spark execution threads in the cluster. May be null. If null: number of workers will * be obtained from JavaSparkContext.defaultParallelism(), which should provide the number of cores * in the cluster. * @param rddDataSetNumExamples Number of examples in each DataSet object in the {@code RDD} */ public Builder(Integer numWorkers, int rddDataSetNumExamples) { checkArgument(numWorkers == null || numWorkers > 0, "Invalid number of workers: " + numWorkers + " (must be >= 1)"); checkArgument(rddDataSetNumExamples > 0, "Invalid rdd data set size: " + rddDataSetNumExamples + " (must be >= 1)"); this.numWorkers = numWorkers; this.rddDataSetNumExamples = rddDataSetNumExamples; } /** * Batch size (in number of examples) per worker, for each fit(DataSet) call. * * @param batchSizePerWorker Size of each minibatch to use for each worker * @return */ public Builder batchSizePerWorker(int batchSizePerWorker) { this.batchSizePerWorker = batchSizePerWorker; return this; } /** * Frequency with which to average worker parameters.
* Note: Too high or too low can be bad for different reasons.
* - Too low (such as 1) can result in a lot of network traffic
* - Too high (>> 20 or so) can result in accuracy issues or problems with network convergence * * @param averagingFrequency Frequency (in number of minibatches of size 'batchSizePerWorker') to average parameters */ public Builder averagingFrequency(int averagingFrequency) { checkArgument(averagingFrequency > 0, "Invalid input: averaging frequency must be >= 1"); this.averagingFrequency = averagingFrequency; return this; } /** * The number of levels in the aggregation tree for parameter synchronization. (default: 2) * Note: For large models trained with many partitions, increasing this number * will reduce the load on the driver and help prevent it from becoming a bottleneck.
* * @param aggregationDepth RDD tree aggregation channels when averaging parameter updates. */ public Builder aggregationDepth(int aggregationDepth) { checkArgument(aggregationDepth > 0, "Invalid input: tree aggregation channels must be >= 1"); this.aggregationDepth = aggregationDepth; return this; } /** * Set the number of minibatches to asynchronously prefetch in the worker. *

* Default: 0 (no prefetching) * * @param prefetchNumBatches Number of minibatches (DataSets of size batchSizePerWorker) to fetch */ public Builder workerPrefetchNumBatches(int prefetchNumBatches) { this.prefetchNumBatches = prefetchNumBatches; return this; } /** * Set whether the updater (i.e., historical state for momentum, adagrad, etc should be saved). * NOTE: This can double (or more) the amount of network traffic in each direction, but might * improve network training performance (and can be more stable for certain updaters such as adagrad).
*

* This is enabled by default. * * @param saveUpdater If true: retain the updater state (default). If false, don't retain (updaters will be * reinitalized in each worker after averaging). */ public Builder saveUpdater(boolean saveUpdater) { this.saveUpdater = saveUpdater; return this; } /** * Set if/when repartitioning should be conducted for the training data.
* Default value: always repartition (if required to guarantee correct number of partitions and correct number * of examples in each partition). * * @param repartition Setting for repartitioning */ public Builder repartionData(Repartition repartition) { this.repartition = repartition; return this; } /** * Used in conjunction with {@link #repartionData(Repartition)} (which defines when repartitioning should be * conducted), repartitionStrategy defines how the repartitioning should be done. See {@link RepartitionStrategy} * for details * * @param repartitionStrategy Repartitioning strategy to use */ public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) { this.repartitionStrategy = repartitionStrategy; return this; } /** * Set the storage level for {@code RDD}s.
* Default: StorageLevel.MEMORY_ONLY_SER() - i.e., store in memory, in serialized form
* To use no RDD persistence, use {@code null}
*

* Note: Spark's StorageLevel.MEMORY_ONLY() and StorageLevel.MEMORY_AND_DISK() can be problematic when * it comes to off-heap data (which DL4J/ND4J uses extensively). Spark does not account for off-heap memory * when deciding if/when to drop blocks to ensure enough free memory; consequently, for DataSet RDDs that are * larger than the total amount of (off-heap) memory, this can lead to OOM issues. Put another way: Spark counts * the on-heap size of DataSet and INDArray objects only (which is negligible) resulting in a significant * underestimate of the true DataSet object sizes. More DataSets are thus kept in memory than we can really afford. * * @param storageLevel Storage level to use for DataSet RDDs */ public Builder storageLevel(StorageLevel storageLevel) { this.storageLevel = storageLevel; return this; } /** * Set the storage level RDDs used when fitting data from Streams: either PortableDataStreams (sc.binaryFiles via * {@link SparkDl4jMultiLayer#fit(String)} and {@link SparkComputationGraph#fit(String)}) or String paths * (via {@link SparkDl4jMultiLayer#fitPaths(JavaRDD)}, {@link SparkComputationGraph#fitPaths(JavaRDD)} and * {@link SparkComputationGraph#fitPathsMultiDataSet(JavaRDD)}).
*

* Default storage level is StorageLevel.MEMORY_ONLY() which should be appropriate in most cases. * * @param storageLevelStreams Storage level to use */ public Builder storageLevelStreams(StorageLevel storageLevelStreams) { this.storageLevelStreams = storageLevelStreams; return this; } /** * The approach to use when training on a {@code RDD} or {@code RDD}. * Default: {@link RDDTrainingApproach#Export}, which exports data to a temporary directory first * * @param rddTrainingApproach Training approach to use when training from a {@code RDD} or {@code RDD} */ public Builder rddTrainingApproach(RDDTrainingApproach rddTrainingApproach) { this.rddTrainingApproach = rddTrainingApproach; return this; } /** * When {@link #rddTrainingApproach(RDDTrainingApproach)} is set to {@link RDDTrainingApproach#Export} (as it is by default) * the data is exported to a temporary directory first. *

* Default: null. -> use {hadoop.tmp.dir}/dl4j/. In this case, data is exported to {hadoop.tmp.dir}/dl4j/SOME_UNIQUE_ID/
* If you specify a directory, the directory {exportDirectory}/SOME_UNIQUE_ID/ will be used instead. * * @param exportDirectory Base directory to export data */ public Builder exportDirectory(String exportDirectory) { this.exportDirectory = exportDirectory; return this; } /** * Random number generator seed, used mainly for enforcing repeatable splitting on RDDs * Default: no seed set (i.e., random seed) * * @param rngSeed RNG seed * @return */ public Builder rngSeed(long rngSeed) { this.rngSeed = rngSeed; return this; } /** * Whether training stats collection should be enabled (disabled by default). * @see ParameterAveragingTrainingMaster#setCollectTrainingStats(boolean) * @see org.deeplearning4j.spark.stats.StatsUtils#exportStatsAsHTML(SparkTrainingStats, OutputStream) * @param collectTrainingStats */ public Builder collectTrainingStats(boolean collectTrainingStats){ this.collectTrainingStats = collectTrainingStats; return this; } public ParameterAveragingTrainingMaster build() { return new ParameterAveragingTrainingMaster(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy