org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster Maven / Gradle / Ivy
package org.deeplearning4j.spark.impl.paramavg;
import lombok.Data;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.*;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.util.SparkUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;
/**
* ParameterAveragingTrainingMaster: A {@link TrainingMaster} implementation for training networks on Spark.
* This is standard parameter averaging with a configurable averaging period.
*
* @author Alex Black
*/
@Data
public class ParameterAveragingTrainingMaster implements TrainingMaster {
private static final Logger log = LoggerFactory.getLogger(ParameterAveragingTrainingMaster.class);
private static final int COALESCE_THRESHOLD = 3;
private boolean saveUpdater;
private Integer numWorkers;
private int rddDataSetNumExamples;
private int batchSizePerWorker;
private int averagingFrequency;
private int prefetchNumBatches;
private boolean collectTrainingStats;
private ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
private Collection listeners;
private int iterationCount = 0;
private Repartition repartition;
private RepartitionStrategy repartitionStrategy;
private ParameterAveragingTrainingMaster(Builder builder) {
this.saveUpdater = builder.saveUpdater;
this.numWorkers = builder.numWorkers;
this.rddDataSetNumExamples = builder.rddDataSetNumExamples;
this.batchSizePerWorker = builder.batchSizePerWorker;
this.averagingFrequency = builder.averagingFrequency;
this.prefetchNumBatches = builder.prefetchNumBatches;
this.repartition = builder.repartition;
this.repartitionStrategy = builder.repartitionStrategy;
}
public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker,
int averagingFrequency, int prefetchNumBatches) {
this(saveUpdater, numWorkers, rddDataSetNumExamples, batchSizePerWorker, averagingFrequency, prefetchNumBatches,
Repartition.Always, RepartitionStrategy.Balanced, false);
}
/**
* @param saveUpdater If true: save (and average) the updater state when doing parameter averaging
* @param numWorkers Number of workers (executors * threads per executor) for the cluster
* @param rddDataSetNumExamples Number of examples in each DataSet object in the {@code RDD}
* @param batchSizePerWorker Number of examples to use per worker per fit
* @param averagingFrequency Frequency (in number of minibatches) with which to average parameters
* @param prefetchNumBatches Number of batches to asynchronously prefetch (0: disable)
* @param collectTrainingStats If true: collect training statistics for debugging/optimization purposes
*/
public ParameterAveragingTrainingMaster(boolean saveUpdater, Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker,
int averagingFrequency, int prefetchNumBatches, Repartition repartition,
RepartitionStrategy repartitionStrategy, boolean collectTrainingStats) {
if (numWorkers <= 0)
throw new IllegalArgumentException("Invalid number of workers: " + numWorkers + " (must be >= 1)");
if (rddDataSetNumExamples <= 0)
throw new IllegalArgumentException("Invalid rdd data set size: " + rddDataSetNumExamples + " (must be >= 1)");
this.saveUpdater = saveUpdater;
this.numWorkers = numWorkers;
this.rddDataSetNumExamples = rddDataSetNumExamples;
this.batchSizePerWorker = batchSizePerWorker;
this.averagingFrequency = averagingFrequency;
this.prefetchNumBatches = prefetchNumBatches;
this.collectTrainingStats = collectTrainingStats;
this.repartition = repartition;
this.repartitionStrategy = repartitionStrategy;
if (collectTrainingStats)
stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
}
@Override
public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) {
NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(),
network.getNetwork().params(),
network.getNetwork().getUpdater().getStateViewArray());
if (collectTrainingStats) stats.logBroadcastStart();
Broadcast broadcast = network.getSparkContext().broadcast(tuple);
if (collectTrainingStats) stats.logBroadcastEnd();
WorkerConfiguration configuration = new WorkerConfiguration(false, batchSizePerWorker, averagingFrequency, prefetchNumBatches, collectTrainingStats);
return new ParameterAveragingTrainingWorker(broadcast, saveUpdater, configuration);
}
@Override
public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph) {
NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getConfiguration(),
graph.getNetwork().params(),
graph.getNetwork().getUpdater().getStateViewArray());
if (collectTrainingStats) stats.logBroadcastStart();
Broadcast broadcast = graph.getSparkContext().broadcast(tuple);
if (collectTrainingStats) stats.logBroadcastEnd();
WorkerConfiguration configuration = new WorkerConfiguration(true, batchSizePerWorker, averagingFrequency, prefetchNumBatches, collectTrainingStats);
return new ParameterAveragingTrainingWorker(broadcast, saveUpdater, configuration);
}
private int numObjectsEachWorker() {
return batchSizePerWorker * averagingFrequency / rddDataSetNumExamples;
}
private int getNumDataSetObjectsPerSplit() {
int dataSetObjectsPerSplit;
if (rddDataSetNumExamples == 1) {
dataSetObjectsPerSplit = numWorkers * batchSizePerWorker * averagingFrequency;
} else {
int numDataSetObjsReqEachWorker = numObjectsEachWorker();
if (numDataSetObjsReqEachWorker < 1) {
//In this case: more examples in a DataSet object than we actually require
//For example, 100 examples in DataSet, with batchSizePerWorker=50 and averagingFrequency=1
numDataSetObjsReqEachWorker = 1;
}
dataSetObjectsPerSplit = numDataSetObjsReqEachWorker * numWorkers;
}
return dataSetObjectsPerSplit;
}
@Override
public void executeTraining(SparkDl4jMultiLayer network, JavaRDD trainingData) {
if (numWorkers == null) numWorkers = network.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
//For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
// number of minibatches between averagings
//But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
trainingData.persist(StorageLevel.MEMORY_ONLY());
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingData.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData);
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaRDD split : splits) {
doIteration(network, split, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTraining(SparkDl4jMultiLayer network, JavaPairRDD trainingData) {
if (numWorkers == null) numWorkers = network.getSparkContext().defaultParallelism();
int origNumPartitions = trainingData.partitions().size();
if (origNumPartitions >= COALESCE_THRESHOLD * numWorkers) {
log.info("Coalesing PortableDataStreams from {} to {} partitions", origNumPartitions, numWorkers);
trainingData = trainingData.coalesce(numWorkers);
}
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingData.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaPairRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData);
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaPairRDD split : splits) {
JavaRDD streams = split.values();
doIterationPDS(network, null, streams, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTrainingPaths(SparkDl4jMultiLayer network, JavaRDD trainingDataPaths){
if (numWorkers == null) numWorkers = network.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingDataPaths.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingDataPaths);
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaRDD split : splits) {
doIterationPaths(network, null, split, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTraining(SparkComputationGraph graph, JavaRDD trainingData) {
if (numWorkers == null) numWorkers = graph.getSparkContext().defaultParallelism();
JavaRDD mdsTrainingData = trainingData.map(new DataSetToMultiDataSetFn());
executeTrainingMDS(graph, mdsTrainingData);
}
@Override
public void executeTrainingMDS(SparkComputationGraph graph, JavaRDD trainingData) {
if (numWorkers == null) numWorkers = graph.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
//For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
// number of minibatches between averagings
//But to do that, we need to know: (a) the number of examples, and (b) the number of workers
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingData.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData);
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaRDD split : splits) {
doIteration(graph, split, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTraining(SparkComputationGraph graph, JavaPairRDD trainingData) {
if (numWorkers == null) numWorkers = graph.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
//For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
// number of minibatches between averagings
//But to do that, we need to know: (a) the number of examples, and (b) the number of workers
int origNumPartitions = trainingData.partitions().size();
if (origNumPartitions >= COALESCE_THRESHOLD * numWorkers) {
log.info("Coalesing streams from {} to {} partitions", origNumPartitions, numWorkers);
trainingData = trainingData.coalesce(numWorkers);
}
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingData.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaPairRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData, new Random().nextLong());
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaPairRDD split : splits) {
JavaRDD streams = split.values();
doIterationPDS(null, graph, streams, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTrainingMDS(SparkComputationGraph graph, JavaPairRDD trainingData) {
if (numWorkers == null) numWorkers = graph.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingData.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaPairRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingData, new Random().nextLong());
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaPairRDD split : splits) {
JavaRDD streams = split.values();
if (collectTrainingStats) stats.logRepartitionStart();
streams = SparkUtils.repartition(streams, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
if (collectTrainingStats && repartition != Repartition.Never) stats.logRepartitionEnd();
doIterationPDS_MDS(graph, streams, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTrainingPaths(SparkComputationGraph network, JavaRDD trainingDataPaths){
if (numWorkers == null) numWorkers = network.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingDataPaths.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingDataPaths);
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaRDD split : splits) {
doIterationPaths(null, network, split, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void executeTrainingPathsMDS(SparkComputationGraph network, JavaRDD trainingMultiDataPaths){
if (numWorkers == null) numWorkers = network.getSparkContext().defaultParallelism();
if (collectTrainingStats) stats.logFitStart();
if (collectTrainingStats) stats.logCountStart();
long totalDataSetObjectCount = trainingMultiDataPaths.count();
if (collectTrainingStats) stats.logCountEnd();
int dataSetObjectsPerSplit = getNumDataSetObjectsPerSplit();
if (collectTrainingStats) stats.logSplitStart();
JavaRDD[] splits = SparkUtils.balancedRandomSplit((int) totalDataSetObjectCount, dataSetObjectsPerSplit, trainingMultiDataPaths);
if (collectTrainingStats) stats.logSplitEnd();
int splitNum = 1;
for (JavaRDD split : splits) {
doIterationPathsMDS(network, split, splitNum++, splits.length);
}
if (collectTrainingStats) stats.logFitEnd((int) totalDataSetObjectCount);
}
@Override
public void setCollectTrainingStats(boolean collectTrainingStats) {
this.collectTrainingStats = collectTrainingStats;
if (collectTrainingStats) {
if (this.stats == null)
this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
} else {
this.stats = null;
}
}
@Override
public boolean getIsCollectTrainingStats() {
return collectTrainingStats;
}
@Override
public SparkTrainingStats getTrainingStats() {
if (stats != null) return stats.build();
return null;
}
private void doIteration(SparkDl4jMultiLayer network, JavaRDD split, int splitNum, int numSplits) {
log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
if (collectTrainingStats) stats.logMapPartitionsStart();
JavaRDD splitData = split;
if (collectTrainingStats) stats.logRepartitionStart();
splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
int nPartitions = splitData.partitions().size();
if (collectTrainingStats && repartition != Repartition.Never) stats.logRepartitionEnd();
FlatMapFunction, ParameterAveragingTrainingResult> function = new ExecuteWorkerFlatMap<>(getWorkerInstance(network));
JavaRDD result = splitData.mapPartitions(function);
processResults(network, null, result, splitNum, numSplits);
if (collectTrainingStats) stats.logMapPartitionsEnd(nPartitions);
}
private void doIterationPDS(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD split, int splitNum, int numSplits) {
log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
if (collectTrainingStats) stats.logMapPartitionsStart();
JavaRDD splitData = split;
if (collectTrainingStats) stats.logRepartitionStart();
splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
int nPartitions = splitData.partitions().size();
if (collectTrainingStats && repartition != Repartition.Never) stats.logRepartitionEnd();
FlatMapFunction, ParameterAveragingTrainingResult> function;
if (network != null) function = new ExecuteWorkerPDSFlatMap<>(getWorkerInstance(network));
else function = new ExecuteWorkerPDSFlatMap<>(getWorkerInstance(graph));
JavaRDD result = splitData.mapPartitions(function);
processResults(network, graph, result, splitNum, numSplits);
if (collectTrainingStats) stats.logMapPartitionsEnd(nPartitions);
}
private void doIterationPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD split, int splitNum, int numSplits) {
log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
if (collectTrainingStats) stats.logMapPartitionsStart();
JavaRDD splitData = split;
if (collectTrainingStats) stats.logRepartitionStart();
splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
int nPartitions = splitData.partitions().size();
if (collectTrainingStats && repartition != Repartition.Never) stats.logRepartitionEnd();
FlatMapFunction, ParameterAveragingTrainingResult> function;
if (network != null) function = new ExecuteWorkerPathFlatMap<>(getWorkerInstance(network));
else function = new ExecuteWorkerPathFlatMap<>(getWorkerInstance(graph));
JavaRDD result = splitData.mapPartitions(function);
processResults(network, graph, result, splitNum, numSplits);
if (collectTrainingStats) stats.logMapPartitionsEnd(nPartitions);
}
private void doIterationPathsMDS(SparkComputationGraph graph, JavaRDD split, int splitNum, int numSplits) {
log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
if (collectTrainingStats) stats.logMapPartitionsStart();
JavaRDD splitData = split;
if (collectTrainingStats) stats.logRepartitionStart();
splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
int nPartitions = splitData.partitions().size();
if (collectTrainingStats && repartition != Repartition.Never) stats.logRepartitionEnd();
FlatMapFunction, ParameterAveragingTrainingResult> function
= new ExecuteWorkerPathMDSFlatMap<>(getWorkerInstance(graph));
JavaRDD result = splitData.mapPartitions(function);
processResults(null, graph, result, splitNum, numSplits);
if (collectTrainingStats) stats.logMapPartitionsEnd(nPartitions);
}
private void doIteration(SparkComputationGraph graph, JavaRDD split, int splitNum, int numSplits) {
log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
if (collectTrainingStats) stats.logMapPartitionsStart();
JavaRDD splitData = split;
splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
int nPartitions = split.partitions().size();
FlatMapFunction, ParameterAveragingTrainingResult> function = new ExecuteWorkerMultiDataSetFlatMap<>(getWorkerInstance(graph));
JavaRDD result = splitData.mapPartitions(function);
processResults(null, graph, result, splitNum, numSplits);
if (collectTrainingStats) stats.logMapPartitionsEnd(nPartitions);
}
private void doIterationPDS_MDS(SparkComputationGraph graph, JavaRDD split, int splitNum, int numSplits) {
log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers",
splitNum, numSplits, batchSizePerWorker, averagingFrequency, numWorkers);
if (collectTrainingStats) stats.logMapPartitionsStart();
JavaRDD splitData = split;
if (collectTrainingStats) stats.logRepartitionStart();
splitData = SparkUtils.repartition(splitData, repartition, repartitionStrategy, numObjectsEachWorker(), numWorkers);
int nPartitions = splitData.partitions().size();
if (collectTrainingStats && repartition != Repartition.Never) stats.logRepartitionEnd();
FlatMapFunction, ParameterAveragingTrainingResult> function = new ExecuteWorkerPDSMDSFlatMap<>(getWorkerInstance(graph));
JavaRDD result = splitData.mapPartitions(function);
processResults(null, graph, result, splitNum, numSplits);
if (collectTrainingStats) stats.logMapPartitionsEnd(nPartitions);
}
private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD results, int splitNum, int totalSplits) {
//Need to do parameter averaging, and where necessary also do averaging of the updaters
//Let's do all of this in ONE step, such that we don't have extra synchronization costs
if (collectTrainingStats) stats.logAggregateStartTime();
ParameterAveragingAggregationTuple tuple = results.aggregate(null,
new ParameterAveragingElementAddFunction(),
new ParameterAveragingElementCombineFunction());
INDArray params = tuple.getParametersSum();
int aggCount = tuple.getAggregationsCount();
SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
if (collectTrainingStats) stats.logAggregationEndTime();
if (collectTrainingStats) stats.logProcessParamsUpdaterStart();
params.divi(aggCount);
INDArray updaterState = tuple.getUpdaterStateSum();
if (updaterState != null) updaterState.divi(aggCount); //May be null if all SGD updaters, for example
if (network != null) {
MultiLayerNetwork net = network.getNetwork();
net.setParameters(params);
if (updaterState != null) net.getUpdater().setStateViewArray(null, updaterState, false);
network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
} else {
ComputationGraph g = graph.getNetwork();
g.setParams(params);
if (updaterState != null) g.getUpdater().setStateViewArray(updaterState);
graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
}
if (collectTrainingStats) {
stats.logProcessParamsUpdaterEnd();
stats.addWorkerStats(aggregatedStats);
}
log.info("Completed training of split {} of {}", splitNum, totalSplits);
if (listeners != null) {
if (network != null) {
MultiLayerNetwork net = network.getNetwork();
net.setScore(network.getScore());
for (IterationListener il : listeners) {
il.iterationDone(net, iterationCount);
}
} else {
ComputationGraph g = graph.getNetwork();
g.setScore(graph.getScore());
for (IterationListener il : listeners) {
il.iterationDone(g, iterationCount);
}
}
}
iterationCount++;
}
public static class Builder {
private boolean saveUpdater;
private Integer numWorkers;
private int rddDataSetNumExamples;
private int batchSizePerWorker = 16;
private int averagingFrequency = 5;
private int prefetchNumBatches = 0;
private Repartition repartition = Repartition.Always;
private RepartitionStrategy repartitionStrategy = RepartitionStrategy.Balanced;
/**
* Same as {@link #Builder(Integer, int)} but automatically set number of workers based on JavaSparkContext.defaultParallelism()
*
* @param rddDataSetNumExamples Number of examples in each DataSet object in the {@code RDD}
*/
public Builder(int rddDataSetNumExamples) {
this(null, rddDataSetNumExamples);
}
/**
* Create a builder, where the following number of workers (Spark executors * number of threads per executor) are
* being used.
* Note: this should match the configuration of the cluster.
*
* It is also necessary to specify how many examples are in each DataSet that appears in the {@code RDD}
* or {@code JavaRDD} used for training.
* Two most common cases here:
* (a) Preprocessed data pipelines will often load binary DataSet objects with N > 1 examples in each; in this case,
* rddDataSetNumExamples should be set to N
* (b) "In line" data pipelines (for example, CSV String -> record reader -> DataSet just before training) will
* typically have exactly 1 example in each DataSet object. In this case, rddDataSetNumExamples should be set to 1
*
* @param numWorkers Number of Spark execution threads in the cluster. May be null. If null: number of workers will
* be obtained from JavaSparkContext.defaultParallelism(), which should provide the number of cores
* in the cluster.
* @param rddDataSetNumExamples Number of examples in each DataSet object in the {@code RDD}
*/
public Builder(Integer numWorkers, int rddDataSetNumExamples) {
if (numWorkers != null && numWorkers <= 0)
throw new IllegalArgumentException("Invalid number of workers: " + numWorkers + " (must be >= 1)");
if (rddDataSetNumExamples <= 0)
throw new IllegalArgumentException("Invalid rdd data set size: " + rddDataSetNumExamples + " (must be >= 1)");
this.numWorkers = numWorkers;
this.rddDataSetNumExamples = rddDataSetNumExamples;
}
/**
* Batch size (in number of examples) per worker, for each fit(DataSet) call.
*
* @param batchSizePerWorker Size of each minibatch to use for each worker
* @return
*/
public Builder batchSizePerWorker(int batchSizePerWorker) {
this.batchSizePerWorker = batchSizePerWorker;
return this;
}
/**
* Frequency with which to average worker parameters.
* Note: Too high or too low can be bad for different reasons.
* - Too low (such as 1) can result in a lot of network traffic
* - Too high (>> 20 or so) can result in accuracy issues or problems with network convergence
*
* @param averagingFrequency Frequency (in number of minibatches of size 'batchSizePerWorker') to average parameters
*/
public Builder averagingFrequency(int averagingFrequency) {
if (averagingFrequency <= 0)
throw new IllegalArgumentException("Ivalid input: averaging frequency must be >= 1");
this.averagingFrequency = averagingFrequency;
return this;
}
/**
* Set the number of minibatches to asynchronously prefetch in the worker.
*
* Default: 0 (no prefetching)
*
* @param prefetchNumBatches Number of minibatches (DataSets of size batchSizePerWorker) to fetch
*/
public Builder workerPrefetchNumBatches(int prefetchNumBatches) {
this.prefetchNumBatches = prefetchNumBatches;
return this;
}
/**
* Set whether the updater (i.e., historical state for momentum, adagrad, etc should be saved).
* NOTE: This can double (or more) the amount of network traffic in each direction, but might
* improve network training performance (and can be more stable for certain updaters such as adagrad).
*
* This is enabled by default.
*
* @param saveUpdater If true: retain the updater state (default). If false, don't retain (updaters will be
* reinitalized in each worker after averaging).
*/
public Builder saveUpdater(boolean saveUpdater) {
this.saveUpdater = saveUpdater;
return this;
}
/**
* Set if/when repartitioning should be conducted for the training data.
* Default value: always repartition (if required to guarantee correct number of partitions and correct number
* of examples in each partition).
*
* @param repartition Setting for repartitioning
*/
public Builder repartionData(Repartition repartition) {
this.repartition = repartition;
return this;
}
/**
* Used in conjunction with {@link #repartionData(Repartition)} (which defines when repartitioning should be
* conducted), repartitionStrategy defines how the repartitioning should be done. See {@link RepartitionStrategy}
* for details
*
* @param repartitionStrategy Repartitioning strategy to use
*/
public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
this.repartitionStrategy = repartitionStrategy;
return this;
}
public ParameterAveragingTrainingMaster build() {
return new ParameterAveragingTrainingMaster(this);
}
}
}