All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.deepwater.DeepWaterModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.deepwater;

import hex.*;
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.schemas.DeepWaterModelV3;
import hex.util.LinearAlgebraUtils;
import water.*;
import water.api.schemas3.ModelSchemaV3;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.udf.CFuncRef;
import water.util.FrameUtils;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.RandomUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Random;

import static hex.ModelMetrics.calcVarImp;
import static water.H2O.technote;

/**
 * The Deep Learning model
 * It contains a DeepWaterModelInfo with the most up-to-date model,
 * a scoring history, as well as some helpers to indicate the progress
 */
public class DeepWaterModel extends Model implements Model.DeepFeatures {

  @Override public DeepwaterMojoWriter getMojo() { return new DeepwaterMojoWriter(this); }

  // Default publicly visible Schema is V2
  public ModelSchemaV3 schema() { return new DeepWaterModelV3(); }

  void set_model_info(DeepWaterModelInfo mi) {
    model_info = mi;
  }

  final public DeepWaterModelInfo model_info() { return model_info; }

  @Override public ToEigenVec getToEigenVec() { return LinearAlgebraUtils.toEigen; }

//  final public VarImp varImp() { return _output.errors.variable_importances; }

  private volatile DeepWaterModelInfo model_info;

  // timing
  private long total_checkpointed_run_time_ms; //time spent in previous models
  private long total_training_time_ms; //total time spent running (training+scoring, including all previous models)
  private long total_scoring_time_ms; //total time spent scoring (including all previous models)
  long total_setup_time_ms; //total time spent setting up (including all previous models)
  private long time_of_start_ms; //start time for this model (this cp restart)

  // auto-tuning
  long actual_train_samples_per_iteration;
  long time_for_iteration_overhead_ms; //helper for auto-tuning: time in microseconds for collective bcast/reduce of the model

  // helpers for diagnostics
  double epoch_counter;
  int iterations;
  private boolean stopped_early;
  long training_rows;
  long validation_rows;

  // Keep the best model so far, based on a single criterion (overall class. error or MSE)
  private float _bestLoss = Float.POSITIVE_INFINITY;

  Key actual_best_model_key;

  static final String unstable_msg = technote(4,
      "\n\nTrying to predict with an unstable model." +
          "\nJob was aborted due to observed numerical instability (exponential growth)."
          + "\nEither the weights or the bias values are unreasonably large or lead to large activation values."
          + "\nTry a different network architecture, a bounded activation function (tanh), adding regularization"
          + "\n(via dropout) or use a smaller learning rate and/or momentum.");

  public DeepWaterScoringInfo last_scored() { return (DeepWaterScoringInfo) super.last_scored(); }


  /**
   * Get the parameters actually used for model building, not the user-given ones (_parms)
   * They might differ since some defaults are filled in, and some invalid combinations are auto-disabled in modifyParams
   * @return actually used parameters
   */
  public final DeepWaterParameters get_params() { return model_info.get_params(); }

  @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    switch(_output.getModelCategory()) {
      case Binomial:    return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
      case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
      case Regression:  return new ModelMetricsRegression.MetricBuilderRegression();
      case AutoEncoder: return new ModelMetricsAutoEncoder.MetricBuilderAutoEncoder(_output.nfeatures());
      default: throw H2O.unimpl("Invalid ModelCategory " + _output.getModelCategory());
    }
  }

  static DataInfo makeDataInfo(Frame train, Frame valid, DeepWaterParameters parms) {
    double x = 0.782347234;
    boolean identityLink = new Distribution(parms).link(x) == x;
    return new DataInfo(
        train,
        valid,
        parms._autoencoder ? 0 : 1, //nResponses
        parms._autoencoder || parms._use_all_factor_levels, //use all FactorLevels for auto-encoder
        parms._standardize ? (parms._autoencoder ? DataInfo.TransformType.NORMALIZE : parms._sparse ? DataInfo.TransformType.DESCALE : DataInfo.TransformType.STANDARDIZE) : DataInfo.TransformType.NONE, //transform predictors
        !parms._standardize || train.lastVec().isCategorical() ? DataInfo.TransformType.NONE : identityLink ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, //transform response for regression with identity link
        parms._missing_values_handling == DeepWaterParameters.MissingValuesHandling.Skip, //whether to skip missing
        false, // do not replace NAs in numeric cols with mean
        true,  // always add a bucket for missing values
        parms._weights_column != null, // observation weights
        parms._offset_column != null,
        parms._fold_column != null
    );
  }

  /** Constructor to restart from a checkpointed model
   *  @param destKey New destination key for the model
   *  @param parms User-given parameters for checkpoint restart
   *  @param cp Checkpoint to restart from
   */
  public DeepWaterModel(final Key destKey, final DeepWaterParameters parms, final DeepWaterModel cp, final DataInfo dataInfo) {
    super(destKey, parms == null ? (DeepWaterParameters)cp._parms.clone() : IcedUtils.deepCopy(parms), (DeepWaterModelOutput)cp._output.clone());
    DeepWaterParameters.Sanity.modifyParms(_parms, _parms, cp._output.nclasses()); //sanitize the model_info's parameters
    assert(_parms != cp._parms); //make sure we have a clone
    assert (_parms._checkpoint == cp._key);
    model_info = IcedUtils.deepCopy(cp.model_info);
    model_info._dataInfo = dataInfo;
    assert(model_info._network != null);
    assert(model_info._modelparams != null);
    model_info.javaToNative();
    _dist = new Distribution(get_params());
    assert(_dist.distribution != DistributionFamily.AUTO); // Note: Must use sanitized parameters via get_params() as this._params can still have defaults AUTO, etc.)
    actual_best_model_key = cp.actual_best_model_key;
    if (actual_best_model_key.get() == null) {
      DeepWaterModel best = IcedUtils.deepCopy(cp);
      //best.model_info.data_info = model_info.data_info; // Note: we currently DO NOT use the checkpoint's data info - as data may change during checkpoint restarts
      actual_best_model_key = Key.make(H2O.SELF);
      DKV.put(actual_best_model_key, best);
    }
    time_of_start_ms = cp.time_of_start_ms;
    total_training_time_ms = cp.total_training_time_ms;
    total_checkpointed_run_time_ms = cp.total_training_time_ms;
    total_scoring_time_ms = cp.total_scoring_time_ms;
    total_setup_time_ms = cp.total_setup_time_ms;
    training_rows = cp.training_rows; //copy the value to display the right number on the model page before training has started
    validation_rows = cp.validation_rows; //copy the value to display the right number on the model page before training has started
    _bestLoss = cp._bestLoss;
    epoch_counter = cp.epoch_counter;
    iterations = cp.iterations;

    // deep clone scoring history
    scoringInfo = cp.scoringInfo.clone();
    for (int i=0; i< scoringInfo.length;++i)
      scoringInfo[i] = IcedUtils.deepCopy(cp.scoringInfo[i]);
    _output.errors = last_scored();
    _output._scoring_history = DeepWaterScoringInfo.createScoringHistoryTable(scoringInfo, (null != get_params()._valid), false, _output.getModelCategory(), _output.isAutoencoder());
    _output._variable_importances = calcVarImp(last_scored().variable_importances);
    if (dataInfo!=null) {
      _output.setNames(dataInfo._adaptedFrame.names());
      _output._domains = dataInfo._adaptedFrame.domains();
    }
    assert(_key.equals(destKey));
  }

  private void setDataInfoToOutput(DataInfo dinfo) {
    if (dinfo == null) return;
    // update the model's expected frame format - needed for train/test adaptation
    _output.setNames(dinfo._adaptedFrame.names());
    _output._domains = dinfo._adaptedFrame.domains();
    _output._nums = dinfo._nums;
    _output._cats = dinfo._cats;
    _output._catOffsets = dinfo._catOffsets;
    _output._normMul = dinfo._normMul;
    _output._normSub = dinfo._normSub;
    _output._normRespMul = dinfo._normRespMul;
    _output._normRespSub = dinfo._normRespSub;
    _output._useAllFactorLevels = dinfo._useAllFactorLevels;
  }

  /**
   * Regular constructor (from scratch)
   * @param destKey destination key
   * @param params DL parameters
   * @param output DL model output
   * @param nClasses Number of classes (1 for regression or autoencoder)
   */
  public DeepWaterModel(final Key destKey, final DeepWaterParameters params, final DeepWaterModelOutput output, Frame train, Frame valid, int nClasses) {
    super(destKey, params, output);
    if (H2O.getCloudSize() != 1)
      throw new IllegalArgumentException("Deep Water currently only supports execution of 1 node.");
    _output._origNames = params._train.get().names();
    _output._origDomains = params._train.get().domains();

    DeepWaterParameters parms = (DeepWaterParameters) params.clone(); //make a copy, don't change model's parameters
    DeepWaterParameters.Sanity.modifyParms(parms, parms, nClasses); //sanitize the model_info's parameters
    DataInfo dinfo = null;
    if (parms._problem_type == DeepWaterParameters.ProblemType.dataset) {
      dinfo = makeDataInfo(train, valid, parms);
      DKV.put(dinfo);
      setDataInfoToOutput(dinfo);
      // either provide no image_shape (i.e., (0,0)), or provide both values and channels >= 1 (to turn it into an image problem)
      if (parms._image_shape != null && parms._image_shape[0] != 0) {
        if (parms._image_shape[0] < 0) {
          throw new IllegalArgumentException("image_shape must either have both values == 0 or both values >= 1 for " + parms._problem_type.getClass().toString() + "=" + parms._problem_type.toString());
        }
        if (parms._image_shape[1] <= 0) {
          throw new IllegalArgumentException("image_shape must either have both values == 0 or both values >= 1 for " + parms._problem_type.getClass().toString() + "=" + parms._problem_type.toString());
        }
        if (parms._channels <= 0) {
          throw new IllegalArgumentException("channels must be >= 1 when image_shape is provided for " + parms._problem_type.getClass().toString() + "=" + parms._problem_type.toString());
        }
        if (dinfo.fullN() != parms._image_shape[0] * parms._image_shape[1] * parms._channels) {
          throw new IllegalArgumentException("Data input size mismatch: Expect image_shape[0] x image_shape[1] x channels == #cols(H2OFrame), but got: "
              + parms._image_shape[0] + " x " + parms._image_shape[1] + " x " + parms._channels + " != " + dinfo.fullN() + ". Check these parameters, or disable ignore_const_cols.");
        }
      }
    }
    model_info = new DeepWaterModelInfo(parms, nClasses, dinfo != null ? dinfo.fullN() : -1);
    model_info._dataInfo = dinfo;
    if (dinfo!=null) {
      FrameUtils.printTopCategoricalLevels(dinfo._adaptedFrame, dinfo.fullN() > 10000, 10);
      Log.info("Building the model on " + dinfo.numNums() + " numeric features and " + dinfo.numCats() + " (one-hot encoded) categorical features.");
    }

    // now, parms is get_params();
    _dist = new Distribution(get_params());
    assert(_dist.distribution != DistributionFamily.AUTO); // Note: Must use sanitized parameters via get_params() as this._params can still have defaults AUTO, etc.)
    actual_best_model_key = Key.make(H2O.SELF);
    if (get_params()._nfolds != 0) actual_best_model_key = null;
    if (!get_params()._autoencoder) {
      scoringInfo = new DeepWaterScoringInfo[1];
      scoringInfo[0] = new DeepWaterScoringInfo();
      scoringInfo[0].validation = (get_params()._valid != null);
      scoringInfo[0].time_stamp_ms = System.currentTimeMillis();
      _output.errors = last_scored();
      _output._scoring_history = DeepWaterScoringInfo.createScoringHistoryTable(scoringInfo, (null != get_params()._valid), false, _output.getModelCategory(), _output.isAutoencoder());
      _output._variable_importances = calcVarImp(last_scored().variable_importances);
    }
    time_of_start_ms = System.currentTimeMillis();
    assert _key.equals(destKey);
    boolean fail = false;
    long byte_size = 0;
    try {
      byte_size = new AutoBuffer().put(this).buf().length;
    } catch(Throwable t) {
      fail = true;
    }
    if (byte_size > Value.MAX || fail)
      throw new IllegalArgumentException(technote(5, "Model is too large to fit into the DKV (larger than " + PrettyPrint.bytes(Value.MAX) + ")."));
  }

  long _timeLastIterationEnter;
  private long _timeLastScoreStart; //start actual scoring
  private long _timeLastScoreEnd;  //finished actual scoring
  private long _timeLastPrintStart;

  private void checkTimingConsistency() {
    assert(total_scoring_time_ms <= total_training_time_ms);
    assert(total_setup_time_ms <= total_training_time_ms);
    assert(total_setup_time_ms+total_scoring_time_ms <= total_training_time_ms);
    assert(total_training_time_ms >= total_checkpointed_run_time_ms);
    assert(total_checkpointed_run_time_ms >= 0);
    assert(total_training_time_ms >= 0);
    assert(total_scoring_time_ms >= 0);
  }

  private void updateTiming(Key job_key) {
    final long now = System.currentTimeMillis();
    long start_time_current_model = job_key.get().start_time();
    total_training_time_ms = total_checkpointed_run_time_ms + (now - start_time_current_model);
    checkTimingConsistency();
  }

  /**
   * Score this DeepWater model
   * @param fTrain potentially downsampled training data for scoring
   * @param fValid  potentially downsampled validation data for scoring
   * @param jobKey key of the owning job
   * @param iteration Map/Reduce iteration count
   * @return true if model building is ongoing
   */
  boolean doScoring(Frame fTrain, Frame fValid, Key jobKey, int iteration, boolean finalScoring) {
    final long now = System.currentTimeMillis();
    final double time_since_last_iter = now - _timeLastIterationEnter;
    updateTiming(jobKey);
    _timeLastIterationEnter = now;
    epoch_counter = (double)model_info().get_processed_total()/training_rows;

    boolean keep_running;
    // Auto-tuning
    // if multi-node and auto-tuning and at least 10 ms for communication and per-iteration overhead (to avoid doing thins on multi-JVM on same node),
    // then adjust the auto-tuning parameter 'actual_train_samples_per_iteration' such that the targeted ratio of comm to comp is achieved
    if (get_params()._train_samples_per_iteration == -2 && iteration > 1) {
      Log.debug("Auto-tuning train_samples_per_iteration.");
      if (time_for_iteration_overhead_ms > 10) {
        Log.debug("  Time taken for per-iteration comm overhead: " + PrettyPrint.msecs(time_for_iteration_overhead_ms, true));
        Log.debug("  Time taken for Map/Reduce iteration: " + PrettyPrint.msecs((long) time_since_last_iter, true));
        final double comm_to_work_ratio = time_for_iteration_overhead_ms / time_since_last_iter;
        Log.debug("  Ratio of per-iteration comm overhead to computation: " + String.format("%.5f", comm_to_work_ratio));
        Log.debug("  target_comm_to_work: " + get_params()._target_ratio_comm_to_comp);
        Log.debug("Old value of train_samples_per_iteration: " + actual_train_samples_per_iteration);
        double correction = get_params()._target_ratio_comm_to_comp / comm_to_work_ratio;
        correction = Math.max(0.5,Math.min(2, correction)); //it's ok to train up to 2x more training rows per iteration, but not fewer than half.
        if (Math.abs(correction) < 0.8 || Math.abs(correction) > 1.2) { //don't correct unless it's significant (avoid slow drift)
          actual_train_samples_per_iteration /= correction;
          actual_train_samples_per_iteration = Math.max(1, actual_train_samples_per_iteration);
          Log.debug("New value of train_samples_per_iteration: " + actual_train_samples_per_iteration);
        } else {
          Log.debug("Keeping value of train_samples_per_iteration the same (would deviate too little from previous value): " + actual_train_samples_per_iteration);
        }
      } else {
        Log.debug("Iteration overhead is faster than 10 ms. Not modifying train_samples_per_iteration: " + actual_train_samples_per_iteration);
      }
    }

    keep_running = (epoch_counter < get_params()._epochs) && !stopped_early;
    final long sinceLastScore = now -_timeLastScoreStart;

    // this is potentially slow - only do every so often
    if( !keep_running || get_params()._score_each_iteration ||
        (sinceLastScore > get_params()._score_interval *1000 //don't score too often
            &&(double)(_timeLastScoreEnd-_timeLastScoreStart)/sinceLastScore < get_params()._score_duty_cycle) ) { //duty cycle
      Log.info(logNvidiaStats());
      jobKey.get().update(0,"Scoring on " + fTrain.numRows() + " training samples" +(fValid != null ? (", " + fValid.numRows() + " validation samples") : ""));
      final boolean printme = !get_params()._quiet_mode;
      _timeLastScoreStart = System.currentTimeMillis();
      DeepWaterScoringInfo scoringInfo = new DeepWaterScoringInfo();
      scoringInfo.time_stamp_ms = _timeLastScoreStart;
      updateTiming(jobKey);
      scoringInfo.total_training_time_ms = total_training_time_ms;
      scoringInfo.total_scoring_time_ms = total_scoring_time_ms;
      scoringInfo.total_setup_time_ms = total_setup_time_ms;
      scoringInfo.epoch_counter = epoch_counter;
      scoringInfo.iterations = iterations;
      scoringInfo.training_samples = (double)model_info().get_processed_total();
      scoringInfo.validation = fValid != null;
      scoringInfo.score_training_samples = fTrain.numRows();
      scoringInfo.score_validation_samples = get_params()._score_validation_samples;
      scoringInfo.is_classification = _output.isClassifier();
      scoringInfo.is_autoencoder = _output.isAutoencoder();

      if (printme) Log.info("Scoring the model.");
      // compute errors
      final String m = model_info().toString();
      if (m.length() > 0) Log.info(m);

      // For GainsLift and Huber, we need the full predictions to compute the model metrics
      boolean needPreds = _output.nclasses() == 2 /* gains/lift table requires predictions */ || get_params()._distribution==DistributionFamily.huber;

      // Scoring on training data
      ModelMetrics mtrain;
      Frame preds = null;
      if (needPreds) {
        // allocate predictions since they are needed
        preds = score(fTrain);
        mtrain = ModelMetrics.getFromDKV(this, fTrain);
      } else {
        // no need to allocate predictions
        ModelMetrics.MetricBuilder mb = scoreMetrics(fTrain);
        mtrain = mb.makeModelMetrics(this,fTrain,fTrain,null);
      }
      if (preds!=null) preds.remove();
      _output._training_metrics = mtrain;
      scoringInfo.scored_train = new ScoreKeeper(mtrain);
      ModelMetricsSupervised mm1 = (ModelMetricsSupervised)mtrain;
      if (mm1 instanceof ModelMetricsBinomial) {
        ModelMetricsBinomial mm = (ModelMetricsBinomial)(mm1);
        scoringInfo.training_AUC = mm._auc;
      }
      if (fTrain.numRows() != training_rows) {
        _output._training_metrics._description = "Metrics reported on temporary training frame with " + fTrain.numRows() + " samples";
      } else if (fTrain._key != null && fTrain._key.toString().contains("chunks")){
        _output._training_metrics._description = "Metrics reported on temporary (load-balanced) training frame";
      } else {
        _output._training_metrics._description = "Metrics reported on full training frame";
      }

      // Scoring on validation data
      ModelMetrics mvalid;
      if (fValid != null) {
        preds = null;
        if (needPreds) {
          // allocate predictions since they are needed
          preds = score(fValid);
          mvalid = ModelMetrics.getFromDKV(this, fValid);
        } else {
          // no need to allocate predictions
          ModelMetrics.MetricBuilder mb = scoreMetrics(fValid);
          mvalid = mb.makeModelMetrics(this, fValid, fValid,null);
        }
        if (preds!=null) preds.remove();
        _output._validation_metrics = mvalid;
        scoringInfo.scored_valid = new ScoreKeeper(mvalid);
        if (mvalid != null) {
          if (mvalid instanceof ModelMetricsBinomial) {
            ModelMetricsBinomial mm = (ModelMetricsBinomial) mvalid;
            scoringInfo.validation_AUC = mm._auc;
          }
          if (fValid.numRows() != validation_rows) {
            _output._validation_metrics._description = "Metrics reported on temporary validation frame with " + fValid.numRows() + " samples";
          } else if (fValid._key != null && fValid._key.toString().contains("chunks")){
            _output._validation_metrics._description = "Metrics reported on temporary (load-balanced) validation frame";
          } else {
            _output._validation_metrics._description = "Metrics reported on full validation frame";
          }
        }
      }
//      if (get_params()._variable_importances) {
//        if (!get_params()._quiet_mode) Log.info("Computing variable importances.");
//        throw H2O.unimpl();
//        final float[] vi = model_info().computeVariableImportances();
//        scoringInfo.variable_importances = new VarImp(vi, Arrays.copyOfRange(model_info().data_info().coefNames(), 0, vi.length));
//      }

      _timeLastScoreEnd = System.currentTimeMillis();
      long scoringTime = _timeLastScoreEnd - _timeLastScoreStart;
      total_scoring_time_ms += scoringTime;
      updateTiming(jobKey);
      // update the scoringInfo object to report proper speed
      scoringInfo.total_training_time_ms = total_training_time_ms;
      scoringInfo.total_scoring_time_ms = total_scoring_time_ms;
      scoringInfo.this_scoring_time_ms = scoringTime;
      // enlarge the error array by one, push latest score back
      if (this.scoringInfo == null) {
        this.scoringInfo = new DeepWaterScoringInfo[]{scoringInfo};
      } else {
        DeepWaterScoringInfo[] err2 = new DeepWaterScoringInfo[this.scoringInfo.length + 1];
        System.arraycopy(this.scoringInfo, 0, err2, 0, this.scoringInfo.length);
        err2[err2.length - 1] = scoringInfo;
        this.scoringInfo = err2;
      }
      _output.errors = last_scored();
      _output._scoring_history = DeepWaterScoringInfo.createScoringHistoryTable(this.scoringInfo, (null != get_params()._valid), false, _output.getModelCategory(), _output.isAutoencoder());
      _output._variable_importances = calcVarImp(last_scored().variable_importances);
      _output._model_summary = model_info.createSummaryTable();

      // always keep a copy of the best model so far (based on the following criterion)
      if (!finalScoring) {
        if (actual_best_model_key != null && get_params()._overwrite_with_best_model && (
                // if we have a best_model in DKV, then compare against its error() (unless it's a different model as judged by the network size)
                (DKV.get(actual_best_model_key) != null && !(loss() >= DKV.get(actual_best_model_key).get().loss() ) )
                        ||
                        // otherwise, compare against our own _bestError
                        (DKV.get(actual_best_model_key) == null && loss() < _bestLoss)
        ) ) {
          _bestLoss = loss();
          model_info.nativeToJava();
          putMeAsBestModel(actual_best_model_key);
        }
        // print the freshly scored model to ASCII
        if (keep_running && printme)
          Log.info(toString());
        if (ScoreKeeper.stopEarly(ScoringInfo.scoreKeepers(scoring_history()),
                get_params()._stopping_rounds, _output.isClassifier(), get_params()._stopping_metric, get_params()._stopping_tolerance, "model's last", true
        )) {
          Log.info("Convergence detected based on simple moving average of the loss function for the past " + get_params()._stopping_rounds + " scoring events. Model building completed.");
          stopped_early = true;
        }
        if (printme) Log.info("Time taken for scoring and diagnostics: " + PrettyPrint.msecs(scoringInfo.this_scoring_time_ms, true));
      }
    }
    if (stopped_early) {
      // pretend as if we finished all epochs to get the progress bar pretty (especially for N-fold and grid-search)
      ((Job) DKV.getGet(jobKey)).update((long) (get_params()._epochs * training_rows));
      update(jobKey);
      return false;
    }
    progressUpdate(jobKey, keep_running);
    //update(jobKey);
    return keep_running;
  }

  private void putMeAsBestModel(Key bestModelKey) {
    DKV.put(bestModelKey, IcedUtils.deepCopy(this));
    assert DKV.get(bestModelKey) != null;
    assert ((DeepWaterModel)DKV.getGet(bestModelKey)).compareTo(this) <= 0;
  }

  private void progressUpdate(Key job_key, boolean keep_running) {
    updateTiming(job_key);
    Job job = job_key.get();
    double progress = job.progress();
//    Log.info("2nd speed: (samples: " + model_info().get_processed_total() + ", total_run_time: " + total_training_time_ms + ", total_scoring_time: " + total_scoring_time_ms + ", total_setup_time: " + total_setup_time_ms + ")");
    float speed = (float)(model_info().get_processed_total() * 1000. / (total_training_time_ms -total_scoring_time_ms-total_setup_time_ms));
    assert(speed >= 0) : "negative speed computed! (total_run_time: " + total_training_time_ms + ", total_scoring_time: " + total_scoring_time_ms + ", total_setup_time: " + total_setup_time_ms + ")";
    String msg =
            "Iterations: " + String.format("%,d", iterations)
            + ". Epochs: " + String.format("%g", epoch_counter)
            + ". Speed: " + (speed>10 ? String.format("%d", (int)speed) : String.format("%g", speed)) + " samples/sec."
            + (progress == 0 ? "" : " Estimated time left: " + PrettyPrint.msecs((long) (total_training_time_ms * (1. - progress) / progress), true));
    job.update(actual_train_samples_per_iteration,msg); //mark the amount of work done for the progress bar
    long now = System.currentTimeMillis();
    long sinceLastPrint = now -_timeLastPrintStart;
    if (!keep_running || sinceLastPrint > get_params()._score_interval * 1000) { //print this after every score_interval, not considering duty cycle
      _timeLastPrintStart = now;
      if (!get_params()._quiet_mode) {
        Log.info(
                "Training time: " + PrettyPrint.msecs(total_training_time_ms, true) + " (scoring: " + PrettyPrint.msecs(total_scoring_time_ms, true) + "). "
                + "Processed " + String.format("%,d", model_info().get_processed_total()) + " samples" + " (" + String.format("%.3f", epoch_counter) + " epochs).\n");
        Log.info(msg);
      }
    }
  }

  private int backendCount = 0;

  @Override
  protected void setupBigScorePredict() {
    synchronized (model_info()) {
      backendCount++;
      // Initial init of backend + model, backend is shared across threads
      if (null == model_info()._backend) {
        model_info().javaToNative();
      }
      // Backend already initialized, initialize model per thread
      if (null == model_info().getModel().get()) {
        model_info().initModel();
      }
    }
  }

  @Override
  protected void closeBigScorePredict() {
    synchronized (model_info()) {
      if (0 == --backendCount) {
        // No more threads using the backend, nuke backend + model
        model_info().nukeBackend();
      } else if (null != model_info().getModel().get()) {
        // Backend still used by other threads, nuke only model
        model_info().nukeModel();
      }
    }
  }

  /**
   * Single-instance scoring - slow, not optimized for mini-batches - do not use unless you know what you're doing
   * @param data One single observation unrolled into a double[], with a length equal to the number of input neurons
   * @param preds Array to store the predictions in (nclasses+1)
   * @return vector of [0, p0, p1, p2, etc.]
   */
  @Override protected double[] score0(double[] data, double[] preds) {
    //allocate a big enough array for the model to be able to score with mini_batch
    float[] f = new float[_parms._mini_batch_size * data.length];
    for (int i=0; i=2) {
      for (int i = 1; i < _output.nclasses()+1; ++i) preds[i] = predFloats[i];
    } else {
      preds[0] = predFloats[0];
    }
    return preds;
  }

  @Override public double[] score0(double[] data, double[] preds, double offset) {
    assert(offset==0);
    return score0(data, preds);
  }

  @Override protected long checksum_impl() {
    return super.checksum_impl() * _output._run_time + model_info().hashCode();
  }

  @Override
  public Frame scoreAutoEncoder(Frame frame, Key destination_key, boolean reconstruction_error_per_feature) {
    throw H2O.unimpl();
  }

  @Override
  public Frame scoreDeepFeatures(Frame frame, int layer) {
    throw H2O.unimpl();
  }

  @Override
  public Frame scoreDeepFeatures(Frame frame, int layer, Job j) {
    throw H2O.unimpl();
  }

  @Override
  public Frame scoreDeepFeatures(Frame frame, String layer, Job job) {
    if (layer == null)
      throw new H2OIllegalArgumentException("must give hidden layer (symbol) name to extract - cannot be null");
    if (isSupervised()) {
      int ridx = frame.find(_output.responseName());
      if (ridx != -1) { // drop the response for scoring!
        frame = new Frame(frame);
        frame.remove(ridx);
      }
    }
    Frame adaptFrm = new Frame(frame);
    Scope.enter();
    adaptTestForTrain(adaptFrm, true, false);
    Frame _fr = adaptFrm;

    DataInfo di = model_info()._dataInfo;
    if (di != null) {
      di = IcedUtils.deepCopy(di);
      di._adaptedFrame = _fr; //dinfo logic on _adaptedFrame is what we'll need for extracting standardized features from the data for scoring
    }
    final int dataIdx = 0; //FIXME
    final int weightIdx =_fr.find(get_params()._weights_column);
    final int batch_size = get_params()._mini_batch_size;

    ArrayList score_data = new ArrayList(); //for binary data (path to data)
    ArrayList skipped = new ArrayList();

    // randomly add more rows to fill up to a multiple of batch_size
    long seed = 0xDECAF + 0xD00D * model_info().get_processed_global();
    Random rng = RandomUtils.getRNG(seed);

    //make predictions for all rows - even those with weights 0 for now (easier to deal with minibatch)
    BufferedString bs = new BufferedString();
    if ((int)_fr.numRows() != _fr.numRows()) {
      throw new IllegalArgumentException("Cannot handle datasets with more than 2 billion rows.");
    }
    for (int i=0; i<_fr.numRows(); ++i) {
      double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i);
      if (weight == 0) { //don't send observations with weight 0 to the GPU
        skipped.add(i);
        continue;
      }
      if (model_info().get_params()._problem_type == DeepWaterParameters.ProblemType.image
          || model_info().get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
        BufferedString file = _fr.vec(dataIdx).atStr(bs, i);
        if (file!=null)
          score_data.add(file.toString());
      } else if (model_info().get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
        score_data.add(i);
      } else throw H2O.unimpl();
    }

    while (score_data.size() % batch_size != 0) {
      int pick = rng.nextInt(score_data.size());
      score_data.add(score_data.get(pick));
    }

    assert(isSupervised()); //not yet implemented for autoencoder
    final boolean makeNative = model_info()._backend ==null;
    if (makeNative) model_info().javaToNative();

    Frame _predFrame = null;
    DeepWaterIterator iter;
    try {
      // first, figure out hidden layer dimensionality - do this the hard way
      int cols;
      {
        if (model_info().get_params()._problem_type == DeepWaterParameters.ProblemType.image) {
          int width = model_info()._width;
          int height = model_info()._height;
          int channels = model_info()._channels;
          iter = new DeepWaterImageIterator(score_data, null /*no labels*/, model_info()._meanData, batch_size, width, height, channels, model_info().get_params()._cache_data);
        } else if (model_info().get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
          iter = new DeepWaterDatasetIterator(score_data, null /*no labels*/, di, batch_size, model_info().get_params()._cache_data);
        } else if (model_info().get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
          iter = new DeepWaterTextIterator(score_data, null /*no labels*/, batch_size, 56 /*FIXME*/, model_info().get_params()._cache_data);
        } else {
          throw H2O.unimpl();
        }
        float[] data = iter.getData();
        float[] predFloats = model_info().extractLayer(layer, data); //just to see how big this gets
        if (predFloats.length == 0) {
          throw new IllegalArgumentException(model_info().listAllLayers());
        }
        cols = predFloats.length;
        assert (cols % batch_size == 0);
        cols /= batch_size;
      }

      // allocate the predictions Vec/Frame
      Vec[] predVecs = new Vec[cols];
      for (int i = 0; i < cols; ++i)
        predVecs[i] = _fr.anyVec().makeZero();
      _predFrame = new Frame(predVecs);
      String[] names = new String[cols];
      for (int j=0; j




© 2015 - 2025 Weber Informatics LLC | Privacy Policy