All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.ml.params.H2ODRFParams.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.tree.drf.DRFModel.DRFParameters
import ai.h2o.sparkling.H2OFrame
import hex.tree.SharedTreeModel.SharedTreeParameters.HistogramType
import hex.tree.CalibrationHelper.CalibrationMethod
import hex.genmodel.utils.DistributionFamily
import hex.Model.Parameters.FoldAssignmentScheme
import hex.Model.Parameters.CategoricalEncodingScheme
import hex.ScoreKeeper.StoppingMetric
import hex.MultinomialAucType

trait H2ODRFParams
  extends H2OAlgoParamsBase
  with HasCalibrationDataFrame
  with HasIgnoredCols {

  protected def paramTag = reflect.classTag[DRFParameters]

  //
  // Parameter definitions
  //
  protected val mtries = intParam(
    name = "mtries",
    doc = """Number of variables randomly sampled as candidates at each split. If set to -1, defaults to sqrt{p} for classification and p/3 for regression (where p is the # of predictors.""")

  protected val binomialDoubleTrees = booleanParam(
    name = "binomialDoubleTrees",
    doc = """For binary classification: Build 2x as many trees (one per class) - can lead to higher accuracy.""")

  protected val sampleRate = doubleParam(
    name = "sampleRate",
    doc = """Row sample rate per tree (from 0.0 to 1.0).""")

  protected val balanceClasses = booleanParam(
    name = "balanceClasses",
    doc = """Balance training data class counts via over/under-sampling (for imbalanced data).""")

  protected val classSamplingFactors = nullableFloatArrayParam(
    name = "classSamplingFactors",
    doc = """Desired over/under-sampling ratios per class (in lexicographic order). If not specified, sampling factors will be automatically computed to obtain class balance during training. Requires balance_classes.""")

  protected val maxAfterBalanceSize = floatParam(
    name = "maxAfterBalanceSize",
    doc = """Maximum relative size of the training data after balancing class counts (can be less than 1.0). Requires balance_classes.""")

  protected val maxConfusionMatrixSize = intParam(
    name = "maxConfusionMatrixSize",
    doc = """[Deprecated] Maximum size (# classes) for confusion matrices to be printed in the Logs.""")

  protected val ntrees = intParam(
    name = "ntrees",
    doc = """Number of trees.""")

  protected val maxDepth = intParam(
    name = "maxDepth",
    doc = """Maximum tree depth (0 for unlimited).""")

  protected val minRows = doubleParam(
    name = "minRows",
    doc = """Fewest allowed (weighted) observations in a leaf.""")

  protected val nbins = intParam(
    name = "nbins",
    doc = """For numerical columns (real/int), build a histogram of (at least) this many bins, then split at the best point.""")

  protected val nbinsTopLevel = intParam(
    name = "nbinsTopLevel",
    doc = """For numerical columns (real/int), build a histogram of (at most) this many bins at the root level, then decrease by factor of two per level.""")

  protected val nbinsCats = intParam(
    name = "nbinsCats",
    doc = """For categorical columns (factors), build a histogram of this many bins, then split at the best point. Higher values can lead to more overfitting.""")

  protected val seed = longParam(
    name = "seed",
    doc = """Seed for pseudo random number generator (if applicable).""")

  protected val buildTreeOneNode = booleanParam(
    name = "buildTreeOneNode",
    doc = """Run on one node only; no network overhead but fewer cpus used. Suitable for small datasets.""")

  protected val sampleRatePerClass = nullableDoubleArrayParam(
    name = "sampleRatePerClass",
    doc = """A list of row sample rates per class (relative fraction for each class, from 0.0 to 1.0), for each tree.""")

  protected val colSampleRatePerTree = doubleParam(
    name = "colSampleRatePerTree",
    doc = """Column sample rate per tree (from 0.0 to 1.0).""")

  protected val colSampleRateChangePerLevel = doubleParam(
    name = "colSampleRateChangePerLevel",
    doc = """Relative change of the column sampling rate for every level (must be > 0.0 and <= 2.0).""")

  protected val scoreTreeInterval = intParam(
    name = "scoreTreeInterval",
    doc = """Score the model after every so many trees. Disabled if set to 0.""")

  protected val minSplitImprovement = doubleParam(
    name = "minSplitImprovement",
    doc = """Minimum relative improvement in squared error reduction for a split to happen.""")

  protected val histogramType = stringParam(
    name = "histogramType",
    doc = """What type of histogram to use for finding optimal split points. Possible values are ``"AUTO"``, ``"UniformAdaptive"``, ``"Random"``, ``"QuantilesGlobal"``, ``"RoundRobin"``, ``"UniformRobust"``.""")

  protected val calibrateModel = booleanParam(
    name = "calibrateModel",
    doc = """Use Platt Scaling (default) or Isotonic Regression to calculate calibrated class probabilities. Calibration can provide more accurate estimates of class probabilities.""")

  protected val calibrationMethod = stringParam(
    name = "calibrationMethod",
    doc = """Calibration method to use. Possible values are ``"AUTO"``, ``"PlattScaling"``, ``"IsotonicRegression"``.""")

  protected val checkConstantResponse = booleanParam(
    name = "checkConstantResponse",
    doc = """Check if response column is constant. If enabled, then an exception is thrown if the response column is a constant value.If disabled, then model will train regardless of the response column being a constant value or not.""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val nfolds = intParam(
    name = "nfolds",
    doc = """Number of folds for K-fold cross-validation (0 to disable or >= 2).""")

  protected val keepCrossValidationModels = booleanParam(
    name = "keepCrossValidationModels",
    doc = """Whether to keep the cross-validation models.""")

  protected val keepCrossValidationPredictions = booleanParam(
    name = "keepCrossValidationPredictions",
    doc = """Whether to keep the predictions of the cross-validation models.""")

  protected val keepCrossValidationFoldAssignment = booleanParam(
    name = "keepCrossValidationFoldAssignment",
    doc = """Whether to keep the cross-validation fold assignment.""")

  protected val distribution = stringParam(
    name = "distribution",
    doc = """Distribution function. Possible values are ``"AUTO"``, ``"bernoulli"``, ``"quasibinomial"``, ``"modified_huber"``, ``"multinomial"``, ``"ordinal"``, ``"gaussian"``, ``"poisson"``, ``"gamma"``, ``"tweedie"``, ``"huber"``, ``"laplace"``, ``"quantile"``, ``"fractionalbinomial"``, ``"negativebinomial"``, ``"custom"``.""")

  protected val labelCol = stringParam(
    name = "labelCol",
    doc = """Response variable column.""")

  protected val weightCol = nullableStringParam(
    name = "weightCol",
    doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")

  protected val offsetCol = nullableStringParam(
    name = "offsetCol",
    doc = """Offset column. This will be added to the combination of columns before applying the link function.""")

  protected val foldCol = nullableStringParam(
    name = "foldCol",
    doc = """Column with cross-validation fold index assignment per observation.""")

  protected val foldAssignment = stringParam(
    name = "foldAssignment",
    doc = """Cross-validation fold assignment scheme, if fold_column is not specified. The 'Stratified' option will stratify the folds based on the response variable, for classification problems. Possible values are ``"AUTO"``, ``"Random"``, ``"Modulo"``, ``"Stratified"``.""")

  protected val categoricalEncoding = stringParam(
    name = "categoricalEncoding",
    doc = """Encoding scheme for categorical features. Possible values are ``"AUTO"``, ``"OneHotInternal"``, ``"OneHotExplicit"``, ``"Enum"``, ``"Binary"``, ``"Eigen"``, ``"LabelEncoder"``, ``"SortByResponse"``, ``"EnumLimited"``.""")

  protected val ignoreConstCols = booleanParam(
    name = "ignoreConstCols",
    doc = """Ignore constant columns.""")

  protected val scoreEachIteration = booleanParam(
    name = "scoreEachIteration",
    doc = """Whether to score during each iteration of model training.""")

  protected val stoppingRounds = intParam(
    name = "stoppingRounds",
    doc = """Early stopping based on convergence of stopping_metric. Stop if simple moving average of length k of the stopping_metric does not improve for k:=stopping_rounds scoring events (0 to disable).""")

  protected val maxRuntimeSecs = doubleParam(
    name = "maxRuntimeSecs",
    doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")

  protected val stoppingMetric = stringParam(
    name = "stoppingMetric",
    doc = """Metric to use for early stopping (AUTO: logloss for classification, deviance for regression and anomaly_score for Isolation Forest). Note that custom and custom_increasing can only be used in GBM and DRF with the Python client. Possible values are ``"AUTO"``, ``"deviance"``, ``"logloss"``, ``"MSE"``, ``"RMSE"``, ``"MAE"``, ``"RMSLE"``, ``"AUC"``, ``"AUCPR"``, ``"lift_top_group"``, ``"misclassification"``, ``"mean_per_class_error"``, ``"anomaly_score"``, ``"AUUC"``, ``"ATE"``, ``"ATT"``, ``"ATC"``, ``"qini"``, ``"custom"``, ``"custom_increasing"``.""")

  protected val stoppingTolerance = doubleParam(
    name = "stoppingTolerance",
    doc = """Relative tolerance for metric-based stopping criterion (stop if relative improvement is not at least this much).""")

  protected val gainsliftBins = intParam(
    name = "gainsliftBins",
    doc = """Gains/Lift table number of bins. 0 means disabled.. Default value -1 means automatic binning.""")

  protected val customMetricFunc = nullableStringParam(
    name = "customMetricFunc",
    doc = """Reference to custom evaluation function, format: `language:keyName=funcName`.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  protected val aucType = stringParam(
    name = "aucType",
    doc = """Set default multinomial AUC type. Possible values are ``"AUTO"``, ``"NONE"``, ``"MACRO_OVR"``, ``"WEIGHTED_OVR"``, ``"MACRO_OVO"``, ``"WEIGHTED_OVO"``.""")

  //
  // Default values
  //
  setDefault(
    mtries -> -1,
    binomialDoubleTrees -> false,
    sampleRate -> 0.632,
    balanceClasses -> false,
    classSamplingFactors -> null,
    maxAfterBalanceSize -> 5.0f,
    maxConfusionMatrixSize -> 20,
    ntrees -> 50,
    maxDepth -> 20,
    minRows -> 1.0,
    nbins -> 20,
    nbinsTopLevel -> 1024,
    nbinsCats -> 1024,
    seed -> -1L,
    buildTreeOneNode -> false,
    sampleRatePerClass -> null,
    colSampleRatePerTree -> 1.0,
    colSampleRateChangePerLevel -> 1.0,
    scoreTreeInterval -> 0,
    minSplitImprovement -> 1.0e-5,
    histogramType -> HistogramType.AUTO.name(),
    calibrateModel -> false,
    calibrationMethod -> CalibrationMethod.AUTO.name(),
    checkConstantResponse -> true,
    modelId -> null,
    nfolds -> 0,
    keepCrossValidationModels -> true,
    keepCrossValidationPredictions -> false,
    keepCrossValidationFoldAssignment -> false,
    distribution -> DistributionFamily.AUTO.name(),
    labelCol -> "label",
    weightCol -> null,
    offsetCol -> null,
    foldCol -> null,
    foldAssignment -> FoldAssignmentScheme.AUTO.name(),
    categoricalEncoding -> CategoricalEncodingScheme.AUTO.name(),
    ignoreConstCols -> true,
    scoreEachIteration -> false,
    stoppingRounds -> 0,
    maxRuntimeSecs -> 0.0,
    stoppingMetric -> StoppingMetric.AUTO.name(),
    stoppingTolerance -> 0.001,
    gainsliftBins -> -1,
    customMetricFunc -> null,
    exportCheckpointsDir -> null,
    aucType -> MultinomialAucType.AUTO.name())

  //
  // Getters
  //
  def getMtries(): Int = $(mtries)

  def getBinomialDoubleTrees(): Boolean = $(binomialDoubleTrees)

  def getSampleRate(): Double = $(sampleRate)

  def getBalanceClasses(): Boolean = $(balanceClasses)

  def getClassSamplingFactors(): Array[Float] = $(classSamplingFactors)

  def getMaxAfterBalanceSize(): Float = $(maxAfterBalanceSize)

  def getMaxConfusionMatrixSize(): Int = $(maxConfusionMatrixSize)

  def getNtrees(): Int = $(ntrees)

  def getMaxDepth(): Int = $(maxDepth)

  def getMinRows(): Double = $(minRows)

  def getNbins(): Int = $(nbins)

  def getNbinsTopLevel(): Int = $(nbinsTopLevel)

  def getNbinsCats(): Int = $(nbinsCats)

  def getSeed(): Long = $(seed)

  def getBuildTreeOneNode(): Boolean = $(buildTreeOneNode)

  def getSampleRatePerClass(): Array[Double] = $(sampleRatePerClass)

  def getColSampleRatePerTree(): Double = $(colSampleRatePerTree)

  def getColSampleRateChangePerLevel(): Double = $(colSampleRateChangePerLevel)

  def getScoreTreeInterval(): Int = $(scoreTreeInterval)

  def getMinSplitImprovement(): Double = $(minSplitImprovement)

  def getHistogramType(): String = $(histogramType)

  def getCalibrateModel(): Boolean = $(calibrateModel)

  def getCalibrationMethod(): String = $(calibrationMethod)

  def getCheckConstantResponse(): Boolean = $(checkConstantResponse)

  def getModelId(): String = $(modelId)

  def getNfolds(): Int = $(nfolds)

  def getKeepCrossValidationModels(): Boolean = $(keepCrossValidationModels)

  def getKeepCrossValidationPredictions(): Boolean = $(keepCrossValidationPredictions)

  def getKeepCrossValidationFoldAssignment(): Boolean = $(keepCrossValidationFoldAssignment)

  def getDistribution(): String = $(distribution)

  def getLabelCol(): String = $(labelCol)

  def getWeightCol(): String = $(weightCol)

  def getOffsetCol(): String = $(offsetCol)

  def getFoldCol(): String = $(foldCol)

  def getFoldAssignment(): String = $(foldAssignment)

  def getCategoricalEncoding(): String = $(categoricalEncoding)

  def getIgnoreConstCols(): Boolean = $(ignoreConstCols)

  def getScoreEachIteration(): Boolean = $(scoreEachIteration)

  def getStoppingRounds(): Int = $(stoppingRounds)

  def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)

  def getStoppingMetric(): String = $(stoppingMetric)

  def getStoppingTolerance(): Double = $(stoppingTolerance)

  def getGainsliftBins(): Int = $(gainsliftBins)

  def getCustomMetricFunc(): String = $(customMetricFunc)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  def getAucType(): String = $(aucType)

  //
  // Setters
  //
  def setMtries(value: Int): this.type = {
    set(mtries, value)
  }
           
  def setBinomialDoubleTrees(value: Boolean): this.type = {
    set(binomialDoubleTrees, value)
  }
           
  def setSampleRate(value: Double): this.type = {
    set(sampleRate, value)
  }
           
  def setBalanceClasses(value: Boolean): this.type = {
    set(balanceClasses, value)
  }
           
  def setClassSamplingFactors(value: Array[Float]): this.type = {
    set(classSamplingFactors, value)
  }
           
  def setMaxAfterBalanceSize(value: Float): this.type = {
    set(maxAfterBalanceSize, value)
  }
           
  def setMaxConfusionMatrixSize(value: Int): this.type = {
    set(maxConfusionMatrixSize, value)
  }
           
  def setNtrees(value: Int): this.type = {
    set(ntrees, value)
  }
           
  def setMaxDepth(value: Int): this.type = {
    set(maxDepth, value)
  }
           
  def setMinRows(value: Double): this.type = {
    set(minRows, value)
  }
           
  def setNbins(value: Int): this.type = {
    set(nbins, value)
  }
           
  def setNbinsTopLevel(value: Int): this.type = {
    set(nbinsTopLevel, value)
  }
           
  def setNbinsCats(value: Int): this.type = {
    set(nbinsCats, value)
  }
           
  def setSeed(value: Long): this.type = {
    set(seed, value)
  }
           
  def setBuildTreeOneNode(value: Boolean): this.type = {
    set(buildTreeOneNode, value)
  }
           
  def setSampleRatePerClass(value: Array[Double]): this.type = {
    set(sampleRatePerClass, value)
  }
           
  def setColSampleRatePerTree(value: Double): this.type = {
    set(colSampleRatePerTree, value)
  }
           
  def setColSampleRateChangePerLevel(value: Double): this.type = {
    set(colSampleRateChangePerLevel, value)
  }
           
  def setScoreTreeInterval(value: Int): this.type = {
    set(scoreTreeInterval, value)
  }
           
  def setMinSplitImprovement(value: Double): this.type = {
    set(minSplitImprovement, value)
  }
           
  def setHistogramType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[HistogramType](value)
    set(histogramType, validated)
  }
           
  def setCalibrateModel(value: Boolean): this.type = {
    set(calibrateModel, value)
  }
           
  def setCalibrationMethod(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[CalibrationMethod](value)
    set(calibrationMethod, validated)
  }
           
  def setCheckConstantResponse(value: Boolean): this.type = {
    set(checkConstantResponse, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setNfolds(value: Int): this.type = {
    set(nfolds, value)
  }
           
  def setKeepCrossValidationModels(value: Boolean): this.type = {
    set(keepCrossValidationModels, value)
  }
           
  def setKeepCrossValidationPredictions(value: Boolean): this.type = {
    set(keepCrossValidationPredictions, value)
  }
           
  def setKeepCrossValidationFoldAssignment(value: Boolean): this.type = {
    set(keepCrossValidationFoldAssignment, value)
  }
           
  def setDistribution(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[DistributionFamily](value)
    set(distribution, validated)
  }
           
  def setLabelCol(value: String): this.type = {
    set(labelCol, value)
  }
           
  def setWeightCol(value: String): this.type = {
    set(weightCol, value)
  }
           
  def setOffsetCol(value: String): this.type = {
    set(offsetCol, value)
  }
           
  def setFoldCol(value: String): this.type = {
    set(foldCol, value)
  }
           
  def setFoldAssignment(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[FoldAssignmentScheme](value)
    set(foldAssignment, validated)
  }
           
  def setCategoricalEncoding(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[CategoricalEncodingScheme](value)
    set(categoricalEncoding, validated)
  }
           
  def setIgnoreConstCols(value: Boolean): this.type = {
    set(ignoreConstCols, value)
  }
           
  def setScoreEachIteration(value: Boolean): this.type = {
    set(scoreEachIteration, value)
  }
           
  def setStoppingRounds(value: Int): this.type = {
    set(stoppingRounds, value)
  }
           
  def setMaxRuntimeSecs(value: Double): this.type = {
    set(maxRuntimeSecs, value)
  }
           
  def setStoppingMetric(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[StoppingMetric](value)
    set(stoppingMetric, validated)
  }
           
  def setStoppingTolerance(value: Double): this.type = {
    set(stoppingTolerance, value)
  }
           
  def setGainsliftBins(value: Int): this.type = {
    set(gainsliftBins, value)
  }
           
  def setCustomMetricFunc(value: String): this.type = {
    set(customMetricFunc, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           
  def setAucType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[MultinomialAucType](value)
    set(aucType, validated)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2ODRFParams(trainingFrame)
  }

  private[sparkling] def getH2ODRFParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "mtries" -> getMtries(),
        "binomial_double_trees" -> getBinomialDoubleTrees(),
        "sample_rate" -> getSampleRate(),
        "balance_classes" -> getBalanceClasses(),
        "class_sampling_factors" -> getClassSamplingFactors(),
        "max_after_balance_size" -> getMaxAfterBalanceSize(),
        "max_confusion_matrix_size" -> getMaxConfusionMatrixSize(),
        "ntrees" -> getNtrees(),
        "max_depth" -> getMaxDepth(),
        "min_rows" -> getMinRows(),
        "nbins" -> getNbins(),
        "nbins_top_level" -> getNbinsTopLevel(),
        "nbins_cats" -> getNbinsCats(),
        "seed" -> getSeed(),
        "build_tree_one_node" -> getBuildTreeOneNode(),
        "sample_rate_per_class" -> getSampleRatePerClass(),
        "col_sample_rate_per_tree" -> getColSampleRatePerTree(),
        "col_sample_rate_change_per_level" -> getColSampleRateChangePerLevel(),
        "score_tree_interval" -> getScoreTreeInterval(),
        "min_split_improvement" -> getMinSplitImprovement(),
        "histogram_type" -> getHistogramType(),
        "calibrate_model" -> getCalibrateModel(),
        "calibration_method" -> getCalibrationMethod(),
        "check_constant_response" -> getCheckConstantResponse(),
        "model_id" -> getModelId(),
        "nfolds" -> getNfolds(),
        "keep_cross_validation_models" -> getKeepCrossValidationModels(),
        "keep_cross_validation_predictions" -> getKeepCrossValidationPredictions(),
        "keep_cross_validation_fold_assignment" -> getKeepCrossValidationFoldAssignment(),
        "distribution" -> getDistribution(),
        "response_column" -> getLabelCol(),
        "weights_column" -> getWeightCol(),
        "offset_column" -> getOffsetCol(),
        "fold_column" -> getFoldCol(),
        "fold_assignment" -> getFoldAssignment(),
        "categorical_encoding" -> getCategoricalEncoding(),
        "ignore_const_cols" -> getIgnoreConstCols(),
        "score_each_iteration" -> getScoreEachIteration(),
        "stopping_rounds" -> getStoppingRounds(),
        "max_runtime_secs" -> getMaxRuntimeSecs(),
        "stopping_metric" -> getStoppingMetric(),
        "stopping_tolerance" -> getStoppingTolerance(),
        "gainslift_bins" -> getGainsliftBins(),
        "custom_metric_func" -> getCustomMetricFunc(),
        "export_checkpoints_dir" -> getExportCheckpointsDir(),
        "auc_type" -> getAucType()) +++
      getCalibrationDataFrameParam(trainingFrame) +++
      getIgnoredColsParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "mtries" -> "mtries",
        "binomialDoubleTrees" -> "binomial_double_trees",
        "sampleRate" -> "sample_rate",
        "balanceClasses" -> "balance_classes",
        "classSamplingFactors" -> "class_sampling_factors",
        "maxAfterBalanceSize" -> "max_after_balance_size",
        "maxConfusionMatrixSize" -> "max_confusion_matrix_size",
        "ntrees" -> "ntrees",
        "maxDepth" -> "max_depth",
        "minRows" -> "min_rows",
        "nbins" -> "nbins",
        "nbinsTopLevel" -> "nbins_top_level",
        "nbinsCats" -> "nbins_cats",
        "seed" -> "seed",
        "buildTreeOneNode" -> "build_tree_one_node",
        "sampleRatePerClass" -> "sample_rate_per_class",
        "colSampleRatePerTree" -> "col_sample_rate_per_tree",
        "colSampleRateChangePerLevel" -> "col_sample_rate_change_per_level",
        "scoreTreeInterval" -> "score_tree_interval",
        "minSplitImprovement" -> "min_split_improvement",
        "histogramType" -> "histogram_type",
        "calibrateModel" -> "calibrate_model",
        "calibrationMethod" -> "calibration_method",
        "checkConstantResponse" -> "check_constant_response",
        "modelId" -> "model_id",
        "nfolds" -> "nfolds",
        "keepCrossValidationModels" -> "keep_cross_validation_models",
        "keepCrossValidationPredictions" -> "keep_cross_validation_predictions",
        "keepCrossValidationFoldAssignment" -> "keep_cross_validation_fold_assignment",
        "distribution" -> "distribution",
        "labelCol" -> "response_column",
        "weightCol" -> "weights_column",
        "offsetCol" -> "offset_column",
        "foldCol" -> "fold_column",
        "foldAssignment" -> "fold_assignment",
        "categoricalEncoding" -> "categorical_encoding",
        "ignoreConstCols" -> "ignore_const_cols",
        "scoreEachIteration" -> "score_each_iteration",
        "stoppingRounds" -> "stopping_rounds",
        "maxRuntimeSecs" -> "max_runtime_secs",
        "stoppingMetric" -> "stopping_metric",
        "stoppingTolerance" -> "stopping_tolerance",
        "gainsliftBins" -> "gainslift_bins",
        "customMetricFunc" -> "custom_metric_func",
        "exportCheckpointsDir" -> "export_checkpoints_dir",
        "aucType" -> "auc_type")
  }
      
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy