All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.deeplearning.DeepLearning Maven / Gradle / Ivy

package hex.deeplearning;


import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.SupervisedModelBuilder;
import hex.schemas.DeepLearningV3;
import hex.schemas.ModelBuilderSchema;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.init.Linpack;
import water.init.NetworkTest;
import water.util.*;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;

import static water.util.MRUtils.sampleFrame;
import static water.util.MRUtils.sampleFrameStratified;

/**
 * Deep Learning Neural Net implementation based on MRTask
 */
public class DeepLearning extends SupervisedModelBuilder {
  @Override
  public ModelCategory[] can_build() {
    return new ModelCategory[]{
            ModelCategory.Regression,
            ModelCategory.Binomial,
            ModelCategory.Multinomial,
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };

  @Override
  public boolean isSupervised() {
    return !_parms._autoencoder;
  }

  public DeepLearning( DeepLearningModel.DeepLearningParameters parms ) {
    super("DeepLearning", parms); init(false);
  }

  public ModelBuilderSchema schema() { return new DeepLearningV3(); }

  /** Start the DeepLearning training Job on an F/J thread. */
  @Override public Job trainModel() {
    // We look at _train before init(true) is called, so step around that here:
    long work = 1;
    if (null != _train)
      work = (long)_parms._epochs * _train.numRows();
    return start(new DeepLearningDriver(), work);
  }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   *
   *  Validate the very large number of arguments in the DL Parameter directly. */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    _parms.validate(this, expensive);
    if (expensive && error_count() == 0) checkMemoryFootPrint();
  }

  /**
   * Helper to create the DataInfo object from training/validation frames and the DL parameters
   * @param train Training frame
   * @param valid Validation frame
   * @param parms Model parameters
   * @return
   */
  static DataInfo makeDataInfo(Frame train, Frame valid, DeepLearningModel.DeepLearningParameters parms) {
    return new DataInfo(
            Key.make(), //dest key
            train,
            valid,
            parms._autoencoder ? 0 : 1, //nResponses
            parms._autoencoder || parms._use_all_factor_levels, //use all FactorLevels for auto-encoder
            parms._autoencoder ? DataInfo.TransformType.NORMALIZE : DataInfo.TransformType.STANDARDIZE, //transform predictors
            train.lastVec().isEnum() ? DataInfo.TransformType.NONE : DataInfo.TransformType.STANDARDIZE, //transform response (only used if nResponses > 0)
            parms._missing_values_handling == DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip, //whether to skip missing
            true); //always add a bucket for missing values
  }

  @Override
  protected void checkMemoryFootPrint() {
    if (_parms._checkpoint != null) return;
    long p = _train.degreesOfFreedom() - (_parms._autoencoder ? 0 : _train.lastVec().cardinality());
    String[][] dom = _train.domains();
    // hack: add the factor levels for the NAs
    for (int i=0; i<_train.numCols()-(_parms._autoencoder ? 0 : 1); ++i) {
      if (dom[i] != null) {
        p++;
      }
    }
//    assert(makeDataInfo(_train, _valid, _parms).fullN() == p);
    long output = _parms._autoencoder ? p : Math.abs(_train.lastVec().cardinality());
    // weights
    long model_size = p * _parms._hidden[0];
    int layer=1;
    for (; layer < _parms._hidden.length; ++layer)
      model_size += _parms._hidden[layer-1] * _parms._hidden[layer];
    model_size += _parms._hidden[layer-1] * output;

    // biases
    for (layer=0; layer < _parms._hidden.length; ++layer)
      model_size += _parms._hidden[layer];
    model_size += output;

    if (model_size > 1e8) {
      String msg = "Model is too large: " + model_size + " parameters. Try reducing the number of neurons in the hidden layers (or reduce the number of categorical factors).";
      error("_hidden", msg);
      cancel(msg);
    }
  }

  public class DeepLearningDriver extends H2O.H2OCountedCompleter {
    @Override protected void compute2() {
      try {
        byte[] cs = new AutoBuffer().put(_parms).buf();

        Scope.enter();
        // Init parameters
        init(true);
        // Read lock input
        _parms.read_lock_frames(DeepLearning.this);
        // Something goes wrong
        if (error_count() > 0){
          DeepLearning.this.updateValidationMessages();
          throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(DeepLearning.this);
        }


        buildModel();

        //check that _parms isn't changed during DL model training
        byte[] cs2 = new AutoBuffer().put(_parms).buf();
        assert(Arrays.equals(cs, cs2));

        done();                 // Job done!
//      if (n_folds > 0) CrossValUtils.crossValidate(this);
      } catch( Throwable t ) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          failed(t);
          throw t;
        }
      } finally {
        _parms.read_unlock_frames(DeepLearning.this);
        Scope.exit();
      }
      tryComplete();
    }

    Key self() { return _key; }

    // the following parameters can be modified when restarting from a checkpoint
    transient final String [] cp_modifiable = new String[] {
            "_seed",
            "_epochs",
            "_score_interval",
            "_train_samples_per_iteration",
            "_target_ratio_comm_to_comp",
            "_score_duty_cycle",
            "_score_training_samples",
            "_classification_stop",
            "_regression_stop",
            "_quiet_mode",
            "_max_confusion_matrix_size",
            "_max_hit_ratio_k",
            "_diagnostics",
            "_variable_importances",
            "_force_load_balance",
            "_replicate_training_data",
            "_shuffle_training_data",
            "_single_node_mode",
            "_fast_mode",
            // Allow modification of the regularization parameters after a checkpoint restart
            "_l1",
            "_l2",
            "_max_w2",
            "_input_dropout_ratio",
            "_hidden_dropout_ratios",
            "_loss",
            "_overwrite_with_best_model",
            "_missing_values_handling",
            "_reproducible",
            "_export_weights_and_biases"
    };
    // the following parameters must not be modified when restarting from a checkpoint
    transient final String [] cp_not_modifiable = new String[] {
            "_drop_na20_cols",
            "_response_column",
            "_activation",
//            "_hidden", //this must be checked via Arrays.equals(a,b), not via String.equals()
//            "_ignored_columns", //this must be checked via Arrays.equals(a,b), not via String.equals()
            "_use_all_factor_levels",
            "_adaptive_rate",
            "_autoencoder",
            "_rho",
            "_epsilon",
            "_sparse",
            "_sparsity_beta",
            "_col_major",
            "_rate",
            "_momentum_start",
            "_momentum_ramp",
            "_momentum_stable",
            "_nesterov_accelerated_gradient",
            "_ignore_const_cols",
            "_max_categorical_features"
    };

    /**
     * Train a Deep Learning model, assumes that all members are populated
     * If checkpoint == null, then start training a new model, otherwise continue from a checkpoint
     */
    public final void buildModel() {
      Scope.enter();
      DeepLearningModel cp = null;
      if (_parms._checkpoint == null) {
        cp = new DeepLearningModel(dest(), _parms, new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this), _train, _valid);
        cp.model_info().initializeMembers();
      } else {
        final DeepLearningModel previous = DKV.getGet(_parms._checkpoint);
        if (previous == null) throw new IllegalArgumentException("Checkpoint not found.");
        Log.info("Resuming from checkpoint.");

        if( isClassifier() != previous._output.isClassifier() )
          throw new IllegalArgumentException("Response type must be the same as for the checkpointed model.");
        if( isSupervised() != previous._output.isSupervised() )
          throw new IllegalArgumentException("Model type must be the same as for the checkpointed model.");

        // check the user-given arguments for consistency
        DeepLearningModel.DeepLearningParameters oldP = previous._parms; //user-given parameters for checkpointed model
        DeepLearningModel.DeepLearningParameters newP = _parms; //user-given parameters for restart

        new ProgressUpdate("Resuming from checkpoint").fork(_progressKey);
        if (newP.getNumFolds() != 0)
          throw new UnsupportedOperationException("n_folds must be 0: Cross-validation is not supported during checkpoint restarts.");
        if ((_parms._valid == null) != (previous._parms._valid == null)
                || (_parms._valid != null  && !_parms._valid.equals(previous._parms._valid))) {
          throw new IllegalArgumentException("Validation dataset must be the same as for the checkpointed model.");
        }
        if (!newP._autoencoder && (newP._response_column == null || !newP._response_column.equals(oldP._response_column))) {
          throw new IllegalArgumentException("Response column (" + newP._response_column + ") is not the same as for the checkpointed model: " + oldP._response_column);
        }
        if (!Arrays.equals(newP._hidden, oldP._hidden)) {
          throw new IllegalArgumentException("Hidden layers (" + Arrays.toString(newP._hidden) + ") is not the same as for the checkpointed model: " + Arrays.toString(oldP._hidden));
        }
        if (!Arrays.equals(newP._ignored_columns, oldP._ignored_columns)) {
          throw new IllegalArgumentException("Predictor columns must be the same as for the checkpointed model. Check ignored columns.");
        }

        //compare the user-given parameters before and after and check that they are not changed
        for (Field fBefore : oldP.getClass().getDeclaredFields()) {
          if (ArrayUtils.contains(cp_not_modifiable, fBefore.getName())) {
            for (Field fAfter : newP.getClass().getDeclaredFields()) {
              if (fBefore.equals(fAfter)) {
                try {
                  if (fAfter.get(newP) == null || fBefore.get(oldP) == null || !fBefore.get(oldP).toString().equals(fAfter.get(newP).toString())) { // if either of the two parameters is null, skip the toString()
                    if (fBefore.get(oldP) == null && fAfter.get(newP) == null) continue; //if both parameters are null, we don't need to do anything
                    throw new IllegalArgumentException("Cannot change parameter: '" + fBefore.getName() + "': " + fBefore.get(oldP) + " -> " + fAfter.get(newP));
                  }
                } catch (IllegalAccessException e) {
                  e.printStackTrace();
                }
              }
            }
          }
        }

        try {
          final DataInfo dinfo = makeDataInfo(_train, _valid, _parms);
          DKV.put(dinfo._key,dinfo);
          cp = new DeepLearningModel(dest(), _parms, previous, false, dinfo);
          cp.write_lock(self());

          // these are the mutable parameters that are to be used by the model (stored in model_info._parms)
          final DeepLearningModel.DeepLearningParameters actualNewP = cp.model_info().get_params(); //actually used parameters for model building (defaults filled in, etc.)
          assert(actualNewP != previous.model_info().get_params());
          assert(actualNewP != newP);
          assert(actualNewP != oldP);

          if (!Arrays.equals(cp._output._names, previous._output._names)) {
            throw new IllegalArgumentException("Predictor columns of the training data must be the same as for the checkpointed model. Check ignored columns.");
          }
          if (!Arrays.deepEquals(cp._output._domains, previous._output._domains)) {
            throw new IllegalArgumentException("Categorical factor levels of the training data must be the same as for the checkpointed model.");
          }
          if (dinfo.fullN() != previous.model_info().data_info().fullN()) {
            throw new IllegalArgumentException("Total number of predictors is different than for the checkpointed model.");
          }

          for (Field fBefore : actualNewP.getClass().getDeclaredFields()) {
            if (ArrayUtils.contains(cp_modifiable, fBefore.getName())) {
              for (Field fAfter : newP.getClass().getDeclaredFields()) {
                if (fBefore.equals(fAfter)) {
                  try {
                    if (fAfter.get(newP) == null || fBefore.get(actualNewP) == null || !fBefore.get(actualNewP).toString().equals(fAfter.get(newP).toString())) { // if either of the two parameters is null, skip the toString()
                      if (fBefore.get(actualNewP) == null && fAfter.get(newP) == null) continue; //if both parameters are null, we don't need to do anything
                      Log.info("Applying user-requested modification of '" + fBefore.getName() + "': " + fBefore.get(actualNewP) + " -> " + fAfter.get(newP));
                      fBefore.set(actualNewP, fAfter.get(newP));
                    }
                  } catch (IllegalAccessException e) {
                    e.printStackTrace();
                  }
                }
              }
            }
          }
          // update parameters in place to set defaults etc.
          DeepLearningModel.modifyParms(actualNewP, actualNewP, isClassifier());

          actualNewP._epochs += previous.epoch_counter; //add new epochs to existing model
          Log.info("Adding " + String.format("%.3f", previous.epoch_counter) + " epochs from the checkpointed model.");

          if (actualNewP.getNumFolds() != 0) {
            Log.info("Disabling cross-validation: Not supported when resuming training from a checkpoint.");

            H2O.unimpl("writing to n_folds field needs to be uncommented");
            // actualNewP._n_folds = 0;
          }
          cp.update(self());
        } finally {
          if (cp != null) cp.unlock(self());
        }
      }
      trainModel(cp);

      // clean up, but don't delete the model and the (last) model metrics
      List keep = new ArrayList<>();
      keep.add(dest());
      if (cp._output._model_metrics.length != 0) keep.add(cp._output._model_metrics[cp._output._model_metrics.length-1]);
      if (cp._output.weights != null && cp._output.biases != null) {
        for (Key k : Arrays.asList(cp._output.weights)) {
          keep.add(k);
          for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
            keep.add(vk._key);
          }
        }
        for (Key k : Arrays.asList(cp._output.biases)) {
          keep.add(k);
          for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
            keep.add(vk._key);
          }
        }
      }
      Scope.exit(keep.toArray(new Key[0]));
    }


    /**
     * Train a Deep Learning neural net model
     * @param model Input model (e.g., from initModel(), or from a previous training run)
     * @return Trained model
     */
    public final DeepLearningModel trainModel(DeepLearningModel model) {
      Frame validScoreFrame = null;
      Frame train, trainScoreFrame;
      try {
//      if (checkpoint == null && !quiet_mode) logStart(); //if checkpoint is given, some Job's params might be uninitialized (but the restarted model's parameters are correct)
        if (model == null) {
          model = DKV.get(dest()).get();
        }
        Log.info("Model category: " + (_parms._autoencoder ? "Auto-Encoder" : isClassifier() ? "Classification" : "Regression"));
        final long model_size = model.model_info().size();
        Log.info("Number of model parameters (weights/biases): " + String.format("%,d", model_size));
        model.write_lock(self());
        new ProgressUpdate("Setting up training data...").fork(_progressKey);
        final DeepLearningModel.DeepLearningParameters mp = model.model_info().get_params();
        Frame tra_fr = new Frame(Key.make(mp.train()._key.toString() + ".temporary"), _train.names(), _train.vecs());
        Frame val_fr = _valid != null ? new Frame(Key.make(mp.valid()._key.toString() + ".temporary"), _valid.names(), _valid.vecs()) : null;

        train = tra_fr;
        if (mp._force_load_balance) {
          new ProgressUpdate("Load balancing training data...").fork(_progressKey);
          train = reBalance(train, mp._replicate_training_data /*rebalance into only 4*cores per node*/, mp._train.toString() + "." + model._key.toString() + ".train");
        }
        if (model._output.isClassifier() && mp._balance_classes) {
          new ProgressUpdate("Balancing class distribution of training data...").fork(_progressKey);
          float[] trainSamplingFactors = new float[train.lastVec().domain().length]; //leave initialized to 0 -> will be filled up below
          if (mp._class_sampling_factors != null) {
            if (mp._class_sampling_factors.length != train.lastVec().domain().length)
              throw new IllegalArgumentException("class_sampling_factors must have " + train.lastVec().domain().length + " elements");
            trainSamplingFactors = mp._class_sampling_factors.clone(); //clone: don't modify the original
          }
          train = sampleFrameStratified(
                  train, train.lastVec(), trainSamplingFactors, (long)(mp._max_after_balance_size*train.numRows()), mp._seed, true, false);
          model._output._modelClassDist = new MRUtils.ClassDist(train.lastVec()).doAll(train.lastVec()).rel_dist();
        }
        model._output.autoencoder = _parms._autoencoder;
        model.training_rows = train.numRows();
        trainScoreFrame = sampleFrame(train, mp._score_training_samples, mp._seed); //training scoring dataset is always sampled uniformly from the training dataset

        if (!_parms._quiet_mode) Log.info("Number of chunks of the training data: " + train.anyVec().nChunks());
        if (val_fr != null) {
          model.validation_rows = val_fr.numRows();
          // validation scoring dataset can be sampled in multiple ways from the given validation dataset
          if (model._output.isClassifier() && mp._balance_classes && mp._score_validation_sampling == DeepLearningModel.DeepLearningParameters.ClassSamplingMethod.Stratified) {
            new ProgressUpdate("Sampling validation data (stratified)...").fork(_progressKey);
            validScoreFrame = sampleFrameStratified(val_fr, val_fr.lastVec(), null,
                    mp._score_validation_samples > 0 ? mp._score_validation_samples : val_fr.numRows(), mp._seed +1, false /* no oversampling */, false);
          } else {
            new ProgressUpdate("Sampling validation data...").fork(_progressKey);
            validScoreFrame = sampleFrame(val_fr, mp._score_validation_samples, mp._seed +1);
          }
          if (mp._force_load_balance) {
            new ProgressUpdate("Balancing class distribution of validation data...").fork(_progressKey);
            validScoreFrame = reBalance(validScoreFrame, false /*always split up globally since scoring should be distributed*/, mp._valid.toString() + "." + model._key.toString() + ".valid");
          }
          if (!_parms._quiet_mode) Log.info("Number of chunks of the validation data: " + validScoreFrame.anyVec().nChunks());
        }

        // Set train_samples_per_iteration size (cannot be done earlier since this depends on whether stratified sampling is done)
        model.actual_train_samples_per_iteration = computeTrainSamplesPerIteration(mp, train.numRows(), model);
        // Determine whether shuffling is enforced
        if(mp._replicate_training_data && (model.actual_train_samples_per_iteration == train.numRows()*(mp._single_node_mode ?1:H2O.CLOUD.size())) && !mp._shuffle_training_data && H2O.CLOUD.size() > 1 && !mp._reproducible) {
          Log.info("Enabling training data shuffling, because all nodes train on the full dataset (replicated training data).");
          mp._shuffle_training_data = true;
        }
        if(!mp._shuffle_training_data && mp._balance_classes && !mp._reproducible) {
          Log.info("Enabling training data shuffling, because balance_classes is enabled.");
          mp._shuffle_training_data = true;
        }

        if (!mp._quiet_mode && mp._diagnostics) Log.info("Initial model:\n" + model.model_info());
        if (_parms._autoencoder) {
          new ProgressUpdate("Scoring null model of autoencoder...").fork(_progressKey);
          model.doScoring(trainScoreFrame, validScoreFrame, self(), null); //get the null model reconstruction error
        }
        // put the initial version of the model into DKV
        model.update(self());
        model._timeLastScoreEnter = System.currentTimeMillis(); //to keep track of time per iteration, must be called before first call to doScoring
        Log.info("Starting to train the Deep Learning model.");

        //main loop
        do {
          DeepLearningModel.DeepLearningModelInfo mi = model.model_info();
          final String speed = (model.run_time!=0 ? (" at " + mi.get_processed_total() * 1000 / model.run_time + " samples/s..."): "...");
          final String etl = model.run_time == 0 ? "" : " Estimated time left: " + PrettyPrint.msecs((long)(model.run_time*(1.-progress())/progress()), true);
          new ProgressUpdate("Training" + speed + etl).fork(_progressKey);
          model.set_model_info(mp._epochs == 0 ? mi : H2O.CLOUD.size() > 1 && mp._replicate_training_data ? (mp._single_node_mode ?
                  new DeepLearningTask2(self(), train, mi, rowFraction(train, mp, model)).doAll(Key.make()).model_info() : //replicated data + single node mode
                  new DeepLearningTask2(self(), train, mi, rowFraction(train, mp, model)).doAllNodes().model_info()) : //replicated data + multi-node mode
                  new DeepLearningTask(self(), mi, rowFraction(train, mp, model)).doAll(train).model_info()); //distributed data (always in multi-node mode)
          update(model.actual_train_samples_per_iteration); //update progress
        }
        while (model.doScoring(trainScoreFrame, validScoreFrame, self(), _progressKey));

        // replace the model with the best model so far (if it's better)
        if (!isCancelledOrCrashed() && _parms._overwrite_with_best_model && model.actual_best_model_key != null && _parms.getNumFolds() == 0) {
          DeepLearningModel best_model = DKV.getGet(model.actual_best_model_key);
          if (best_model != null && best_model.error() < model.error() && Arrays.equals(best_model.model_info().units, model.model_info().units)) {
            Log.info("Setting the model to be the best model so far (based on scoring history).");
            DeepLearningModel.DeepLearningModelInfo mi = best_model.model_info().deep_clone();
            // Don't cheat - count full amount of training samples, since that's the amount of training it took to train (without finding anything better)
            mi.set_processed_global(model.model_info().get_processed_global());
            mi.set_processed_local(model.model_info().get_processed_local());
            model.set_model_info(mi);
            model.update(self());
            model.doScoring(trainScoreFrame, validScoreFrame, self(), _progressKey);
            assert(best_model.error() == model.error());
          }
        }

        Log.info("==============================================================================================================================================================================");
        Log.info("Finished training the Deep Learning model.");
        Log.info(model);
        Log.info("==============================================================================================================================================================================");
      }
      catch(Throwable ex) {
        model = DKV.get(dest()).get();
        Log.info("Deep Learning model building was cancelled.");
        throw new RuntimeException(ex);
      }
      finally {
        if (model != null) {
          model.unlock(self());
          if (model.actual_best_model_key != null) {
            assert (model.actual_best_model_key != model._key);
            DKV.remove(model.actual_best_model_key);
          }
        }
        for (Frame f : _delete_me) f.delete(); //delete internally rebalanced frames
      }
      return model;
    }
    transient HashSet _delete_me = new HashSet<>();

    /**
     * Rebalance a frame for load balancing
     * @param fr Input frame
     * @param local whether to only create enough chunks to max out all cores on one node only
     * @return Frame that has potentially more chunks
     */
    private Frame reBalance(final Frame fr, boolean local, String name) {
      int chunks = (int)Math.min( 4 * H2O.NUMCPUS * (local ? 1 : H2O.CLOUD.size()), fr.numRows());
      if (fr.anyVec().nChunks() > chunks && !_parms._reproducible) {
        Log.info("Dataset already contains " + fr.anyVec().nChunks() + " chunks. No need to rebalance.");
        return fr;
      } else if (_parms._reproducible) {
        Log.warn("Reproducibility enforced - using only 1 thread - can be slow.");
        chunks = 1;
      }
      if (!_parms._quiet_mode) Log.info("ReBalancing dataset into (at least) " + chunks + " chunks.");
      Key newKey = Key.make(name + ".chunks" + chunks);
      RebalanceDataSet rb = new RebalanceDataSet(fr, newKey, chunks);
      H2O.submitTask(rb);
      rb.join();
      Frame f = DKV.get(newKey).get();
      _delete_me.add(f);
      return f;
    }

    /**
     * Compute the actual train_samples_per_iteration size from the user-given parameter
     * @param mp Model parameter (DeepLearning object)
     * @param numRows number of training rows
     * @param model DL model
     * @return The total number of training rows to be processed per iteration (summed over on all nodes)
     */
    private long computeTrainSamplesPerIteration(final DeepLearningModel.DeepLearningParameters mp, final long numRows, DeepLearningModel model) {
      long tspi = mp._train_samples_per_iteration;
      assert(tspi == 0 || tspi == -1 || tspi == -2 || tspi >= 1);
      if (tspi == 0 || (!mp._replicate_training_data && tspi == -1) ) {
        tspi = numRows;
        if (!mp._quiet_mode) Log.info("Setting train_samples_per_iteration (" + mp._train_samples_per_iteration + ") to one epoch: #rows (" + tspi + ").");
      }
      else if (tspi == -1) {
        tspi = (mp._single_node_mode ? 1 : H2O.CLOUD.size()) * numRows;
        if (!mp._quiet_mode) Log.info("Setting train_samples_per_iteration (" + mp._train_samples_per_iteration + ") to #nodes x #rows (" + tspi + ").");
      } else if (tspi == -2) {
        // automatic tuning based on CPU speed, network speed and model size

      // measure cpu speed
      double total_gflops = 0;
      for (H2ONode h2o : H2O.CLOUD._memary) {
        HeartBeat hb = h2o._heartbeat;
        total_gflops += hb._gflops;
      }
      if (mp._single_node_mode) total_gflops /= H2O.CLOUD.size();
      if (total_gflops == 0) {
        total_gflops = Linpack.run(H2O.SELF._heartbeat._cpus_allowed) * (mp._single_node_mode ? 1 : H2O.CLOUD.size());
      }

      final long model_size = model.model_info().size();
      int[] msg_sizes = new int[]{ (int)(model_size*4) == (model_size*4) ? (int)(model_size*4) : Integer.MAX_VALUE };
      double[] microseconds_collective = new double[msg_sizes.length];
      NetworkTest.NetworkTester nt = new NetworkTest.NetworkTester(msg_sizes,null,microseconds_collective,model_size>1e6 ? 1 : 5 /*repeats*/,false,true /*only collectives*/);
      nt.compute2();

      //length of the network traffic queue based on log-tree rollup (2 log(nodes))
      int network_queue_length = mp._single_node_mode || H2O.CLOUD.size() == 1? 1 : 2*(int)Math.floor(Math.log(H2O.CLOUD.size())/Math.log(2));

      // heuristics
      double flops_overhead_per_row = 30;
      if (mp._activation == DeepLearningModel.DeepLearningParameters.Activation.Maxout || mp._activation == DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout) {
        flops_overhead_per_row *= 8;
      } else if (mp._activation == DeepLearningModel.DeepLearningParameters.Activation.Tanh || mp._activation == DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout) {
        flops_overhead_per_row *= 5;
      }

      // target fraction of comm vs cpu time: 5%
      double fraction = mp._single_node_mode || H2O.CLOUD.size() == 1 ? 1e-3 : 0.05; //one single node mode, there's no model averaging effect, so less need to shorten the M/R iteration

      // estimate the time for communication (network) and training (compute)
      model.time_for_communication_us = (H2O.CLOUD.size() == 1 ? 1e4 /* add 10ms for single-node */ : 0) + network_queue_length * microseconds_collective[0];
      double time_per_row_us  = flops_overhead_per_row * model_size / (total_gflops * 1e9) / H2O.SELF._heartbeat._cpus_allowed * 1e6;

      // compute the optimal number of training rows per iteration
      // fraction := time_comm_us / (time_comm_us + tspi * time_per_row_us)  ==>  tspi = (time_comm_us/fraction - time_comm_us)/time_per_row_us
      tspi = (long)((model.time_for_communication_us / fraction - model.time_for_communication_us)/ time_per_row_us);

      tspi = Math.min(tspi, (mp._single_node_mode ? 1 : H2O.CLOUD.size()) * numRows * 10); //not more than 10x of what train_samples_per_iteration=-1 would do

      // If the number is close to a multiple of epochs, use that -> prettier scoring
      if (tspi > numRows && Math.abs(tspi % numRows)/(double)numRows < 0.2)  tspi = tspi - tspi % numRows;
      tspi = Math.min(tspi, (long)(mp._epochs * numRows / 10)); //limit to number of epochs desired, but at least 10 iterations total
      tspi = Math.max(1, tspi); //at least 1 point

      if (!mp._quiet_mode) {
        Log.info("Auto-tuning parameter 'train_samples_per_iteration':");
        Log.info("Estimated compute power : " + (int)total_gflops + " GFlops");
        Log.info("Estimated time for comm : " + PrettyPrint.usecs((long) model.time_for_communication_us));
        Log.info("Estimated time per row  : " + ((long)time_per_row_us > 0 ? PrettyPrint.usecs((long) time_per_row_us) : time_per_row_us + " usecs"));
        Log.info("Estimated training speed: " + (int)(1e6/time_per_row_us) + " rows/sec");
        Log.info("Setting train_samples_per_iteration (" + mp._train_samples_per_iteration + ") to auto-tuned value: " + tspi);
      }

      } else {
        // limit user-given value to number of epochs desired
        tspi = Math.min(tspi, (long)(mp._epochs * numRows));
      }
      assert(tspi != 0 && tspi != -1 && tspi != -2 && tspi >= 1);
      return tspi;
    }

    /**
     * Compute the fraction of rows that need to be used for training during one iteration
     * @param numRows number of training rows
     * @param train_samples_per_iteration number of training rows to be processed per iteration
     * @param replicate_training_data whether of not the training data is replicated on each node
     * @return fraction of rows to be used for training during one iteration
     */
    private float computeRowUsageFraction(final long numRows, final long train_samples_per_iteration, final boolean replicate_training_data) {
      float rowUsageFraction = (float)train_samples_per_iteration / numRows;
      if (replicate_training_data) rowUsageFraction /= H2O.CLOUD.size();
      assert(rowUsageFraction > 0);
      return rowUsageFraction;
    }
    private float rowFraction(Frame train, DeepLearningModel.DeepLearningParameters p, DeepLearningModel m) {
      return computeRowUsageFraction(train.numRows(), m.actual_train_samples_per_iteration, p._replicate_training_data);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy