All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.deeplearning.DeepLearningTask2 Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.deeplearning;

import water.Key;
import water.MRTask;
import water.fvec.Frame;

/**
 * DRemoteTask-based Deep Learning.
 * Every node has access to all the training data which leads to optimal CPU utilization IF the data fits on every node.
 */
public class DeepLearningTask2 extends MRTask {
  /**
   * Construct a DeepLearningTask2 where every node trains on the entire training dataset
   * @param train Frame containing training data
   * @param model_info Initial DeepLearningModelInfo (weights + biases)
   * @param sync_fraction Fraction of the training data to use for one SGD iteration
   */
  public DeepLearningTask2(Key jobKey, Frame train, DeepLearningModel.DeepLearningModelInfo model_info, float sync_fraction) {
    assert(sync_fraction > 0);
    _jobKey = jobKey;
    _fr = train;
    _model_info = model_info;
    _sync_fraction = sync_fraction;
  }

  /**
   * Returns the aggregated DeepLearning model that was trained by all nodes (over all the training data)
   * @return model_info object
   */
  public DeepLearningModel.DeepLearningModelInfo model_info() {
    return _res.model_info();
  }

  final private Key _jobKey;
  final private Frame _fr;
  final private DeepLearningModel.DeepLearningModelInfo _model_info;
  final private float _sync_fraction;
  private DeepLearningTask _res;

  /**
   * Do the local computation: Perform one DeepLearningTask (with run_local=true) iteration.
   * Pass over all the data (will be replicated in dfork() here), and use _sync_fraction random rows.
   * This calls DeepLearningTask's reduce() between worker threads that update the same local model_info via Hogwild!
   * Once the computation is done, reduce() will be called
   */
  @Override
  public void setupLocal() {
    super.setupLocal();
    _res = new DeepLearningTask(_jobKey, _model_info, _sync_fraction);
    addToPendingCount(1);
    _res.setCompleter(this);
    _res.asyncExec(0, _fr, true /*run_local*/);
  }

  /**
   * Reduce between worker nodes, with network traffic (if greater than 1 nodes)
   * After all reduce()'s are done, postGlobal() will be called
   * @param drt task to reduce
   */
  @Override
  public void reduce(DeepLearningTask2 drt) {
    if (_res == null) _res = drt._res;
    else {
      _res._chunk_node_count += drt._res._chunk_node_count;
      _res.model_info().add(drt._res.model_info()); //add models, but don't average yet
    }
    assert(_res.model_info().get_params()._replicate_training_data);
  }

  /**
   * Finish up the work after all nodes have reduced their models via the above reduce() method.
   * All we do is average the models and add to the global training sample counter.
   * After this returns, model_info() can be queried for the updated model.
   */
  @Override
  protected void postGlobal() {
    assert(_res.model_info().get_params()._replicate_training_data);
    super.postGlobal();
    _res.model_info().div(_res._chunk_node_count); //model averaging
    _res.model_info().add_processed_global(_res.model_info().get_processed_local()); //switch from local counters to global counters
    _res.model_info().set_processed_local(0l);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy