org.deeplearning4j.spark.api.TrainingWorker Maven / Gradle / Ivy
package org.deeplearning4j.spark.api;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import java.io.Serializable;
/**
* TrainingWorker is a small serializable class that can be passed (in serialized form) to each Spark executor
* for actually conducting training. The results are then passed back to the {@link TrainingMaster} for processing.
*
* TrainingWorker implementations provide a layer of abstraction for network learning tha should allow for more flexibility/
* control over how learning is conducted (including for example asynchronous communication)
*
* @author Alex Black
*/
public interface TrainingWorker extends Serializable {
/**
* Get the initial model when training a MultiLayerNetwork/SparkDl4jMultiLayer
*
* @return Initial model for this worker
*/
MultiLayerNetwork getInitialModel();
/**
* Get the initial model when training a ComputationGraph/SparkComputationGraph
*
* @return Initial model for this worker
*/
ComputationGraph getInitialModelGraph();
/**
* Process (fit) a minibatch for a MultiLayerNetwork
*
* @param dataSet Data set to train on
* @param network Network to train
* @param isLast If true: last data set currently available. If false: more data sets will be processed for this executor
* @return Null, or a training result if training should be terminated immediately.
*/
R processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast);
/**
* Process (fit) a minibatch for a ComputationGraph
*
* @param dataSet Data set to train on
* @param graph Network to train
* @param isLast If true: last data set currently available. If false: more data sets will be processed for this executor
* @return Null, or a training result if training should be terminated immediately.
*/
R processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast);
/**
* Process (fit) a minibatch for a ComputationGraph using a MultiDataSet
*
* @param dataSet Data set to train on
* @param graph Network to train
* @param isLast If true: last data set currently available. If false: more data sets will be processed for this executor
* @return Null, or a training result if training should be terminated immediately.
*/
R processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast);
/**
* As per {@link #processMinibatch(DataSet, MultiLayerNetwork, boolean)} but used when {@link SparkTrainingStats} are being collecte
*/
Pair processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast);
/**
* As per {@link #processMinibatch(DataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collecte
*/
Pair processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast);
/**
* As per {@link #processMinibatch(MultiDataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collecte
*/
Pair processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast);
/**
* Get the final result to be returned to the driver
*
* @param network Current state of the network
* @return Result to return to the driver
*/
R getFinalResult(MultiLayerNetwork network);
/**
* Get the final result to be returned to the driver
*
* @param graph Current state of the network
* @return Result to return to the driver
*/
R getFinalResult(ComputationGraph graph);
/**
* Get the final result to be returned to the driver, if no data was available for this executor
*
* @return Result to return to the driver
*/
R getFinalResultNoData();
/**
* As per {@link #getFinalResultNoData()} but used when {@link SparkTrainingStats} are being collected
*/
Pair getFinalResultNoDataWithStats();
/**
* As per {@link #getFinalResult(MultiLayerNetwork)} but used when {@link SparkTrainingStats} are being collected
*/
Pair getFinalResultWithStats(MultiLayerNetwork network);
/**
* As per {@link #getFinalResult(ComputationGraph)} but used when {@link SparkTrainingStats} are being collected
*/
Pair getFinalResultWithStats(ComputationGraph graph);
/**
* Get the {@link WorkerConfiguration} that contains information such as minibatch sizes, etc
*
* @return Worker configuration
*/
WorkerConfiguration getDataConfiguration();
}