All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.deepwater.DeepWaterTask Maven / Gradle / Ivy

package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendTrain;
import hex.FrameTask;
import water.Futures;
import water.H2O;
import water.Job;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

public class DeepWaterTask extends FrameTask {
  private DeepWaterModelInfo _localmodel; //per-node state (to be reduced)
  private DeepWaterModelInfo _sharedmodel; //input/output
  private int _chunk_node_count = 1;
  private float _useFraction;
  private boolean _shuffle;
  private final Job _job;

  /**
   * Accessor to the object containing the (final) state of the Deep Learning model
   * Should only be queried after calling this.doAll(Frame training)
   * @return "The" final model after one Map/Reduce iteration
   */
  final public DeepWaterModelInfo model_info() {
    assert(_sharedmodel != null);
    return _sharedmodel;
  }

  /**
   * The only constructor
   * @param inputModel Initial model state
   * @param fraction Fraction of rows of the training to train with
   */
  DeepWaterTask(DeepWaterModelInfo inputModel, float fraction, Job job) {
    super(job._key,inputModel._dataInfo);
    _sharedmodel = inputModel;
    _useFraction=fraction;
    _shuffle = model_info().get_params()._shuffle_training_data;
    _job = job;
  }

  /**
   * Transfer ownership from global (shared) model to local model which will be worked on
   */
  @Override protected void setupLocal(){
//    long start = System.currentTimeMillis();
    assert(_localmodel == null);
    _localmodel = _sharedmodel;
    _sharedmodel = null;
    _localmodel.set_processed_local(0);
    final int weightIdx =_fr.find(_localmodel.get_params()._weights_column);
    final int respIdx =_fr.find(_localmodel.get_params()._response_column);
    final int batchSize = _localmodel.get_params()._mini_batch_size;

//    long nativetime = 0;
    DeepWaterIterator iter = null;
    long seed = 0xDECAF + 0xD00D * _localmodel.get_processed_global();
    Random rng = RandomUtils.getRNG(seed);

    if (_fr.numRows()>Integer.MAX_VALUE) {
      throw H2O.unimpl("Need to implement batching into int-sized chunks.");
    }
    int len = (int)_fr.numRows();
    int j=0;
    Futures fs = new Futures();
    ArrayList trainLabels = new ArrayList<>();
    ArrayList trainData = new ArrayList<>();

    try {
      // Binary data (Images/Documents/etc.)
      if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image ||
          _localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
        int dataIdx = 0; //must be the first column //FIXME
        Log.debug("Using column " + _fr.name(dataIdx) + " for " +
            ((_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image) ? "path to image data"
                :((_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) ? "text data"
                : "path to arbitrary bytes")));
        // full passes over the data
        BufferedString bs = new BufferedString();
        int fullpasses = (int)_useFraction; // Example: train_samples_per_iteration = 4700, and train.numRows()=1000 -> _useFraction = 4.7 -> fullpasses = 4
        while (j++ < fullpasses) {
          for (int i=0; i<_fr.numRows(); ++i) {
            double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i);
            if (weight == 0)
              continue;
            BufferedString file = _fr.vec(dataIdx).atStr(bs, i);
            if (file!=null)
              trainData.add(file.toString());
            float response = (float) _fr.vec(respIdx).at(i);
            trainLabels.add(response);
          }
        }

        // fractional passes // 0.7
        while (trainData.size() < _useFraction*len || trainData.size() % batchSize != 0) {
          assert(_shuffle);
          int i = rng.nextInt(len);
          double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i);
          if (weight == 0)
            continue;
          BufferedString file = _fr.vec(dataIdx).atStr(bs, i);
          if (file!=null)
            trainData.add(file.toString());
          float response = (float) _fr.vec(respIdx).at(i);
          trainLabels.add(response);
        }
      }

      // Numeric data (H2O Frame full with numeric columns)
      else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
        double mul = _localmodel._dataInfo._normRespMul!=null ? _localmodel._dataInfo._normRespMul[0] : 1;
        double sub = _localmodel._dataInfo._normRespSub!=null ? _localmodel._dataInfo._normRespSub[0] : 0;

        // full passes over the data
        int fullpasses = (int) _useFraction;
        while (j++ < fullpasses) {
          for (int i = 0; i < _fr.numRows(); ++i) {
            double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i);
            if (weight == 0)
              continue;
            float response = (float)((_fr.vec(respIdx).at(i) - sub) / mul);
            trainData.add(i);
            trainLabels.add(response);
          }
        }

        // fractional passes
        while (trainData.size() < _useFraction * len || trainData.size() % batchSize != 0) {
          int i = rng.nextInt(len);
          double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i);
          if (weight == 0)
            continue;
          float response = (float)((_fr.vec(respIdx).at(i) - sub) / mul);
          trainData.add(i);
          trainLabels.add(response);
        }
      }

      // shuffle the (global) list
      if (_shuffle) {
        rng.setSeed(seed);
        Collections.shuffle(trainLabels, rng);
        rng.setSeed(seed);
        Collections.shuffle(trainData, rng);
      }
      if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image) {
        iter = new DeepWaterImageIterator(trainData, trainLabels, _localmodel._meanData, batchSize, _localmodel._width, _localmodel._height, _localmodel._channels, _localmodel.get_params()._cache_data);
      }
      else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
        assert (_localmodel._dataInfo != null);
        iter = new DeepWaterDatasetIterator(trainData, trainLabels, _localmodel._dataInfo, batchSize, _localmodel.get_params()._cache_data);
      }
      else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
        iter = new DeepWaterTextIterator(trainData, trainLabels, batchSize, 56/*FIXME*/, _localmodel.get_params()._cache_data);
      }

      NativeTrainTask ntt;
      while (iter.Next(fs) && !_job.isStopping()) {
//        if (ntt != null) nativetime += ntt._timeInMillis;
        long n = _localmodel.get_processed_total();

//        if(!_localmodel.get_params()._quiet_mode)
//          Log.info("Trained " + n + " samples. Training on " + Arrays.toString(((DeepWaterImageIterator)iter).getFiles()));

        _localmodel._backend.setParameter(_localmodel._model, "learning_rate", _localmodel.get_params().learningRate((double) n));
        _localmodel._backend.setParameter(_localmodel._model, "momentum", _localmodel.get_params().momentum((double) n));

        //fork off GPU work, but let the iterator.Next() wait on completion before swapping again
        //System.err.println("data: " + Arrays.toString(iter.getData()));
        float[] preds = _localmodel._backend.predict(_localmodel._model, iter.getData());
        if (Float.isNaN(ArrayUtils.sum(preds))) {
          Log.err(DeepWaterModel.unstable_msg);
          throw new UnsupportedOperationException(DeepWaterModel.unstable_msg);
        }
//        System.err.println("pred: " + Arrays.toString(preds));
        ntt = new NativeTrainTask(_localmodel._backend, _localmodel._model, iter.getData(), iter.getLabel());
        fs.add(H2O.submitTask(ntt));
        _localmodel.add_processed_local(iter._batch_size);
      }
      fs.blockForPending();
//      nativetime += ntt._timeInMillis;
    } catch (IOException e) {
      e.printStackTrace(); //gracefully continue if we can't find files etc.
    }
//    long end = System.currentTimeMillis();
//    if (!_localmodel.get_params()._quiet_mode) {
//      Log.info("Time for one iteration: " + PrettyPrint.msecs(end - start, true));
//      Log.info("Time for native training : " + PrettyPrint.msecs(nativetime, true));
//    }
  }
  @Override public void map(Chunk [] chunks, NewChunk [] outputs) { }

  static private class NativeTrainTask extends H2O.H2OCountedCompleter {

    long _timeInMillis;
    final BackendTrain _backend;
    final BackendModel _model;

    float[] _data;
    float[] _labels;

    NativeTrainTask(BackendTrain backend, BackendModel model, float[] data, float[] label) {
      _backend = backend;
      _model = model;
      _data = data;
      _labels = label;
    }

    @Override
    public void compute2() {
      long start = System.currentTimeMillis();
      _backend.train(_model, _data,_labels); //ignore predictions
      long end = System.currentTimeMillis();
      _timeInMillis += end-start;
      tryComplete();
    }
  }

  /**
   * After all maps are done on a node, this is called to store the per-node model into DKV (for elastic averaging)
   * Otherwise, do nothing.
   */
  @Override protected void closeLocal() {
    _sharedmodel = null;
  }

  /**
   * Average the per-node models (for elastic averaging, already wrote them to DKV in postLocal())
   * This is a no-op between F/J worker threads (operate on the same weights/biases)
   * @param other Other DeepWaterTask to reduce
   */
  @Override public void reduce(DeepWaterTask other){
    if (_localmodel != null && other._localmodel != null && other._localmodel.get_processed_local() > 0 //other DLTask was active (its model_info should be used for averaging)
        && other._localmodel != _localmodel) //other DLTask worked on a different model_info
    {
      // avoid adding remote model info to unprocessed local data, still random
      // (this can happen if we have no chunks on the master node)
      if (_localmodel.get_processed_local() == 0) {
        _localmodel = other._localmodel;
        _chunk_node_count = other._chunk_node_count;
      } else {
        _localmodel.add(other._localmodel);
        _chunk_node_count += other._chunk_node_count;
      }
    }
  }


  private static long _lastWarn;
  private static long _warnCount;
  /**
   * After all reduces are done, the driver node calls this method to clean up
   * This is only needed if we're not inside a DeepWaterTask2 (which will do the reduction between replicated data workers).
   * So if replication is disabled, and every node works on partial data, then we have work to do here (model averaging).
   */
  @Override protected void postGlobal(){
    DeepWaterParameters dlp = _localmodel.get_params();
    if (H2O.CLOUD.size() > 1 && !dlp._replicate_training_data) {
      long now = System.currentTimeMillis();
      if (_chunk_node_count < H2O.CLOUD.size() && (now - _lastWarn > 5000) && _warnCount < 3) {
//        Log.info("Synchronizing across " + _chunk_node_count + " H2O node(s).");
        Log.warn(H2O.CLOUD.size() - _chunk_node_count + " node(s) (out of " + H2O.CLOUD.size()
            + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes).");
        _lastWarn = now;
        _warnCount++;
      }
    }
    // Check that we're not inside a DeepWaterTask2
    assert ((!dlp._replicate_training_data || H2O.CLOUD.size() == 1) == !_run_local);
    if (!_run_local) {
      _localmodel.add_processed_global(_localmodel.get_processed_local()); //move local sample counts to global ones
      _localmodel.set_processed_local(0L);
      // model averaging
      if (_chunk_node_count > 1)
        _localmodel.div(_chunk_node_count);
    } else {
      //Get ready for reduction in DeepWaterTask2
      //Just swap the local and global models
      _sharedmodel = _localmodel;
    }
    if (_sharedmodel == null)
      _sharedmodel = _localmodel;
    _localmodel = null;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy