hex.deeplearning.DeepLearningTask2 Maven / Gradle / Ivy
package hex.deeplearning;
import water.Key;
import water.MRTask;
import water.fvec.Frame;
/**
* DRemoteTask-based Deep Learning.
* Every node has access to all the training data which leads to optimal CPU utilization and training accuracy IFF the data fits on every node.
*/
public class DeepLearningTask2 extends MRTask {
/**
* Construct a DeepLearningTask2 where every node trains on the entire training dataset
* @param jobKey Job ID
* @param train Frame containing training data
* @param model_info Initial DeepLearningModelInfo (weights + biases)
* @param sync_fraction Fraction of the training data to use for one SGD iteration
*/
public DeepLearningTask2(Key jobKey, Frame train, DeepLearningModelInfo model_info, float sync_fraction, int iteration) {
assert(sync_fraction > 0);
_jobKey = jobKey;
_fr = train;
_sharedmodel = model_info;
_sync_fraction = sync_fraction;
_iteration = iteration;
}
/**
* Returns the aggregated DeepLearning model that was trained by all nodes (over all the training data)
* @return model_info object
*/
public DeepLearningModelInfo model_info() { return _sharedmodel; }
final private Key _jobKey;
final private Frame _fr;
private DeepLearningModelInfo _sharedmodel;
final private float _sync_fraction;
private DeepLearningTask _res;
private final int _iteration;
/**
* Do the local computation: Perform one DeepLearningTask (with run_local=true) iteration.
* Pass over all the data (will be replicated in dfork() here), and use _sync_fraction random rows.
* This calls DeepLearningTask's reduce() between worker threads that update the same local model_info via Hogwild!
* Once the computation is done, reduce() will be called
*/
@Override
public void setupLocal() {
super.setupLocal();
_res = new DeepLearningTask(_jobKey, _sharedmodel, _sync_fraction, _iteration, this);
addToPendingCount(1);
_res.dfork(null, _fr, true /*run_local*/);
}
/**
* Reduce between worker nodes, with network traffic (if greater than 1 nodes)
* After all reduce()'s are done, postGlobal() will be called
* @param drt task to reduce
*/
@Override
public void reduce(DeepLearningTask2 drt) {
if (_res == null) _res = drt._res;
else {
_res._chunk_node_count += drt._res._chunk_node_count;
_res.model_info().add(drt._res.model_info()); //add models, but don't average yet
}
assert(_res.model_info().get_params()._replicate_training_data);
}
/**
* Finish up the work after all nodes have reduced their models via the above reduce() method.
* All we do is average the models and add to the global training sample counter.
* After this returns, model_info() can be queried for the updated model.
*/
@Override
protected void postGlobal() {
assert(_res.model_info().get_params()._replicate_training_data);
super.postGlobal();
// model averaging (DeepLearningTask only computed the per-node models, each on all the data)
_res.model_info().div(_res._chunk_node_count);
_res.model_info().add_processed_global(_res.model_info().get_processed_local()); //switch from local counters to global counters
_res.model_info().set_processed_local(0l);
DeepLearningModelInfo nodeAverageModel = _res.model_info();
if (nodeAverageModel.get_params()._elastic_averaging)
_sharedmodel = DeepLearningModelInfo.timeAverage(nodeAverageModel);
else
_sharedmodel = nodeAverageModel;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy