All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.api.TrainingMaster Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.api;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.core.loader.DataSetLoader;
import org.deeplearning4j.core.loader.MultiDataSetLoader;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;

import java.util.Collection;

public interface TrainingMaster> {


    /**
     * Remove a training hook from the worker
     * @param trainingHook the training hook to remove
     */
    void removeHook(TrainingHook trainingHook);

    /**
     * Add a hook for the master for pre and post training
     * @param trainingHook the training hook to add
     */
    void addHook(TrainingHook trainingHook);

    /**
     * Get the TrainingMaster configuration as JSON
     */
    String toJson();

    /**
     * Get the TrainingMaster configuration as YAML
     */
    String toYaml();

    /**
     * Get the worker instance for this training master
     *
     * @param network Current SparkDl4jMultiLayer
     * @return Worker instance
     */
    W getWorkerInstance(SparkDl4jMultiLayer network);

    /**
     * Get the worker instance for this training master
     *
     * @param graph Current SparkComputationGraph
     * @return Worker instance
     */
    W getWorkerInstance(SparkComputationGraph graph);

    /**
     * Train the SparkDl4jMultiLayer with the specified data set
     *
     * @param network      Current network state
     * @param trainingData Data to train on
     */
    void executeTraining(SparkDl4jMultiLayer network, JavaRDD trainingData);


    /**
     * Fit the network using a list of paths for serialized DataSet objects.
     *
     * @param network           Current network state
     * @param trainingDataPaths Data to train on
     */
    void executeTrainingPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD trainingDataPaths, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader);

    /**
     * Train the SparkComputationGraph with the specified data set
     *
     * @param graph        Current network state
     * @param trainingData Data to train on
     */
    void executeTraining(SparkComputationGraph graph, JavaRDD trainingData);

    /**
     * Train the SparkComputationGraph with the specified data set
     *
     * @param graph        Current network state
     * @param trainingData Data to train on
     */
    void executeTrainingMDS(SparkComputationGraph graph, JavaRDD trainingData);

    /**
     * Set whether the training statistics should be collected. Training statistics may include things like per-epoch run times,
     * time spent waiting for data, etc.
     * 

* These statistics are primarily used for debugging and optimization, in order to gain some insight into what aspects * of network training are taking the most time. * * @param collectTrainingStats If true: collecting training statistics will be */ void setCollectTrainingStats(boolean collectTrainingStats); /** * Get the current setting for collectTrainingStats */ boolean getIsCollectTrainingStats(); /** * Return the training statistics. Note that this may return null, unless setCollectTrainingStats has been set first * * @return Training statistics */ SparkTrainingStats getTrainingStats(); /** * Set the iteration listeners. These should be called after every averaging (or similar) operation in the TrainingMaster, * though the exact behaviour may be dependent on each TrainingListener * * @param listeners Listeners to set */ void setListeners(Collection listeners); /** * Set the iteration listeners and the StatsStorageRouter. This is typically used for UI functionality: for example, * setListeners(new FileStatsStorage(myFile), Collections.singletonList(new StatsListener(null))). This will pass a * StatsListener to each worker, and then shuffle the results back to the specified FileStatsStorage instance (which * can then be attached to the UI or loaded later) * * @param router StatsStorageRouter in which to place the results * @param listeners Listeners */ void setListeners(StatsStorageRouter router, Collection listeners); /** * Attempt to delete any temporary files generated by this TrainingMaster. * Depending on the configuration, no temporary files may be generated. * * @param sc JavaSparkContext (used to access HDFS etc file systems, when required) * @return True if deletion was successful (or, no files to delete); false otherwise. */ boolean deleteTempFiles(JavaSparkContext sc); /** * Attempt to delete any temporary files generated by this TrainingMaster. * Depending on the configuration, no temporary files may be generated. * * @param sc SparkContext (used to access HDFS etc file systems, when required) * @return True if deletion was successful (or, no files to delete); false otherwise. */ boolean deleteTempFiles(SparkContext sc); }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy