All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.api.TrainingMaster Maven / Gradle / Ivy

There is a newer version: 0.6.0
Show newest version
package org.deeplearning4j.spark.api;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.input.PortableDataStream;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;

import java.io.OutputStream;
import java.util.Collection;

/**
 * A TrainingMaster controls how distributed training is executed in practice
* In principle, a large number of different approches can be used in distributed training (synchronous vs. asynchronous, * parameter vs. gradient averaging, etc). Each of these different approaches would be implemented as a TrainingMaster; * this allows {@link SparkDl4jMultiLayer} and {@link SparkComputationGraph} to be used with different training methods. * * @author Alex Black */ public interface TrainingMaster> { /** * Get the worker instance for this training master * * @param network Current SparkDl4jMultiLayer * @return Worker instance */ W getWorkerInstance(SparkDl4jMultiLayer network); /** * Get the worker instance for this training master * * @param graph Current SparkComputationGraph * @return Worker instance */ W getWorkerInstance(SparkComputationGraph graph); /** * Train the SparkDl4jMultiLayer with the specified data set * * @param network Current network state * @param trainingData Data to train on */ void executeTraining(SparkDl4jMultiLayer network, JavaRDD trainingData); /** * Train the SparkDl4jMultiLayer with the specified serialized DataSet objects. The assumption * here is that the PortableDataStreams are for DataSet objects, one per file. * * @param network Current network state * @param trainingData Data to train on */ void executeTraining(SparkDl4jMultiLayer network, JavaPairRDD trainingData); /** * Train the SparkComputationGraph with the specified data set * * @param graph Current network state * @param trainingData Data to train on */ void executeTraining(SparkComputationGraph graph, JavaRDD trainingData); /** * Train the SparkComputationGraph with the specified serialized DataSet objects. The assumption * here is that the PortableDataStreams are for DataSet objects, one per file, and that these have been * serialized using {@link DataSet#save(OutputStream)} * * @param network Current network state * @param trainingData Data to train on */ void executeTraining(SparkComputationGraph network, JavaPairRDD trainingData); /** * Train the SparkComputationGraph with the specified data set * * @param graph Current network state * @param trainingData Data to train on */ void executeTrainingMDS(SparkComputationGraph graph, JavaRDD trainingData); /** * Train the SparkComputationGraph with the specified serialized MultiDataSet objects. The assumption * here is that the PortableDataStreams are for MultiDataSet objects, one per file. * * @param network Current network state * @param trainingData Data to train on */ void executeTrainingMDS(SparkComputationGraph network, JavaPairRDD trainingData); /** * Set whether the training statistics should be collected. Training statistics may include things like per-epoch run times, * time spent waiting for data, etc. *

* These statistics are primarily used for debugging and optimization, in order to gain some insight into what aspects * of network training are taking the most time. * * @param collectTrainingStats If true: collecting training statistics will be */ void setCollectTrainingStats(boolean collectTrainingStats); /** * Get the current setting for collectTrainingStats */ boolean getIsCollectTrainingStats(); /** * Return the training statistics. Note that this may return null, unless setCollectTrainingStats has been set first * * @return Training statistics */ SparkTrainingStats getTrainingStats(); /** * Set the iteration listeners. These should be called after every averaging (or similar) operation in the TrainingMaster, * though the exact behaviour may be dependent on each IterationListener * * @param listeners Listeners to set */ void setListeners(Collection listeners); }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy