All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.automl.AutoML Maven / Gradle / Ivy

The newest version!
package ai.h2o.automl;

import ai.h2o.automl.AutoMLBuildSpec.AutoMLBuildModels;
import ai.h2o.automl.AutoMLBuildSpec.AutoMLInput;
import ai.h2o.automl.AutoMLBuildSpec.AutoMLStoppingCriteria;
import ai.h2o.automl.StepResultState.ResultStatus;
import ai.h2o.automl.WorkAllocations.Work;
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.events.EventLogEntry.Stage;
import ai.h2o.automl.leaderboard.ModelGroup;
import ai.h2o.automl.leaderboard.ModelProvider;
import ai.h2o.automl.leaderboard.ModelStep;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import hex.Model;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.leaderboard.*;
import hex.splitframe.ShuffleSplitFrame;
import water.*;
import water.automl.api.schemas3.AutoMLV99;
import water.exceptions.H2OAutoMLException;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.nbhm.NonBlockingHashMap;
import water.util.*;

import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static ai.h2o.automl.AutoMLBuildSpec.AutoMLStoppingCriteria.AUTO_STOPPING_TOLERANCE;


/**
 * H2O AutoML
 *
 * AutoML is used for automating the machine learning workflow, which includes automatic training and
 * tuning of many models within a user-specified time-limit. Stacked Ensembles will be automatically
 * trained on collections of individual models to produce highly predictive ensemble models which, in most cases,
 * will be the top performing models in the AutoML Leaderboard.
 */
public final class AutoML extends Lockable implements TimedH2ORunnable {

  public enum Constraint {
    MODEL_COUNT,
    TIMEOUT,
    FAILURE_COUNT,
  }

  public static final Comparator byStartTime = Comparator.comparing(a -> a._startTime);
  public static final String keySeparator = "@@";
  
  private static final int DEFAULT_MAX_CONSECUTIVE_MODEL_FAILURES = 10; 

  private static final boolean verifyImmutability = true; // check that trainingFrame hasn't been messed with
  private static final ThreadLocal timestampFormatForKeys = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMdd_HHmmss"));
  private static final Logger log = LoggerFactory.getLogger(AutoML.class);

  private static LeaderboardExtensionsProvider createLeaderboardExtensionProvider(AutoML automl) {
    final Key amlKey = automl._key;

    return new LeaderboardExtensionsProvider() {
      @Override
      public LeaderboardCell[] createExtensions(Model model) {
        final AutoML aml = amlKey.get();
        ModelingStep step = aml.session().getModelingStep(model.getKey());
        return new LeaderboardCell[] {
                new TrainingTime(model),
                new ScoringTimePerRow(model, aml.getLeaderboardFrame() == null ? aml.getTrainingFrame() : aml.getLeaderboardFrame()),
//                new ModelSize(model._key)
                new AlgoName(model),
                new ModelProvider(model, step),
                new ModelStep(model, step),
                new ModelGroup(model, step),
        };
      }
    };
  }

  /**
   * Instantiate an AutoML object and start it running.  Progress can be tracked via its job().
   *
   * @param buildSpec
   * @return a new running AutoML instance.
   */
  public static AutoML startAutoML(AutoMLBuildSpec buildSpec) {
    AutoML aml = new AutoML(buildSpec);
    aml.submit();
    return aml;
  }
  
  static AutoML startAutoML(AutoMLBuildSpec buildSpec, boolean testMode) {
    AutoML aml = new AutoML(buildSpec);
    aml._testMode = testMode;
    aml.submit();
    return aml;
  }

  @Override
  public Class makeSchema() {
    return AutoMLV99.AutoMLKeyV3.class;
  }

  private AutoMLBuildSpec _buildSpec;     // all parameters for doing this AutoML build
  private Frame _origTrainingFrame;       // untouched original training frame

  public AutoMLBuildSpec getBuildSpec() { return _buildSpec; }

  public Frame getTrainingFrame() { return _trainingFrame; }
  public Frame getValidationFrame() { return _validationFrame; }
  public Frame getBlendingFrame() { return _blendingFrame; }
  public Frame getLeaderboardFrame() { return _leaderboardFrame; }

  public Vec getResponseColumn() { return _responseColumn; }
  public Vec getFoldColumn() { return _foldColumn; }
  public Vec getWeightsColumn() { return _weightsColumn; }

  public DistributionFamily getDistributionFamily() {
    return _distributionFamily;
  }

  public double[] getClassDistribution() {
    if (_classDistribution == null)
      _classDistribution = (new MRUtils.ClassDist(_responseColumn)).doAll(_responseColumn).dist();
    return _classDistribution;
  }

  public StepDefinition[] getActualModelingSteps() { return _actualModelingSteps; }

  Frame _trainingFrame;    // required training frame: can add and remove Vecs, but not mutate Vec data in place.
  Frame _validationFrame;  // optional validation frame; the training_frame is split automatically if it's not specified.
  Frame _blendingFrame;    // optional blending frame for SE (usually if xval is disabled).
  Frame _leaderboardFrame; // optional test frame used for leaderboard scoring; if not specified, leaderboard will use xval metrics.

  Vec _responseColumn;
  Vec _foldColumn;
  Vec _weightsColumn;

  DistributionFamily _distributionFamily;
  private double[] _classDistribution;

  Date _startTime;
  Countdown _runCountdown;
  Job _job;                  // the Job object for the build of this AutoML.
  WorkAllocations _workAllocations;
  StepDefinition[] _actualModelingSteps; // the output definition, listing only the steps that were actually used

  int _maxConsecutiveModelFailures = DEFAULT_MAX_CONSECUTIVE_MODEL_FAILURES;
  AtomicInteger _consecutiveModelFailures = new AtomicInteger();
  AtomicLong _incrementalSeed = new AtomicLong();
  private String _runId;

  private ModelingStepsExecutor _modelingStepsExecutor;
  private AutoMLSession _session;
  private Leaderboard _leaderboard;
  private EventLog _eventLog;

  // check that we haven't messed up the original Frame
  private Vec[] _originalTrainingFrameVecs;
  private String[] _originalTrainingFrameNames;
  private long[] _originalTrainingFrameChecksums;
  private transient NonBlockingHashMap _trackedKeys = new NonBlockingHashMap<>();
  private transient ModelingStep[] _executionPlan;
  private transient PreprocessingStep[] _preprocessing;
  transient StepResultState[] _stepsResults;

  private boolean _useAutoBlending;
  private boolean _testMode;  // when on, internal states are kept for inspection
  /**
   * DO NOT USE explicitly: for schema/reflection only.
   */
  public AutoML() {
    super(null);
  }

  public AutoML(AutoMLBuildSpec buildSpec) {
    this(new Date(), buildSpec);
  }
  
  public AutoML(Key key, AutoMLBuildSpec buildSpec) {
    this(key, new Date(), buildSpec);
  }

  /**
   * @deprecated use {@link #AutoML(AutoMLBuildSpec) instead}
   */
  @Deprecated
  public AutoML(Date startTime, AutoMLBuildSpec buildSpec) {
    this(null, startTime, buildSpec);
  }

  /**
   * @deprecated use {@link #AutoML(Key, AutoMLBuildSpec) instead}
   */
  @Deprecated
  public AutoML(Key key, Date startTime, AutoMLBuildSpec buildSpec) {
    super(key == null ? buildSpec.makeKey() : key);
    try {
      _startTime = startTime;
      _session = AutoMLSession.getInstance(_key.toString());
      _eventLog = EventLog.getOrMake(_key);
      eventLog().info(Stage.Workflow, "Project: "+buildSpec.project());

      validateBuildSpec(buildSpec);
      _buildSpec = buildSpec;
      // now that buildSpec is validated, we can assign it: all future logic can now safely access parameters through _buildSpec.
      _runId = _buildSpec.instanceId();
      _runCountdown = Countdown.fromSeconds(_buildSpec.build_control.stopping_criteria.max_runtime_secs());
      _incrementalSeed.set(_buildSpec.build_control.stopping_criteria.seed());

      prepareData();
      initLeaderboard();
      initPreprocessing();
      _modelingStepsExecutor = new ModelingStepsExecutor(_leaderboard, _eventLog, _runCountdown);
    } catch (Exception e) {
      delete(); //cleanup potentially leaked keys
      throw e;
    }
  }

  /**
   * Validates all buildSpec parameters and provide reasonable defaults dynamically or parameter cleaning if necessary.
   *
   * Ideally, validation should be fast as we should be able to call it in the future
   * directly from client (e.g. Flow) to validate parameters before starting the AutoML run.
   * That's also the reason why validate methods should not modify data,
   * only possibly read them to validate parameters that may depend on data.
   *
   * In the future, we may also reuse ModelBuilder.ValidationMessage to return all validation results at once to the client (cf. ModelBuilder).
   *
   * @param buildSpec all the AutoML parameters to validate.
   */
  private void validateBuildSpec(AutoMLBuildSpec buildSpec) {
    validateInput(buildSpec.input_spec);
    validateModelValidation(buildSpec);
    validateModelBuilding(buildSpec.build_models);
    validateEarlyStopping(buildSpec.build_control.stopping_criteria, buildSpec.input_spec);
    validateReproducibility(buildSpec);
  }

  private void validateInput(AutoMLInput input) {
    if (DKV.getGet(input.training_frame) == null)
      throw new H2OIllegalArgumentException("No training data has been specified, either as a path or a key, or it is not available anymore.");

    final Frame trainingFrame = DKV.getGet(input.training_frame);
    final Frame validationFrame = DKV.getGet(input.validation_frame);
    final Frame blendingFrame = DKV.getGet(input.blending_frame);
    final Frame leaderboardFrame = DKV.getGet(input.leaderboard_frame);

    Map compatibleFrames = new LinkedHashMap(){{
      put("training", trainingFrame);
      put("validation", validationFrame);
      put("blending", blendingFrame);
      put("leaderboard", leaderboardFrame);
    }};
    for (Map.Entry entry : compatibleFrames.entrySet()) {
      Frame frame = entry.getValue();
      if (frame != null && frame.find(input.response_column) < 0) {
        throw new H2OIllegalArgumentException("Response column '"+input.response_column+"' is not in the "+entry.getKey()+" frame.");
      }
    }

    if (input.fold_column != null && trainingFrame.find(input.fold_column) < 0) {
      throw new H2OIllegalArgumentException("Fold column '"+input.fold_column+"' is not in the training frame.");
    }
    if (input.weights_column != null && trainingFrame.find(input.weights_column) < 0) {
      throw new H2OIllegalArgumentException("Weights column '"+input.weights_column+"' is not in the training frame.");
    }

    if (input.ignored_columns != null) {
      List ignoredColumns = new ArrayList<>(Arrays.asList(input.ignored_columns));
      Map doNotIgnore = new LinkedHashMap(){{
        put("response_column", input.response_column);
        put("fold_column", input.fold_column);
        put("weights_column", input.weights_column);
      }};
      for (Map.Entry entry: doNotIgnore.entrySet()) {
        if (entry.getValue() != null && ignoredColumns.contains(entry.getValue())) {
          eventLog().info(Stage.Validation,
                  "Removing "+entry.getKey()+" '"+entry.getValue()+"' from list of ignored columns.");
          ignoredColumns.remove(entry.getValue());
        }
      }
      input.ignored_columns = ignoredColumns.toArray(new String[0]);
    }
  }


  private void validateModelValidation(AutoMLBuildSpec buildSpec) {
    if (buildSpec.input_spec.fold_column != null) {
      eventLog().warn(Stage.Validation, "Fold column " + buildSpec.input_spec.fold_column + " will be used for cross-validation. nfolds parameter will be ignored.");
      buildSpec.build_control.nfolds = 0;
    } else if (buildSpec.build_control.nfolds == -1) {
      Frame trainingFrame = DKV.getGet(buildSpec.input_spec.training_frame);
      long nrows = trainingFrame.numRows();
      long ncols = trainingFrame.numCols() - (buildSpec.getNonPredictors().length +
              (buildSpec.input_spec.ignored_columns != null ? buildSpec.input_spec.ignored_columns.length : 0));

      double max_runtime = buildSpec.build_control.stopping_criteria.max_runtime_secs();
      long nthreads = Stream.of(H2O.CLOUD.members())
              .mapToInt((h2o) -> h2o._heartbeat._nthreads)
              .sum();

      boolean use_blending = ((ncols * nrows) / (max_runtime * nthreads)) > 2064;
      if (max_runtime > 0 && use_blending &&
              !(buildSpec.build_control.keep_cross_validation_predictions ||
                buildSpec.build_control.keep_cross_validation_models ||
                buildSpec.build_control.keep_cross_validation_fold_assignment)) {
        _useAutoBlending = true;
        buildSpec.build_control.nfolds = 0;
        eventLog().info(Stage.Validation, "Blending will be used.");
      } else {
        buildSpec.build_control.nfolds = 5;
        eventLog().info(Stage.Validation, "5-fold cross-validation will be used.");
      }
    } else if (buildSpec.build_control.nfolds <= 1) {
      eventLog().info(Stage.Validation, "Cross-validation disabled by user: no fold column nor nfolds > 1.");
      buildSpec.build_control.nfolds = 0;
    }
    if ((buildSpec.build_control.nfolds > 0 || buildSpec.input_spec.fold_column != null)
            && DKV.getGet(buildSpec.input_spec.validation_frame) != null) {
      eventLog().warn(Stage.Validation, "User specified a validation frame with cross-validation still enabled."
              + " Please note that the models will still be validated using cross-validation only,"
              + " the validation frame will be used to provide purely informative validation metrics on the trained models.");
    }
    if (Arrays.asList(
            DistributionFamily.fractionalbinomial,
            DistributionFamily.quasibinomial,
            DistributionFamily.ordinal
            ).contains(buildSpec.build_control.distribution)) {
      throw new H2OIllegalArgumentException("Distribution \"" + buildSpec.build_control.distribution.name() + "\" is not supported in AutoML!");
    }
  }

  private void validateModelBuilding(AutoMLBuildModels modelBuilding) {
    if (modelBuilding.exclude_algos != null && modelBuilding.include_algos != null) {
      throw new  H2OIllegalArgumentException("Parameters `exclude_algos` and `include_algos` are mutually exclusive: please use only one of them if necessary.");
    }
    if (modelBuilding.exploitation_ratio > 1) {
      throw new H2OIllegalArgumentException("`exploitation_ratio` must be between 0 and 1.");
    }
  }
  
  private void validateEarlyStopping(AutoMLStoppingCriteria stoppingCriteria, AutoMLInput input) {
    if (stoppingCriteria.max_models() <= 0 && stoppingCriteria.max_runtime_secs() <= 0) {
      stoppingCriteria.set_max_runtime_secs(3600);
      eventLog().info(Stage.Validation, "User didn't set any runtime constraints (max runtime or max models), using default 1h time limit");
    }
    Frame refFrame = DKV.getGet(input.training_frame);
    if (stoppingCriteria.stopping_tolerance() == AUTO_STOPPING_TOLERANCE) {
      stoppingCriteria.set_default_stopping_tolerance_for_frame(refFrame);
      eventLog().info(Stage.Validation, "Setting stopping tolerance adaptively based on the training frame: "+stoppingCriteria.stopping_tolerance());
    } else {
      eventLog().info(Stage.Validation, "Stopping tolerance set by the user: "+stoppingCriteria.stopping_tolerance());
      double defaultTolerance = AutoMLStoppingCriteria.default_stopping_tolerance_for_frame(refFrame);
      if (stoppingCriteria.stopping_tolerance() < 0.7 * defaultTolerance){
        eventLog().warn(Stage.Validation, "Stopping tolerance set by the user is < 70% of the recommended default of "+defaultTolerance+", so models may take a long time to converge or may not converge at all.");
      }
    }
  }


  private void validateReproducibility(AutoMLBuildSpec buildSpec) {
    eventLog().info(Stage.Validation, "Build control seed: " + buildSpec.build_control.stopping_criteria.seed() +
            (buildSpec.build_control.stopping_criteria.seed() == -1 ? " (random)" : ""));
  }

  private void initLeaderboard() {
    String sortMetric = _buildSpec.input_spec.sort_metric;
    sortMetric = sortMetric == null || StoppingMetric.AUTO.name().equalsIgnoreCase(sortMetric) ? null : sortMetric.toLowerCase();
    if ("deviance".equalsIgnoreCase(sortMetric)) {
        sortMetric = "mean_residual_deviance"; //compatibility with names used in leaderboard
    }
    _leaderboard = Leaderboard.getInstance(_key.toString(), eventLog().asLogger(Stage.ModelTraining), _leaderboardFrame, sortMetric, Leaderboard.ScoreData.auto);
    if (null != _leaderboard) {
      eventLog().warn(Stage.Workflow, "New models will be added to existing leaderboard "+_key.toString()
              +" (leaderboard frame="+(_leaderboardFrame == null ? null : _leaderboardFrame._key)+") with already "+_leaderboard.getModelKeys().length+" models.");
    } else {
      _leaderboard = Leaderboard.getOrMake(_key.toString(), eventLog().asLogger(Stage.ModelTraining), _leaderboardFrame, sortMetric, Leaderboard.ScoreData.auto);
    }
    _leaderboard.setExtensionsProvider(createLeaderboardExtensionProvider(this));
  }

  private void initPreprocessing() {
    _preprocessing = _buildSpec.build_models.preprocessing == null 
            ? null 
            : Arrays.stream(_buildSpec.build_models.preprocessing)
                .map(def -> def.newPreprocessingStep(this))
                .toArray(PreprocessingStep[]::new);
  }
  
  PreprocessingStep[] getPreprocessing() {
    return _preprocessing;
  }

  ModelingStep[] getExecutionPlan() {
    if (_executionPlan == null) {
      _executionPlan = session().getModelingStepsRegistry().getOrderedSteps(selectModelingPlan(null), this);
    }
    return _executionPlan;
  }

  StepDefinition[] selectModelingPlan(StepDefinition[] plan) {
    if (_buildSpec.build_models.modeling_plan == null) {
      // as soon as user specifies max_models, consider that user expects reproducibility.
      _buildSpec.build_models.modeling_plan = plan != null ? plan
              : _buildSpec.build_control.stopping_criteria.max_models() > 0 ? ModelingPlans.REPRODUCIBLE
              : ModelingPlans.defaultPlan();
    }
    return _buildSpec.build_models.modeling_plan;
  }

  void planWork() {
    Set skippedAlgos = new HashSet<>();
    if (_buildSpec.build_models.exclude_algos != null) {
      skippedAlgos.addAll(Arrays.asList(_buildSpec.build_models.exclude_algos));
    } else if (_buildSpec.build_models.include_algos != null) {
      skippedAlgos.addAll(Arrays.asList(Algo.values()));
      skippedAlgos.removeAll(Arrays.asList(_buildSpec.build_models.include_algos));
    }

    for (Algo algo : Algo.values()) {
      if (!skippedAlgos.contains(algo) && !algo.enabled()) {
        boolean isMultinode = H2O.CLOUD.size() > 1;
        _eventLog.warn(Stage.Workflow,
                isMultinode ? "AutoML: "+algo.name()+" is not available in multi-node cluster; skipping it."
                        + " See http://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html#experimental-features for details."
                        : "AutoML: "+algo.name()+" is not available; skipping it."
        );
        skippedAlgos.add(algo);
      }
    }

    WorkAllocations workAllocations = new WorkAllocations();
    for (ModelingStep step: getExecutionPlan()) {
      workAllocations.allocate(step.makeWork());
    }
    for (IAlgo skippedAlgo : skippedAlgos) {
      eventLog().info(Stage.Workflow, "Disabling Algo: "+skippedAlgo+" as requested by the user.");
      workAllocations.remove(skippedAlgo);
    }
    eventLog().debug(Stage.Workflow, "Defined work allocations: "+workAllocations);
    distributeExplorationVsExploitationWork(workAllocations);
    eventLog().debug(Stage.Workflow, "Actual work allocations: "+workAllocations);
    workAllocations.freeze();
    _workAllocations = workAllocations;
  }

  private void distributeExplorationVsExploitationWork(WorkAllocations allocations) {
    if (_buildSpec.build_models.exploitation_ratio < 0) return;
    int sumExploration = allocations.remainingWork(ModelingStep.isExplorationWork);
    int sumExploitation = allocations.remainingWork(ModelingStep.isExploitationWork);
    double explorationRatio = 1 - _buildSpec.build_models.exploitation_ratio;
    int newTotal = (int)Math.round(sumExploration / explorationRatio);
    int newSumExploration = sumExploration; // keeping the same weight for exploration steps (principle of less surprise).
    int newSumExploitation = newTotal - newSumExploration;
    for (Work work : allocations.getAllocations(ModelingStep.isExplorationWork)) {
      work._weight = (int)Math.round((double)work._weight * newSumExploration/sumExploration);
    }
    for (Work work : allocations.getAllocations(ModelingStep.isExploitationWork)) {
      work._weight = (int)Math.round((double)work._weight * newSumExploitation/sumExploitation);
    }
  }

  /**
   * Creates a job for the current AutoML instance and submits it to the task runner.
   * Calling this on an already running AutoML instance has no effect.
   */
  public void submit() {
    if (_job == null || !_job.isRunning()) {
      planWork();
      H2OJob j = new H2OJob<>(this, _key, _runCountdown.remainingTime());
      _job = j._job;
      eventLog().info(Stage.Workflow, "AutoML job created: " + EventLogEntry.dateTimeFormat.get().format(_startTime))
              .setNamedValue("creation_epoch", _startTime, EventLogEntry.epochFormat.get());
      j.start(_workAllocations.remainingWork());
      DKV.put(this);
    }
  }

  @Override
  public void run() {
    _modelingStepsExecutor.start();
    eventLog().info(Stage.Workflow, "AutoML build started: " + EventLogEntry.dateTimeFormat.get().format(_runCountdown.start_time()))
            .setNamedValue("start_epoch", _runCountdown.start_time(), EventLogEntry.epochFormat.get());
    try {
      learn();
    } finally {
      stop();
    }
  }
  
  @Override
  public void stop() {
    if (null == _modelingStepsExecutor) return; // already stopped
    _modelingStepsExecutor.stop();
    eventLog().info(Stage.Workflow, "AutoML build stopped: " + EventLogEntry.dateTimeFormat.get().format(_runCountdown.stop_time()))
            .setNamedValue("stop_epoch", _runCountdown.stop_time(), EventLogEntry.epochFormat.get());
    eventLog().info(Stage.Workflow, "AutoML build done: built " + _modelingStepsExecutor.modelCount() + " models");
    eventLog().info(Stage.Workflow, "AutoML duration: "+ PrettyPrint.msecs(_runCountdown.duration(), true))
            .setNamedValue("duration_secs", Math.round(_runCountdown.duration() / 1000.));

    log.info("AutoML run summary:");
    for (EventLogEntry event : eventLog()._events)
      log.info(event.toString());
    if (0 < leaderboard().getModelKeys().length) {
      log.info(leaderboard().toLogString());
    } else {
      long max_runtime_secs = (long)_buildSpec.build_control.stopping_criteria.max_runtime_secs();
      eventLog().warn(Stage.Workflow, "Empty leaderboard.\n"
              +"AutoML was not able to build any model within a max runtime constraint of "+max_runtime_secs+" seconds, "
              +"you may want to increase this value before retrying.");
    }

    session().detach();
    possiblyVerifyImmutability();
    if (!_buildSpec.build_control.keep_cross_validation_predictions) {
      cleanUpModelsCVPreds();
    }
  }

  /**
   * Holds until AutoML's job is completed, if a job exists.
   */
  public void get() {
    if (_job != null) _job.get();
  }

  public Job job() {
    if (null == _job) return null;
    return DKV.getGet(_job._key);
  }

  public Model leader() {
    return leaderboard() == null ? null : _leaderboard.getLeader();
  }
  
  public AutoMLSession session() {
    _session = _session == null ? null : _session._key.get();
    if (_session != null) _session.attach(this, false);
    return _session;
  }

  public Leaderboard leaderboard() {
    return _leaderboard == null ? null : (_leaderboard = _leaderboard._key.get());
  }

  public EventLog eventLog() {
    return _eventLog == null ? null : (_eventLog = _eventLog._key.get());
  }

  public String projectName() {
    return _buildSpec == null ? null : _buildSpec.project();
  }

  public long timeRemainingMs() {
    return _runCountdown.remainingTime();
  }

  public int remainingModels() {
    if (_buildSpec.build_control.stopping_criteria.max_models() == 0)
      return Integer.MAX_VALUE;
    return _buildSpec.build_control.stopping_criteria.max_models() - _modelingStepsExecutor.modelCount();
  }

  @Override
  public boolean keepRunning() {
    return !_runCountdown.timedOut() && remainingModels() > 0;
  }

  public boolean isCVEnabled() {
    return _buildSpec.build_control.nfolds > 0 || _buildSpec.input_spec.fold_column != null;
  }
  

  //*****************  Data Preparation Section  *****************//

  private void optionallySplitTrainingDataset() {
    // If no cross-validation and validation or leaderboard frame are missing,
    // then we need to create one out of the original training set.
    if (!isCVEnabled()) {
      double[] splitRatios = null;
      double validationRatio = null == _validationFrame ? 0.1 : 0;
      double blendingRatio = (_useAutoBlending && null == _blendingFrame) ? 0.2 : 0;
      if (validationRatio + blendingRatio > 0) {
        splitRatios = new double[]{
                1 - (validationRatio + blendingRatio),
                validationRatio,
                blendingRatio
        };
        ArrayList frames = new ArrayList();
        if (null == _validationFrame) frames.add("validation");
        if (null == _blendingFrame && _useAutoBlending) frames.add("blending");

        String framesStr = String.join(", ", frames);
        String ratioStr = Arrays.stream(splitRatios)
                .mapToObj(d -> Integer.toString((int) (d * 100)))
                .collect(Collectors.joining("/"));
        eventLog().info(Stage.DataImport, "Since cross-validation is disabled, and " + framesStr + " frame(s) were not provided, " +
                "automatically split the training data into training, " + framesStr + " frame(s) in the ratio " + ratioStr + ".");
      }
      if (splitRatios != null) {
        Key[] keys = new Key[] {
            Key.make(_runId+"_training_"+ _origTrainingFrame._key),
            Key.make(_runId+"_validation_"+ _origTrainingFrame._key),
            Key.make(_runId+"_blending_"+ _origTrainingFrame._key),
        };
        Frame[] splits = ShuffleSplitFrame.shuffleSplitFrame(
                _origTrainingFrame, 
                keys, 
                splitRatios, 
                _buildSpec.build_control.stopping_criteria.seed()
        );
        _trainingFrame = splits[0];

        if (_validationFrame == null && splits[1].numRows() > 0) {
          _validationFrame = splits[1];
        } else {
          splits[1].delete();
        }

        if (_blendingFrame == null && splits[2].numRows() > 0) {
          _blendingFrame = splits[2];
        } else {
          splits[2].delete();
        }
      }
      if (_leaderboardFrame == null)
        _leaderboardFrame = _validationFrame;
    }
  }

  private DistributionFamily inferDistribution(Vec response) {
    int numOfDomains = response.domain() == null ? 0 : response.domain().length;
    if (_buildSpec.build_control.distribution == DistributionFamily.AUTO) {
      if (numOfDomains == 0)
        return DistributionFamily.gaussian;
      if (numOfDomains == 2)
        return DistributionFamily.bernoulli;
      if (numOfDomains > 2)
        return DistributionFamily.multinomial;

      throw new RuntimeException("Number of classes is equal to 1.");
    } else {
      DistributionFamily distribution = _buildSpec.build_control.distribution;
      if (numOfDomains > 2) {
        if (!Arrays.asList(
                DistributionFamily.multinomial,
                DistributionFamily.ordinal,
                DistributionFamily.custom
        ).contains(distribution)) {
          throw new H2OAutoMLException("Wrong distribution specified! Number of classes of response is greater than 2." +
                  " Possible distribution values: \"multinomial\"," +
                  /*" \"ordinal\"," + */ // Currently unsupported in AutoML
                  " \"custom\".");
        }
      } else if (numOfDomains == 2) {
        if (!Arrays.asList(
                DistributionFamily.bernoulli,
                DistributionFamily.quasibinomial,
                DistributionFamily.fractionalbinomial,
                DistributionFamily.custom
        ).contains(distribution)) {
          throw new H2OAutoMLException("Wrong distribution specified! Number of classes of response is 2." +
                  " Possible distribution values: \"bernoulli\"," +
                  /*" \"quasibinomial\", \"fractionalbinomial\"," + */ // Currently unsupported in AutoML
                  " \"custom\".");
        }
      } else {
        if (!Arrays.asList(
                DistributionFamily.gaussian,
                DistributionFamily.poisson,
                DistributionFamily.negativebinomial,
                DistributionFamily.gamma,
                DistributionFamily.laplace,
                DistributionFamily.quantile,
                DistributionFamily.huber,
                DistributionFamily.tweedie,
                DistributionFamily.custom
        ).contains(distribution)) {
          throw new H2OAutoMLException("Wrong distribution specified! Response type suggests a regression task." +
                  " Possible distribution values: \"gaussian\", \"poisson\", \"negativebinomial\", \"gamma\", " +
                  "\"laplace\", \"quantile\", \"huber\", \"tweedie\", \"custom\".");
        }
      }
    return distribution;
    }
  }

  private void prepareData() {
    final AutoMLInput input = _buildSpec.input_spec;
    _origTrainingFrame = DKV.getGet(input.training_frame);
    _validationFrame = DKV.getGet(input.validation_frame);
    _blendingFrame = DKV.getGet(input.blending_frame);
    _leaderboardFrame = DKV.getGet(input.leaderboard_frame);

    optionallySplitTrainingDataset();

    if (null == _trainingFrame) {
      // when nfolds>0, let trainingFrame be the original frame
      // but cloning to keep an internal ref just in case the original ref gets deleted from client side
      // (can occur in some corner cases with Python GC for example if frame get's out of scope during an AutoML rerun)
      _trainingFrame = new Frame(_origTrainingFrame);
      _trainingFrame._key = Key.make(_runId+"_training_" + _origTrainingFrame._key);
      DKV.put(_trainingFrame);
    }

    _responseColumn = _trainingFrame.vec(input.response_column);
    _foldColumn = _trainingFrame.vec(input.fold_column);
    _weightsColumn = _trainingFrame.vec(input.weights_column);

    _distributionFamily = inferDistribution(_responseColumn);

    eventLog().info(Stage.DataImport,
        "training frame: "+_trainingFrame.toString().replace("\n", " ")+" checksum: "+_trainingFrame.checksum());
    if (null != _validationFrame) {
      eventLog().info(Stage.DataImport,
          "validation frame: "+_validationFrame.toString().replace("\n", " ")+" checksum: "+_validationFrame.checksum());
    } else {
      eventLog().info(Stage.DataImport, "validation frame: NULL");
    }
    if (null != _leaderboardFrame) {
      eventLog().info(Stage.DataImport,
          "leaderboard frame: "+_leaderboardFrame.toString().replace("\n", " ")+" checksum: "+_leaderboardFrame.checksum());
    } else {
      eventLog().info(Stage.DataImport, "leaderboard frame: NULL");
    }
    if (null != _blendingFrame) {
      this.eventLog().info(Stage.DataImport,
          "blending frame: "+_blendingFrame.toString().replace("\n", " ")+" checksum: "+_blendingFrame.checksum());
    } else {
      this.eventLog().info(Stage.DataImport, "blending frame: NULL");
    }

    eventLog().info(Stage.DataImport, "response column: "+input.response_column);
    eventLog().info(Stage.DataImport, "fold column: "+_foldColumn);
    eventLog().info(Stage.DataImport, "weights column: "+_weightsColumn);

    if (verifyImmutability) {
      // check that we haven't messed up the original Frame
      _originalTrainingFrameVecs = _origTrainingFrame.vecs().clone();
      _originalTrainingFrameNames = _origTrainingFrame.names().clone();
      _originalTrainingFrameChecksums = new long[_originalTrainingFrameVecs.length];

      for (int i = 0; i < _originalTrainingFrameVecs.length; i++)
        _originalTrainingFrameChecksums[i] = _originalTrainingFrameVecs[i].checksum();
    }
  }


  //*****************  Training Jobs  *****************//

  private void learn() {
    List completed = new ArrayList<>();
    if (_preprocessing != null) {
      for (PreprocessingStep preprocessingStep : _preprocessing) preprocessingStep.prepare();
    }
    for (ModelingStep step : getExecutionPlan()) {
      if (!exceededSearchLimits(step)) {
        StepResultState state = _modelingStepsExecutor.submit(step, job());
        log.info("AutoML step returned with state: "+state.toString());
        if (_testMode) _stepsResults = ArrayUtils.append(_stepsResults, state);
        if (state.is(ResultStatus.success)) {
          _consecutiveModelFailures.set(0);
          completed.add(step);
        } else if (state.is(ResultStatus.failed)) {
          if (!step.ignores(Constraint.FAILURE_COUNT) 
                  && _consecutiveModelFailures.incrementAndGet() >= _maxConsecutiveModelFailures) {
            throw new H2OAutoMLException("Aborting AutoML after too many consecutive model failures", state.error());
          }
          if (state.error() instanceof H2OAutoMLException) { // if a step throws this exception, this will immediately abort the entire AutoML run.
            throw (H2OAutoMLException) state.error();
          }
        }
      }
    }
    if (_preprocessing != null) {
      for (PreprocessingStep preprocessingStep : _preprocessing) preprocessingStep.dispose();
    }
    _actualModelingSteps = session().getModelingStepsRegistry().createDefinitionPlanFromSteps(completed.toArray(new ModelingStep[0]));
    eventLog().info(Stage.Workflow, "Actual modeling steps: "+Arrays.toString(_actualModelingSteps));
  }

  public Key makeKey(String algoName, String type, boolean with_counter) {
    List tokens = new ArrayList<>();
    tokens.add(algoName);
    if (!StringUtils.isNullOrEmpty(type)) tokens.add(type);
    if (with_counter) tokens.add(Integer.toString(session().nextModelCounter(algoName, type)));
    tokens.add(_runId);
    return Key.make(String.join("_", tokens));
  }

  public void trackKeys(Key... keys) {
    String whereFrom = Arrays.toString(Thread.currentThread().getStackTrace());
    for (Key key : keys) _trackedKeys.put(key, whereFrom);
  }

  private boolean exceededSearchLimits(ModelingStep step) {
    if (_job.stop_requested()) {
      eventLog().debug(EventLogEntry.Stage.ModelTraining, "AutoML: job cancelled; skipping "+step._description);
      return true;
    }

    if (!step.ignores(Constraint.TIMEOUT) && _runCountdown.timedOut()) {
      eventLog().debug(EventLogEntry.Stage.ModelTraining, "AutoML: out of time; skipping "+step._description);
      return true;
    }

    if (!step.ignores(Constraint.MODEL_COUNT) && remainingModels() <= 0) {
      eventLog().debug(EventLogEntry.Stage.ModelTraining, "AutoML: hit the max_models limit; skipping "+step._description);
      return true;
    }
    return false;
  }

  //*****************  Clean Up + other utility functions *****************//

  /**
   * Delete the AutoML-related objects, including the grids and models that it built if cascade=true
   */
  @Override
  protected Futures remove_impl(Futures fs, boolean cascade) {
    Key jobKey = _job == null ? null : _job._key;
    log.debug("Cleaning up AutoML "+jobKey);
    if (_buildSpec != null) {
      // If the Frame was made here (e.g. buildspec contained a path, then it will be deleted
      if (_buildSpec.input_spec.training_frame == null && _origTrainingFrame != null) {
        _origTrainingFrame.delete(jobKey, fs, true);
      }
      if (_buildSpec.input_spec.validation_frame == null && _validationFrame != null) {
        _validationFrame.delete(jobKey, fs, true);
      }
    }
    if (_trainingFrame != null && _origTrainingFrame != null)
      Frame.deleteTempFrameAndItsNonSharedVecs(_trainingFrame, _origTrainingFrame);
    if (leaderboard() != null) leaderboard().remove(fs, cascade);
    if (eventLog() != null) eventLog().remove(fs, cascade);
    if (session() != null) session().remove(fs, cascade);
    if (cascade && _preprocessing != null) {
      for (PreprocessingStep preprocessingStep : _preprocessing) {
        preprocessingStep.remove();
      }
    }
    for (Key key : _trackedKeys.keySet()) Keyed.remove(key, fs, true);

    return super.remove_impl(fs, cascade);
  }

  private boolean possiblyVerifyImmutability() {
    boolean warning = false;

    if (verifyImmutability) {
      // check that we haven't messed up the original Frame
      eventLog().debug(Stage.Workflow, "Verifying training frame immutability. . .");

      Vec[] vecsRightNow = _origTrainingFrame.vecs();
      String[] namesRightNow = _origTrainingFrame.names();

      if (_originalTrainingFrameVecs.length != vecsRightNow.length) {
        log.warn("Training frame vec count has changed from: " +
                _originalTrainingFrameVecs.length + " to: " + vecsRightNow.length);
        warning = true;
      }
      if (_originalTrainingFrameNames.length != namesRightNow.length) {
        log.warn("Training frame vec count has changed from: " +
                _originalTrainingFrameNames.length + " to: " + namesRightNow.length);
        warning = true;
      }

      for (int i = 0; i < _originalTrainingFrameVecs.length; i++) {
        if (!_originalTrainingFrameVecs[i].equals(vecsRightNow[i])) {
          log.warn("Training frame vec number " + i + " has changed keys.  Was: " +
                  _originalTrainingFrameVecs[i] + " , now: " + vecsRightNow[i]);
          warning = true;
        }
        if (!_originalTrainingFrameNames[i].equals(namesRightNow[i])) {
          log.warn("Training frame vec number " + i + " has changed names.  Was: " +
                  _originalTrainingFrameNames[i] + " , now: " + namesRightNow[i]);
          warning = true;
        }
        if (_originalTrainingFrameChecksums[i] != vecsRightNow[i].checksum()) {
          log.warn("Training frame vec number " + i + " has changed checksum.  Was: " +
                  _originalTrainingFrameChecksums[i] + " , now: " + vecsRightNow[i].checksum());
          warning = true;
        }
      }

      if (warning)
        eventLog().warn(Stage.Workflow, "Training frame was mutated!  This indicates a bug in the AutoML software.");
      else
        eventLog().debug(Stage.Workflow, "Training frame was not mutated (as expected).");

    } else {
      eventLog().debug(Stage.Workflow, "Not verifying training frame immutability. . .  This is turned off for efficiency.");
    }

    return warning;
  }

  private void cleanUpModelsCVPreds() {
    log.info("Cleaning up all CV Predictions for AutoML");
    for (Model model : leaderboard().getModels()) {
        model.deleteCrossValidationPreds();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy