All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.api.TrainingWorker Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.api;

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.common.primitives.Pair;

import java.io.Serializable;

public interface TrainingWorker extends Serializable {

    /**
     * Remove a training hook from the worker
     * @param trainingHook the training hook to remove
     */
    void removeHook(TrainingHook trainingHook);

    /**
     * Add a training hook to be used
     * during training of the worker
     * @param trainingHook the training hook to add
     */
    void addHook(TrainingHook trainingHook);

    /**
     * Get the initial model when training a MultiLayerNetwork/SparkDl4jMultiLayer
     *
     * @return Initial model for this worker
     */
    MultiLayerNetwork getInitialModel();

    /**
     * Get the initial model when training a ComputationGraph/SparkComputationGraph
     *
     * @return Initial model for this worker
     */
    ComputationGraph getInitialModelGraph();

    /**
     * Process (fit) a minibatch for a MultiLayerNetwork
     *
     * @param dataSet Data set to train on
     * @param network Network to train
     * @param isLast  If true: last data set currently available. If false: more data sets will be processed for this executor
     * @return Null, or a training result if training should be terminated immediately.
     */
    R processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast);

    /**
     * Process (fit) a minibatch for a ComputationGraph
     *
     * @param dataSet Data set to train on
     * @param graph   Network to train
     * @param isLast  If true: last data set currently available. If false: more data sets will be processed for this executor
     * @return Null, or a training result if training should be terminated immediately.
     */
    R processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * Process (fit) a minibatch for a ComputationGraph using a MultiDataSet
     *
     * @param dataSet Data set to train on
     * @param graph   Network to train
     * @param isLast  If true: last data set currently available. If false: more data sets will be processed for this executor
     * @return Null, or a training result if training should be terminated immediately.
     */
    R processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * As per {@link #processMinibatch(DataSet, MultiLayerNetwork, boolean)} but used when {@link SparkTrainingStats} are being collecte
     */
    Pair processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast);

    /**
     * As per {@link #processMinibatch(DataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collected
     */
    Pair processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * As per {@link #processMinibatch(MultiDataSet, ComputationGraph, boolean)} but used when {@link SparkTrainingStats} are being collected
     */
    Pair processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast);

    /**
     * Get the final result to be returned to the driver
     *
     * @param network Current state of the network
     * @return Result to return to the driver
     */
    R getFinalResult(MultiLayerNetwork network);

    /**
     * Get the final result to be returned to the driver
     *
     * @param graph Current state of the network
     * @return Result to return to the driver
     */
    R getFinalResult(ComputationGraph graph);

    /**
     * Get the final result to be returned to the driver, if no data was available for this executor
     *
     * @return Result to return to the driver
     */
    R getFinalResultNoData();

    /**
     * As per {@link #getFinalResultNoData()} but used when {@link SparkTrainingStats} are being collected
     */
    Pair getFinalResultNoDataWithStats();

    /**
     * As per {@link #getFinalResult(MultiLayerNetwork)} but used when {@link SparkTrainingStats} are being collected
     */
    Pair getFinalResultWithStats(MultiLayerNetwork network);

    /**
     * As per {@link #getFinalResult(ComputationGraph)} but used when {@link SparkTrainingStats} are being collected
     */
    Pair getFinalResultWithStats(ComputationGraph graph);

    /**
     * Get the {@link WorkerConfiguration} that contains information such as minibatch sizes, etc
     *
     * @return Worker configuration
     */
    WorkerConfiguration getDataConfiguration();

    long getInstanceId();
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy