All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.deepwater.DeepWaterParameters Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.deepwater;

import hex.Model;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import water.H2O;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;

import javax.imageio.ImageIO;
import java.io.File;
import java.io.FilenameFilter;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.Arrays;

import static hex.deepwater.DeepWaterParameters.ProblemType.auto;

/**
 * Parameters for a Deep Water image classification model
 */
public class DeepWaterParameters extends Model.Parameters {
  public String algoName() { return "DeepWater"; }
  public String fullName() { return "Deep Water"; }
  public String javaName() { return DeepWaterModel.class.getName(); }
  @Override protected double defaultStoppingTolerance() { return 0; }
  public DeepWaterParameters() {
    super();
    _stopping_rounds = 5;
  }
  @Override
  public long progressUnits() {
    if (train()==null) return 1;
    return (long)Math.ceil(_epochs*train().numRows());
  }
  public float learningRate(double n) { return (float)(_learning_rate / (1 + _learning_rate_annealing * n)); }
  final public float momentum(double n) {
    double m = _momentum_start;
    if( _momentum_ramp > 0 ) {
      if( n >= _momentum_ramp)
        m = _momentum_stable;
      else
        m += (_momentum_stable - _momentum_start) * n / _momentum_ramp;
    }
    return (float)m;
  }

  public enum Network {
    auto, user, lenet, alexnet, vgg, googlenet, inception_bn, resnet,
  }

  public enum Backend {
    unknown,
    mxnet, caffe, tensorflow, // C++
    xgrpc // anything that speaks grpc
  }

  public enum ProblemType {
    auto, image, text, dataset
  }

  public double _clip_gradient = 10.0;
  public boolean _gpu = true;
  public int[] _device_id = new int[]{0};

  public Network _network = Network.auto;
  public Backend _backend = Backend.mxnet;
  public String _network_definition_file;
  public String _network_parameters_file;
  public String _export_native_parameters_prefix;

  public ProblemType _problem_type = auto;

  // specific parameters for image_classification
  public int[] _image_shape = new int[]{0,0}; //width x height
  public int _channels = 3; //either 1 (monochrome) or 3 (RGB)
  public String _mean_image_file; //optional file with mean image (backend specific)

  /**
   * If enabled, store the best model under the destination key of this model at the end of training.
   * Only applicable if training is not cancelled.
   */
  public boolean _overwrite_with_best_model = true;

  public boolean _autoencoder = false;

  public boolean _sparse = false;

  public boolean _use_all_factor_levels = true;

  public enum MissingValuesHandling {
    MeanImputation, Skip
  }

  public MissingValuesHandling _missing_values_handling = MissingValuesHandling.MeanImputation;

  /**
   * If enabled, automatically standardize the data. If disabled, the user must provide properly scaled input data.
   */
  public boolean _standardize = true;

  /**
   * The number of passes over the training dataset to be carried out.
   * It is recommended to start with lower values for initial experiments.
   * This value can be modified during checkpoint restarts and allows continuation
   * of selected models.
   */
  public double _epochs = 10;

  /**
   * Activation functions
   */
  public enum Activation {
    Rectifier, Tanh
  }

  /**
   * The activation function (non-linearity) to be used the neurons in the hidden layers.
   * Tanh: Hyperbolic tangent function (same as scaled and shifted sigmoid).
   * Rectifier: Chooses the maximum of (0, x) where x is the input value.
   */
  public Activation _activation = null;

  /**
   * The number and size of each hidden layer in the model.
   * For example, if a user specifies "100,200,100" a model with 3 hidden
   * layers will be produced, and the middle hidden layer will have 200
   * neurons.
   */
  public int[] _hidden = null;

  /**
   * A fraction of the features for each training row to be omitted from training in order
   * to improve generalization (dimension sampling).
   */
  public double _input_dropout_ratio = 0.0f;

  /**
   * A fraction of the inputs for each hidden layer to be omitted from training in order
   * to improve generalization. Defaults to 0.5 for each hidden layer if omitted.
   */
  public double[] _hidden_dropout_ratios = null;

  /**
   * The number of training data rows to be processed per iteration. Note that
   * independent of this parameter, each row is used immediately to update the model
   * with (online) stochastic gradient descent. This parameter controls the
   * synchronization period between nodes in a distributed environment and the
   * frequency at which scoring and model cancellation can happen. For example, if
   * it is set to 10,000 on H2O running on 4 nodes, then each node will
   * process 2,500 rows per iteration, sampling randomly from their local data.
   * Then, model averaging between the nodes takes place, and scoring can happen
   * (dependent on scoring interval and duty factor). Special values are 0 for
   * one epoch per iteration, -1 for processing the maximum amount of data
   * per iteration (if **replicate training data** is enabled, N epochs
   * will be trained per iteration on N nodes, otherwise one epoch). Special value
   * of -2 turns on automatic mode (auto-tuning).
   */
  public long _train_samples_per_iteration = -2;

  public double _target_ratio_comm_to_comp = 0.05;

  /*Learning Rate*/
  /**
   * When adaptive learning rate is disabled, the magnitude of the weight
   * updates are determined by the user specified learning rate
   * (potentially annealed), and are a function  of the difference
   * between the predicted value and the target value. That difference,
   * generally called delta, is only available at the output layer. To
   * correct the output at each hidden layer, back propagation is
   * used. Momentum modifies back propagation by allowing prior
   * iterations to influence the current update. Using the momentum
   * parameter can aid in avoiding local minima and the associated
   * instability. Too much momentum can lead to instabilities, that's
   * why the momentum is best ramped up slowly.
   * This parameter is only active if adaptive learning rate is disabled.
   */
  public double _learning_rate = 1e-3;

  /**
   * Learning rate annealing reduces the learning rate to "freeze" into
   * local minima in the optimization landscape.  The annealing rate is the
   * inverse of the number of training samples it takes to cut the learning rate in half
   * (e.g., 1e-6 means that it takes 1e6 training samples to halve the learning rate).
   * This parameter is only active if adaptive learning rate is disabled.
   */
  public double _learning_rate_annealing = 1e-6;

  /**
   * The momentum_start parameter controls the amount of momentum at the beginning of training.
   * This parameter is only active if adaptive learning rate is disabled.
   */
  public double _momentum_start = 0.9;

  /**
   * The momentum_ramp parameter controls the amount of learning for which momentum increases
   * (assuming momentum_stable is larger than momentum_start). The ramp is measured in the number
   * of training samples.
   * This parameter is only active if adaptive learning rate is disabled.
   */
  public double _momentum_ramp = 1e4;

  /**
   * The momentum_stable parameter controls the final momentum value reached after momentum_ramp training samples.
   * The momentum used for training will remain the same for training beyond reaching that point.
   * This parameter is only active if adaptive learning rate is disabled.
   */
  public double _momentum_stable = 0.9;


  /**
   * The minimum time (in seconds) to elapse between model scoring. The actual
   * interval is determined by the number of training samples per iteration and the scoring duty cycle.
   */
  public double _score_interval = 5;

  /**
   * The number of training dataset points to be used for scoring. Will be
   * randomly sampled. Use 0 for selecting the entire training dataset.
   */
  public long _score_training_samples = 10000l;

  /**
   * The number of validation dataset points to be used for scoring. Can be
   * randomly sampled or stratified (if "balance classes" is set and "score
   * validation sampling" is set to stratify). Use 0 for selecting the entire
   * training dataset.
   */
  public long _score_validation_samples = 0l;

  /**
   * Maximum fraction of wall clock time spent on model scoring on training and validation samples,
   * and on diagnostics such as computation of feature importances (i.e., not on training).
   */
  public double _score_duty_cycle = 0.1;

  /**
   * Enable quiet mode for less output to standard output.
   */
  public boolean _quiet_mode = false;

  /**
   * Replicate the entire training dataset onto every node for faster training on small datasets.
   */
  public boolean _replicate_training_data = true;

  /**
   * Run on a single node for fine-tuning of model parameters. Can be useful for
   * checkpoint resumes after training on multiple nodes for fast initial
   * convergence.
   */
  public boolean _single_node_mode = false;

  /**
   * Enable shuffling of training data (on each node). This option is
   * recommended if training data is replicated on N nodes, and the number of training samples per iteration
   * is close to N times the dataset size, where all nodes train with (almost) all
   * the data. It is automatically enabled if the number of training samples per iteration is set to -1 (or to N
   * times the dataset size or larger).
   */
  public boolean _shuffle_training_data = true;

  public int _mini_batch_size = 32;

  public boolean _cache_data = true;

  /**
   * Validate model parameters
   * @param dl DL Model Builder (Driver)
   * @param expensive (whether or not this is the "final" check)
   */
  void validate(DeepWater dl, boolean expensive) {
    boolean classification = expensive || dl.nclasses() != 0 ? dl.isClassifier() : _distribution == DistributionFamily.bernoulli || _distribution == DistributionFamily.bernoulli;
    if (_mini_batch_size < 1)
      dl.error("_mini_batch_size", "Mini-batch size must be >= 1");

    if (_weights_column!=null && expensive) {
      Vec w = (train().vec(_weights_column));
      if (!w.isInt() || w.max() > 1 || w.min() < 0) {
        dl.error("_weights_column", "only supporting weights of 0 or 1 right now");
      }
    }

    if (_clip_gradient<=0)
      dl.error("_clip_gradient", "Clip gradient must be >= 0");

    if (_hidden != null && _network_definition_file != null && !_network_definition_file.isEmpty())
      dl.error("_hidden", "Cannot provide hidden layers and a network definition file at the same time.");

    if (_activation != null && _network_definition_file != null && !_network_definition_file.isEmpty())
      dl.error("_activation", "Cannot provide activation functions and a network definition file at the same time.");

    if (_problem_type == ProblemType.image) {
      if (_image_shape.length != 2)
        dl.error("_image_shape", "image_shape must have 2 dimensions (width, height)");
      if (_image_shape[0] < 0)
        dl.error("_image_shape", "image_shape[0] must be >=1 or automatic (0).");
      if (_image_shape[1] < 0)
        dl.error("_image_shape", "image_shape[1] must be >=1 or automatic (0).");
      if (_channels != 1 && _channels != 3)
        dl.error("_channels", "channels must be either 1 or 3.");
    } else if (_problem_type != auto) {
      dl.warn("_image_shape", "image shape is ignored, only used for image_classification");
      dl.warn("_channels", "channels shape is ignored, only used for image_classification");
      dl.warn("_mean_image_file", "mean_image_file shape is ignored, only used for image_classification");
    }
    if (_categorical_encoding==CategoricalEncodingScheme.Enum) {
      dl.error("_categorical_encoding", "categorical encoding scheme cannot be Enum: the neural network must have numeric columns as input.");
    }

    if (_autoencoder)
      dl.error("_autoencoder", "Autoencoder is not supported right now.");

    if (_network == Network.user) {
      if (_network_definition_file == null || _network_definition_file.isEmpty())
        dl.error("_network_definition_file", "network_definition_file must be provided if the network is user-specified.");
      else if (!new File(_network_definition_file).exists())
        dl.error("_network_definition_file", "network_definition_file " + _network_definition_file + " not found.");
    } else {
      if (_network_definition_file != null && !_network_definition_file.isEmpty() && _network != Network.auto)
        dl.error("_network_definition_file", "network_definition_file cannot be provided if a pre-defined network is chosen.");
    }
    if (_network_parameters_file != null && !_network_parameters_file.isEmpty()) {
      if (!DeepWaterModelInfo.paramFilesExist(_network_parameters_file)) {
        dl.error("_network_parameters_file", "network_parameters_file " + _network_parameters_file + " not found.");
      }
    }
    if (_checkpoint!=null) {
      DeepWaterModel other = (DeepWaterModel) _checkpoint.get();
      if (other == null)
        dl.error("_width", "Invalid checkpoint provided: width mismatch.");
      if (!Arrays.equals(_image_shape, other.get_params()._image_shape))
        dl.error("_width", "Invalid checkpoint provided: width mismatch.");
    }

    if (!_autoencoder) {
      if (classification) {
        dl.hide("_regression_stop", "regression_stop is used only with regression.");
      } else {
        dl.hide("_classification_stop", "classification_stop is used only with classification.");
        //          dl.hide("_max_hit_ratio_k", "max_hit_ratio_k is used only with classification.");
        //          dl.hide("_balance_classes", "balance_classes is used only with classification.");
      }
      //        if( !classification || !_balance_classes )
      //          dl.hide("_class_sampling_factors", "class_sampling_factors requires both classification and balance_classes.");
      if (!classification && _valid != null || _valid == null)
        dl.hide("_score_validation_sampling", "score_validation_sampling requires classification and a validation frame.");
    } else {
      if (_nfolds > 1) {
        dl.error("_nfolds", "N-fold cross-validation is not supported for Autoencoder.");
      }
    }
    if (H2O.CLOUD.size() == 1 && _replicate_training_data)
      dl.hide("_replicate_training_data", "replicate_training_data is only valid with cloud size greater than 1.");
    if (_single_node_mode && (H2O.CLOUD.size() == 1 || !_replicate_training_data))
      dl.hide("_single_node_mode", "single_node_mode is only used with multi-node operation with replicated training data.");
    if (H2O.ARGS.client && _single_node_mode)
      dl.error("_single_node_mode", "Cannot run on a single node in client mode");
    if (_autoencoder)
      dl.hide("_use_all_factor_levels", "use_all_factor_levels is mandatory in combination with autoencoder.");
    if (_nfolds != 0)
      dl.hide("_overwrite_with_best_model", "overwrite_with_best_model is unsupported in combination with n-fold cross-validation.");
    if (expensive) dl.checkDistributions();

    if (_score_training_samples < 0)
      dl.error("_score_training_samples", "Number of training samples for scoring must be >= 0 (0 for all).");
    if (_score_validation_samples < 0)
      dl.error("_score_validation_samples", "Number of training samples for scoring must be >= 0 (0 for all).");
    if (classification && dl.hasOffsetCol())
      dl.error("_offset_column", "Offset is only supported for regression.");

    // reason for the error message below is that validation might not have the same horizontalized features as the training data (or different order)
    if (expensive) {
      if (!classification && _balance_classes) {
        dl.error("_balance_classes", "balance_classes requires classification.");
      }
      if (_class_sampling_factors != null && !_balance_classes) {
        dl.error("_class_sampling_factors", "class_sampling_factors requires balance_classes to be enabled.");
      }
      if (_replicate_training_data && null != train() && train().byteSize() > 0.9*H2O.CLOUD.free_mem()/H2O.CLOUD.size() && H2O.CLOUD.size() > 1) {
        dl.error("_replicate_training_data", "Compressed training dataset takes more than 90% of avg. free available memory per node (" + 0.9*H2O.CLOUD.free_mem()/H2O.CLOUD.size() + "), cannot run with replicate_training_data.");
      }
    }
    if (_autoencoder && _stopping_metric != ScoreKeeper.StoppingMetric.AUTO && _stopping_metric != ScoreKeeper.StoppingMetric.MSE) {
      dl.error("_stopping_metric", "Stopping metric must either be AUTO or MSE for autoencoder.");
    }
  }

  /**
   * Attempt to guess the problem type from the dataset
   * @return
   */
  ProblemType guessProblemType() {
    if (_problem_type == auto) {
      boolean image = false;
      boolean text = false;
      String first = null;
      Vec v = train().vec(0);
      if (v.isString() || v.isCategorical() /*small data parser artefact*/) {
        BufferedString bs = new BufferedString();
        first = v.atStr(bs, 0).toString();
        try {
          ImageIO.read(new File(first));
          image = true;
        } catch (Throwable t) {
        }
        try {
          ImageIO.read(new URL(first));
          image = true;
        } catch (Throwable t) {
        }
      }

      if (first != null) {
        if (!image && (first.endsWith(".jpg") || first.endsWith(".png") || first.endsWith(".tif"))) {
          image = true;
          Log.warn("Cannot read first image at " + first + " - Check data.");
        } else if (v.isString() && train().numCols() <= 4) { //at most text, label, fold_col, weight
          text = true;
        }
      }
      if (image) return ProblemType.image;
      else if (text) return ProblemType.text;
      else return ProblemType.dataset;
    } else {
      return _problem_type;
    }
  }

  static class Sanity {
    // the following parameters can be modified when restarting from a checkpoint
    transient static private final String[] cp_modifiable = new String[]{
        "_seed",
        "_checkpoint",
        "_epochs",
        "_score_interval",
        "_train_samples_per_iteration",
        "_target_ratio_comm_to_comp",
        "_score_duty_cycle",
        "_score_training_samples",
        "_score_validation_samples",
        "_score_validation_sampling",
        "_classification_stop",
        "_regression_stop",
        "_stopping_rounds",
        "_stopping_metric",
        "_quiet_mode",
        "_max_confusion_matrix_size",
        "_max_hit_ratio_k",
        "_diagnostics",
        "_variable_importances",
        "_replicate_training_data",
        "_shuffle_training_data",
        "_single_node_mode",
        "_overwrite_with_best_model",
        "_mini_batch_size",
        "_network_parameters_file",
        "_clip_gradient",
        "_learning_rate",
        "_learning_rate_annealing",
        "_gpu",
        "_sparse",
        "_device_id",
        "_cache_data",
        "_input_dropout_ratio",
        "_hidden_dropout_ratios",
        "_cache_data",
        "_export_native_parameters_prefix",
        "_image_shape", //since it's hard to do equals on this in the check - should not change between checkpoint restarts
    };

    // the following parameters must not be modified when restarting from a checkpoint
    transient static private final String[] cp_not_modifiable = new String[]{
        "_drop_na20_cols",
        "_missing_values_handling",
        "_response_column",
        "_activation",
        "_use_all_factor_levels",
        "_problem_type",
        "_channels",
        "_standardize",
        "_autoencoder",
        "_network",
        "_backend",
        "_momentum_start",
        "_momentum_ramp",
        "_momentum_stable",
        "_ignore_const_cols",
        "_max_categorical_features",
        "_nfolds",
        "_distribution",
        "_network_definition_file",
        "_mean_image_file"
    };

    static void checkCompleteness() {
      for (Field f : hex.deepwater.DeepWaterParameters.class.getDeclaredFields())
        if (!ArrayUtils.contains(cp_not_modifiable, f.getName())
            &&
            !ArrayUtils.contains(cp_modifiable, f.getName())
            ) {
          if (f.getName().equals("_hidden")) continue;
          if (f.getName().equals("_ignored_columns")) continue;
          if (f.getName().equals("$jacocoData")) continue; // If code coverage is enabled
          throw H2O.unimpl("Please add " + f.getName() + " to either cp_modifiable or cp_not_modifiable");
        }
    }

    /**
     * Check that checkpoint continuation is possible
     *
     * @param oldP old DL parameters (from checkpoint)
     * @param newP new DL parameters (user-given, to restart from checkpoint)
     */
    static void checkIfParameterChangeAllowed(final hex.deepwater.DeepWaterParameters oldP, final hex.deepwater.DeepWaterParameters newP) {
      checkCompleteness();
      if (newP._nfolds != 0)
        throw new UnsupportedOperationException("nfolds must be 0: Cross-validation is not supported during checkpoint restarts.");
      if ((newP._valid == null) != (oldP._valid == null)) {
        throw new H2OIllegalArgumentException("Presence of validation dataset must agree with the checkpointed model.");
      }
      if (!newP._autoencoder && (newP._response_column == null || !newP._response_column.equals(oldP._response_column))) {
        throw new H2OIllegalArgumentException("Response column (" + newP._response_column + ") is not the same as for the checkpointed model: " + oldP._response_column);
      }
      if (!Arrays.equals(newP._ignored_columns, oldP._ignored_columns)) {
        throw new H2OIllegalArgumentException("Ignored columns must be the same as for the checkpointed model.");
      }

      //compare the user-given parameters before and after and check that they are not changed
      for (Field fBefore : oldP.getClass().getFields()) {
        if (ArrayUtils.contains(cp_not_modifiable, fBefore.getName())) {
          for (Field fAfter : newP.getClass().getFields()) {
            if (fBefore.equals(fAfter)) {
              try {
                if (fAfter.get(newP) == null || fBefore.get(oldP) == null || !fBefore.get(oldP).toString().equals(fAfter.get(newP).toString())) { // if either of the two parameters is null, skip the toString()
                  if (fBefore.get(oldP) == null && fAfter.get(newP) == null)
                    continue; //if both parameters are null, we don't need to do anything
                  throw new H2OIllegalArgumentException("Cannot change parameter: '" + fBefore.getName() + "': " + fBefore.get(oldP) + " -> " + fAfter.get(newP));
                }
              } catch (IllegalAccessException e) {
                e.printStackTrace();
              }
            }
          }
        }
      }
    }

    /**
     * Update the parameters from checkpoint to user-specified
     *
     * @param srcParms source: user-specified parameters
     * @param tgtParms target: parameters to be modified
     * @param doIt     whether to overwrite target parameters (or just print the message)
     * @param quiet    whether to suppress the notifications about parameter changes
     */
    static void updateParametersDuringCheckpointRestart(hex.deepwater.DeepWaterParameters srcParms, hex.deepwater.DeepWaterParameters tgtParms/*actually used during training*/, boolean doIt, boolean quiet) {
      for (Field fTarget : tgtParms.getClass().getFields()) {
        if (ArrayUtils.contains(cp_modifiable, fTarget.getName())) {
          for (Field fSource : srcParms.getClass().getFields()) {
            if (fTarget.equals(fSource)) {
              try {
                if (fSource.get(srcParms) == null || fTarget.get(tgtParms) == null || !fTarget.get(tgtParms).toString().equals(fSource.get(srcParms).toString())) { // if either of the two parameters is null, skip the toString()
                  if (fTarget.get(tgtParms) == null && fSource.get(srcParms) == null)
                    continue; //if both parameters are null, we don't need to do anything
                  if (!tgtParms._quiet_mode && !quiet)
                    Log.info("Applying user-requested modification of '" + fTarget.getName() + "': " + fTarget.get(tgtParms) + " -> " + fSource.get(srcParms));
                  if (doIt)
                    fTarget.set(tgtParms, fSource.get(srcParms));
                }
              } catch (IllegalAccessException e) {
                e.printStackTrace();
              }
            }
          }
        }
      }
    }

    /**
     * Take user-given parameters and turn them into usable, fully populated parameters (e.g., to be used by Neurons during training)
     *
     * @param fromParms raw user-given parameters from the REST API (READ ONLY)
     * @param toParms   modified set of parameters, with defaults filled in (WILL BE MODIFIED)
     * @param nClasses  number of classes (1 for regression or autoencoder)
     */
    static void modifyParms(hex.deepwater.DeepWaterParameters fromParms, hex.deepwater.DeepWaterParameters toParms, int nClasses) {
      if (H2O.CLOUD.size() == 1 && fromParms._replicate_training_data) {
        if (!fromParms._quiet_mode)
          Log.info("_replicate_training_data: Disabling replicate_training_data on 1 node.");
        toParms._replicate_training_data = false;
      }
      // Automatically set the distribution
      if (fromParms._distribution == DistributionFamily.AUTO) {
        // For classification, allow AUTO/bernoulli/multinomial with losses CrossEntropy/Quadratic/Huber/Absolute
        if (nClasses > 1) {
          toParms._distribution = nClasses == 2 ? DistributionFamily.bernoulli : DistributionFamily.multinomial;
        } else {
          toParms._distribution = DistributionFamily.gaussian;
        }
      }
      if (fromParms._single_node_mode && (H2O.CLOUD.size() == 1 || !fromParms._replicate_training_data)) {
        if (!fromParms._quiet_mode)
          Log.info("_single_node_mode: Disabling single_node_mode (only for multi-node operation with replicated training data).");
        toParms._single_node_mode = false;
      }
      if (fromParms._overwrite_with_best_model && fromParms._nfolds != 0) {
        if (!fromParms._quiet_mode)
          Log.info("_overwrite_with_best_model: Disabling overwrite_with_best_model in combination with n-fold cross-validation.");
        toParms._overwrite_with_best_model = false;
      }
      // Automatically set the problem_type
      if (fromParms._problem_type == auto) {
        toParms._problem_type = fromParms.guessProblemType();
        if (!fromParms._quiet_mode)
          Log.info("_problem_type: Automatically selecting problem_type: " + toParms._problem_type.toString());
      }
      if (fromParms._categorical_encoding==CategoricalEncodingScheme.AUTO) {
        if (!fromParms._quiet_mode)
          Log.info("_categorical_encoding: Automatically enabling OneHotInternal categorical encoding.");
        toParms._categorical_encoding = CategoricalEncodingScheme.OneHotInternal;
      }
      if (fromParms._nfolds != 0) {
        if (fromParms._overwrite_with_best_model) {
          if (!fromParms._quiet_mode)
            Log.info("_overwrite_with_best_model: Automatically disabling overwrite_with_best_model, since the final model is the only scored model with n-fold cross-validation.");
          toParms._overwrite_with_best_model = false;
        }
      }
      // automatic selection
      if (fromParms._network == Network.auto || fromParms._network==null) {
        // if the user specified the network, then keep that
        if (fromParms._network_definition_file != null && !fromParms._network_definition_file.equals("")) {
          if (!fromParms._quiet_mode)
            Log.info("_network_definition_file: Automatically setting network type to 'user', since a network definition file was provided.");
          toParms._network = Network.user;
        } else {
          // pick something reasonable
          if (toParms._problem_type == ProblemType.image) toParms._network = Network.inception_bn;
          if (toParms._problem_type == ProblemType.text || toParms._problem_type == ProblemType.dataset) {
            toParms._network = null;
            if (fromParms._hidden == null) {
              toParms._hidden = new int[]{200, 200};
              toParms._activation = Activation.Rectifier;
              toParms._hidden_dropout_ratios = new double[toParms._hidden.length];
            }
          }
          if (!fromParms._quiet_mode && toParms._network != null && toParms._network != Network.user)
            Log.info("_network: Using " + toParms._network + " model by default.");
        }
      }
      if (fromParms._autoencoder && fromParms._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
        if (!fromParms._quiet_mode)
          Log.info("_stopping_metric: Automatically setting stopping_metric to MSE for autoencoder.");
        toParms._stopping_metric = ScoreKeeper.StoppingMetric.MSE;
      }
      if (toParms._hidden!=null) {
        if (toParms._hidden_dropout_ratios==null) {
          if (!fromParms._quiet_mode)
            Log.info("_hidden_dropout_ratios: Automatically setting hidden_dropout_ratios to 0 for all layers.");
          toParms._hidden_dropout_ratios = new double[toParms._hidden.length];
        }
        if (toParms._activation==null) {
          toParms._activation = Activation.Rectifier;
          if (!fromParms._quiet_mode)
            Log.info("_activation: Automatically setting activation to " + toParms._activation + " for all layers.");
        }
        if (!fromParms._quiet_mode) {
          Log.info("Hidden layers: " + Arrays.toString(toParms._hidden));
          Log.info("Activation function: " + toParms._activation);
          Log.info("Input dropout ratio: " + toParms._input_dropout_ratio);
          Log.info("Hidden layer dropout ratio: " + Arrays.toString(toParms._hidden_dropout_ratios));
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy