All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.ml.params.H2OXGBoostParams.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.tree.xgboost.XGBoostModel.XGBoostParameters
import ai.h2o.sparkling.H2OFrame
import hex.tree.CalibrationHelper.CalibrationMethod
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.TreeMethod
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.GrowPolicy
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.Booster
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.DartSampleType
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.DartNormalizeType
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.DMatrixType
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.Backend
import hex.genmodel.utils.DistributionFamily
import hex.Model.Parameters.FoldAssignmentScheme
import hex.Model.Parameters.CategoricalEncodingScheme
import hex.ScoreKeeper.StoppingMetric
import hex.MultinomialAucType

trait H2OXGBoostParams
  extends H2OAlgoParamsBase
  with HasMonotoneConstraints
  with HasCalibrationDataFrame
  with HasIgnoredCols {

  protected def paramTag = reflect.classTag[XGBoostParameters]

  //
  // Parameter definitions
  //
  protected val ntrees = intParam(
    name = "ntrees",
    doc = """(same as n_estimators) Number of trees.""")

  protected val maxDepth = intParam(
    name = "maxDepth",
    doc = """Maximum tree depth (0 for unlimited).""")

  protected val minRows = doubleParam(
    name = "minRows",
    doc = """(same as min_child_weight) Fewest allowed (weighted) observations in a leaf.""")

  protected val minChildWeight = doubleParam(
    name = "minChildWeight",
    doc = """(same as min_rows) Fewest allowed (weighted) observations in a leaf.""")

  protected val learnRate = doubleParam(
    name = "learnRate",
    doc = """(same as eta) Learning rate (from 0.0 to 1.0).""")

  protected val eta = doubleParam(
    name = "eta",
    doc = """(same as learn_rate) Learning rate (from 0.0 to 1.0).""")

  protected val sampleRate = doubleParam(
    name = "sampleRate",
    doc = """(same as subsample) Row sample rate per tree (from 0.0 to 1.0).""")

  protected val subsample = doubleParam(
    name = "subsample",
    doc = """(same as sample_rate) Row sample rate per tree (from 0.0 to 1.0).""")

  protected val colSampleRate = doubleParam(
    name = "colSampleRate",
    doc = """(same as colsample_bylevel) Column sample rate (from 0.0 to 1.0).""")

  protected val colSampleByLevel = doubleParam(
    name = "colSampleByLevel",
    doc = """(same as col_sample_rate) Column sample rate (from 0.0 to 1.0).""")

  protected val colSampleRatePerTree = doubleParam(
    name = "colSampleRatePerTree",
    doc = """(same as colsample_bytree) Column sample rate per tree (from 0.0 to 1.0).""")

  protected val colSampleByTree = doubleParam(
    name = "colSampleByTree",
    doc = """(same as col_sample_rate_per_tree) Column sample rate per tree (from 0.0 to 1.0).""")

  protected val colSampleByNode = doubleParam(
    name = "colSampleByNode",
    doc = """Column sample rate per tree node (from 0.0 to 1.0).""")

  protected val maxAbsLeafnodePred = floatParam(
    name = "maxAbsLeafnodePred",
    doc = """(same as max_delta_step) Maximum absolute value of a leaf node prediction.""")

  protected val maxDeltaStep = floatParam(
    name = "maxDeltaStep",
    doc = """(same as max_abs_leafnode_pred) Maximum absolute value of a leaf node prediction.""")

  protected val scoreTreeInterval = intParam(
    name = "scoreTreeInterval",
    doc = """Score the model after every so many trees. Disabled if set to 0.""")

  protected val seed = longParam(
    name = "seed",
    doc = """Seed for pseudo random number generator (if applicable).""")

  protected val minSplitImprovement = floatParam(
    name = "minSplitImprovement",
    doc = """(same as gamma) Minimum relative improvement in squared error reduction for a split to happen.""")

  protected val gamma = floatParam(
    name = "gamma",
    doc = """(same as min_split_improvement) Minimum relative improvement in squared error reduction for a split to happen.""")

  protected val nthread = intParam(
    name = "nthread",
    doc = """Number of parallel threads that can be used to run XGBoost. Cannot exceed H2O cluster limits (-nthreads parameter). Defaults to maximum available.""")

  protected val buildTreeOneNode = booleanParam(
    name = "buildTreeOneNode",
    doc = """Run on one node only; no network overhead but fewer cpus used. Suitable for small datasets.""")

  protected val saveMatrixDirectory = nullableStringParam(
    name = "saveMatrixDirectory",
    doc = """Directory where to save matrices passed to XGBoost library. Useful for debugging.""")

  protected val calibrateModel = booleanParam(
    name = "calibrateModel",
    doc = """Use Platt Scaling (default) or Isotonic Regression to calculate calibrated class probabilities. Calibration can provide more accurate estimates of class probabilities.""")

  protected val calibrationMethod = stringParam(
    name = "calibrationMethod",
    doc = """Calibration method to use. Possible values are ``"AUTO"``, ``"PlattScaling"``, ``"IsotonicRegression"``.""")

  protected val maxBins = intParam(
    name = "maxBins",
    doc = """For tree_method=hist only: maximum number of bins.""")

  protected val maxLeaves = intParam(
    name = "maxLeaves",
    doc = """For tree_method=hist only: maximum number of leaves.""")

  protected val treeMethod = stringParam(
    name = "treeMethod",
    doc = """Tree method. Possible values are ``"auto"``, ``"exact"``, ``"approx"``, ``"hist"``.""")

  protected val growPolicy = stringParam(
    name = "growPolicy",
    doc = """Grow policy - depthwise is standard GBM, lossguide is LightGBM. Possible values are ``"depthwise"``, ``"lossguide"``.""")

  protected val booster = stringParam(
    name = "booster",
    doc = """Booster type. Possible values are ``"gbtree"``, ``"gblinear"``, ``"dart"``.""")

  protected val regLambda = floatParam(
    name = "regLambda",
    doc = """L2 regularization.""")

  protected val regAlpha = floatParam(
    name = "regAlpha",
    doc = """L1 regularization.""")

  protected val quietMode = booleanParam(
    name = "quietMode",
    doc = """Enable quiet mode.""")

  protected val sampleType = stringParam(
    name = "sampleType",
    doc = """For booster=dart only: sample_type. Possible values are ``"uniform"``, ``"weighted"``.""")

  protected val normalizeType = stringParam(
    name = "normalizeType",
    doc = """For booster=dart only: normalize_type. Possible values are ``"tree"``, ``"forest"``.""")

  protected val rateDrop = floatParam(
    name = "rateDrop",
    doc = """For booster=dart only: rate_drop (0..1).""")

  protected val oneDrop = booleanParam(
    name = "oneDrop",
    doc = """For booster=dart only: one_drop.""")

  protected val skipDrop = floatParam(
    name = "skipDrop",
    doc = """For booster=dart only: skip_drop (0..1).""")

  protected val dmatrixType = stringParam(
    name = "dmatrixType",
    doc = """Type of DMatrix. For sparse, NAs and 0 are treated equally. Possible values are ``"auto"``, ``"dense"``, ``"sparse"``.""")

  protected val backend = stringParam(
    name = "backend",
    doc = """Backend. By default (auto), a GPU is used if available. Possible values are ``"auto"``, ``"gpu"``, ``"cpu"``.""")

  protected val gpuId = nullableIntArrayParam(
    name = "gpuId",
    doc = """Which GPU(s) to use. .""")

  protected val interactionConstraints = nullableStringArrayArrayParam(
    name = "interactionConstraints",
    doc = """A set of allowed column interactions.""")

  protected val scalePosWeight = floatParam(
    name = "scalePosWeight",
    doc = """Controls the effect of observations with positive labels in relation to the observations with negative labels on gradient calculation. Useful for imbalanced problems.""")

  protected val evalMetric = nullableStringParam(
    name = "evalMetric",
    doc = """Specification of evaluation metric that will be passed to the native XGBoost backend.""")

  protected val scoreEvalMetricOnly = booleanParam(
    name = "scoreEvalMetricOnly",
    doc = """If enabled, score only the evaluation metric. This can make model training faster if scoring is frequent (eg. each iteration).""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val nfolds = intParam(
    name = "nfolds",
    doc = """Number of folds for K-fold cross-validation (0 to disable or >= 2).""")

  protected val keepCrossValidationModels = booleanParam(
    name = "keepCrossValidationModels",
    doc = """Whether to keep the cross-validation models.""")

  protected val keepCrossValidationPredictions = booleanParam(
    name = "keepCrossValidationPredictions",
    doc = """Whether to keep the predictions of the cross-validation models.""")

  protected val keepCrossValidationFoldAssignment = booleanParam(
    name = "keepCrossValidationFoldAssignment",
    doc = """Whether to keep the cross-validation fold assignment.""")

  protected val parallelizeCrossValidation = booleanParam(
    name = "parallelizeCrossValidation",
    doc = """Allow parallel training of cross-validation models.""")

  protected val distribution = stringParam(
    name = "distribution",
    doc = """Distribution function. Possible values are ``"AUTO"``, ``"bernoulli"``, ``"quasibinomial"``, ``"modified_huber"``, ``"multinomial"``, ``"ordinal"``, ``"gaussian"``, ``"poisson"``, ``"gamma"``, ``"tweedie"``, ``"huber"``, ``"laplace"``, ``"quantile"``, ``"fractionalbinomial"``, ``"negativebinomial"``, ``"custom"``.""")

  protected val tweediePower = doubleParam(
    name = "tweediePower",
    doc = """Tweedie power for Tweedie regression, must be between 1 and 2.""")

  protected val labelCol = stringParam(
    name = "labelCol",
    doc = """Response variable column.""")

  protected val weightCol = nullableStringParam(
    name = "weightCol",
    doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")

  protected val offsetCol = nullableStringParam(
    name = "offsetCol",
    doc = """Offset column. This will be added to the combination of columns before applying the link function.""")

  protected val foldCol = nullableStringParam(
    name = "foldCol",
    doc = """Column with cross-validation fold index assignment per observation.""")

  protected val foldAssignment = stringParam(
    name = "foldAssignment",
    doc = """Cross-validation fold assignment scheme, if fold_column is not specified. The 'Stratified' option will stratify the folds based on the response variable, for classification problems. Possible values are ``"AUTO"``, ``"Random"``, ``"Modulo"``, ``"Stratified"``.""")

  protected val categoricalEncoding = stringParam(
    name = "categoricalEncoding",
    doc = """Encoding scheme for categorical features. Possible values are ``"AUTO"``, ``"OneHotInternal"``, ``"OneHotExplicit"``, ``"Enum"``, ``"Binary"``, ``"Eigen"``, ``"LabelEncoder"``, ``"SortByResponse"``, ``"EnumLimited"``.""")

  protected val ignoreConstCols = booleanParam(
    name = "ignoreConstCols",
    doc = """Ignore constant columns.""")

  protected val scoreEachIteration = booleanParam(
    name = "scoreEachIteration",
    doc = """Whether to score during each iteration of model training.""")

  protected val stoppingRounds = intParam(
    name = "stoppingRounds",
    doc = """Early stopping based on convergence of stopping_metric. Stop if simple moving average of length k of the stopping_metric does not improve for k:=stopping_rounds scoring events (0 to disable).""")

  protected val maxRuntimeSecs = doubleParam(
    name = "maxRuntimeSecs",
    doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")

  protected val stoppingMetric = stringParam(
    name = "stoppingMetric",
    doc = """Metric to use for early stopping (AUTO: logloss for classification, deviance for regression and anomaly_score for Isolation Forest). Note that custom and custom_increasing can only be used in GBM and DRF with the Python client. Possible values are ``"AUTO"``, ``"deviance"``, ``"logloss"``, ``"MSE"``, ``"RMSE"``, ``"MAE"``, ``"RMSLE"``, ``"AUC"``, ``"AUCPR"``, ``"lift_top_group"``, ``"misclassification"``, ``"mean_per_class_error"``, ``"anomaly_score"``, ``"AUUC"``, ``"ATE"``, ``"ATT"``, ``"ATC"``, ``"qini"``, ``"custom"``, ``"custom_increasing"``.""")

  protected val stoppingTolerance = doubleParam(
    name = "stoppingTolerance",
    doc = """Relative tolerance for metric-based stopping criterion (stop if relative improvement is not at least this much).""")

  protected val gainsliftBins = intParam(
    name = "gainsliftBins",
    doc = """Gains/Lift table number of bins. 0 means disabled.. Default value -1 means automatic binning.""")

  protected val customMetricFunc = nullableStringParam(
    name = "customMetricFunc",
    doc = """Reference to custom evaluation function, format: `language:keyName=funcName`.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  protected val aucType = stringParam(
    name = "aucType",
    doc = """Set default multinomial AUC type. Possible values are ``"AUTO"``, ``"NONE"``, ``"MACRO_OVR"``, ``"WEIGHTED_OVR"``, ``"MACRO_OVO"``, ``"WEIGHTED_OVO"``.""")

  //
  // Default values
  //
  setDefault(
    ntrees -> 50,
    maxDepth -> 6,
    minRows -> 1.0,
    minChildWeight -> 1.0,
    learnRate -> 0.3,
    eta -> 0.3,
    sampleRate -> 1.0,
    subsample -> 1.0,
    colSampleRate -> 1.0,
    colSampleByLevel -> 1.0,
    colSampleRatePerTree -> 1.0,
    colSampleByTree -> 1.0,
    colSampleByNode -> 1.0,
    maxAbsLeafnodePred -> 0.0f,
    maxDeltaStep -> 0.0f,
    scoreTreeInterval -> 0,
    seed -> -1L,
    minSplitImprovement -> 0.0f,
    gamma -> 0.0f,
    nthread -> -1,
    buildTreeOneNode -> false,
    saveMatrixDirectory -> null,
    calibrateModel -> false,
    calibrationMethod -> CalibrationMethod.AUTO.name(),
    maxBins -> 256,
    maxLeaves -> 0,
    treeMethod -> TreeMethod.auto.name(),
    growPolicy -> GrowPolicy.depthwise.name(),
    booster -> Booster.gbtree.name(),
    regLambda -> 1.0f,
    regAlpha -> 0.0f,
    quietMode -> true,
    sampleType -> DartSampleType.uniform.name(),
    normalizeType -> DartNormalizeType.tree.name(),
    rateDrop -> 0.0f,
    oneDrop -> false,
    skipDrop -> 0.0f,
    dmatrixType -> DMatrixType.auto.name(),
    backend -> Backend.auto.name(),
    gpuId -> null,
    interactionConstraints -> null,
    scalePosWeight -> 1.0f,
    evalMetric -> null,
    scoreEvalMetricOnly -> false,
    modelId -> null,
    nfolds -> 0,
    keepCrossValidationModels -> true,
    keepCrossValidationPredictions -> false,
    keepCrossValidationFoldAssignment -> false,
    parallelizeCrossValidation -> true,
    distribution -> DistributionFamily.AUTO.name(),
    tweediePower -> 1.5,
    labelCol -> "label",
    weightCol -> null,
    offsetCol -> null,
    foldCol -> null,
    foldAssignment -> FoldAssignmentScheme.AUTO.name(),
    categoricalEncoding -> CategoricalEncodingScheme.AUTO.name(),
    ignoreConstCols -> true,
    scoreEachIteration -> false,
    stoppingRounds -> 0,
    maxRuntimeSecs -> 0.0,
    stoppingMetric -> StoppingMetric.AUTO.name(),
    stoppingTolerance -> 0.001,
    gainsliftBins -> -1,
    customMetricFunc -> null,
    exportCheckpointsDir -> null,
    aucType -> MultinomialAucType.AUTO.name())

  //
  // Getters
  //
  def getNtrees(): Int = $(ntrees)

  def getMaxDepth(): Int = $(maxDepth)

  def getMinRows(): Double = $(minRows)

  def getMinChildWeight(): Double = $(minChildWeight)

  def getLearnRate(): Double = $(learnRate)

  def getEta(): Double = $(eta)

  def getSampleRate(): Double = $(sampleRate)

  def getSubsample(): Double = $(subsample)

  def getColSampleRate(): Double = $(colSampleRate)

  def getColSampleByLevel(): Double = $(colSampleByLevel)

  def getColSampleRatePerTree(): Double = $(colSampleRatePerTree)

  def getColSampleByTree(): Double = $(colSampleByTree)

  def getColSampleByNode(): Double = $(colSampleByNode)

  def getMaxAbsLeafnodePred(): Float = $(maxAbsLeafnodePred)

  def getMaxDeltaStep(): Float = $(maxDeltaStep)

  def getScoreTreeInterval(): Int = $(scoreTreeInterval)

  def getSeed(): Long = $(seed)

  def getMinSplitImprovement(): Float = $(minSplitImprovement)

  def getGamma(): Float = $(gamma)

  def getNthread(): Int = $(nthread)

  def getBuildTreeOneNode(): Boolean = $(buildTreeOneNode)

  def getSaveMatrixDirectory(): String = $(saveMatrixDirectory)

  def getCalibrateModel(): Boolean = $(calibrateModel)

  def getCalibrationMethod(): String = $(calibrationMethod)

  def getMaxBins(): Int = $(maxBins)

  def getMaxLeaves(): Int = $(maxLeaves)

  def getTreeMethod(): String = $(treeMethod)

  def getGrowPolicy(): String = $(growPolicy)

  def getBooster(): String = $(booster)

  def getRegLambda(): Float = $(regLambda)

  def getRegAlpha(): Float = $(regAlpha)

  def getQuietMode(): Boolean = $(quietMode)

  def getSampleType(): String = $(sampleType)

  def getNormalizeType(): String = $(normalizeType)

  def getRateDrop(): Float = $(rateDrop)

  def getOneDrop(): Boolean = $(oneDrop)

  def getSkipDrop(): Float = $(skipDrop)

  def getDmatrixType(): String = $(dmatrixType)

  def getBackend(): String = $(backend)

  def getGpuId(): Array[Int] = $(gpuId)

  def getInteractionConstraints(): Array[Array[String]] = $(interactionConstraints)

  def getScalePosWeight(): Float = $(scalePosWeight)

  def getEvalMetric(): String = $(evalMetric)

  def getScoreEvalMetricOnly(): Boolean = $(scoreEvalMetricOnly)

  def getModelId(): String = $(modelId)

  def getNfolds(): Int = $(nfolds)

  def getKeepCrossValidationModels(): Boolean = $(keepCrossValidationModels)

  def getKeepCrossValidationPredictions(): Boolean = $(keepCrossValidationPredictions)

  def getKeepCrossValidationFoldAssignment(): Boolean = $(keepCrossValidationFoldAssignment)

  def getParallelizeCrossValidation(): Boolean = $(parallelizeCrossValidation)

  def getDistribution(): String = $(distribution)

  def getTweediePower(): Double = $(tweediePower)

  def getLabelCol(): String = $(labelCol)

  def getWeightCol(): String = $(weightCol)

  def getOffsetCol(): String = $(offsetCol)

  def getFoldCol(): String = $(foldCol)

  def getFoldAssignment(): String = $(foldAssignment)

  def getCategoricalEncoding(): String = $(categoricalEncoding)

  def getIgnoreConstCols(): Boolean = $(ignoreConstCols)

  def getScoreEachIteration(): Boolean = $(scoreEachIteration)

  def getStoppingRounds(): Int = $(stoppingRounds)

  def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)

  def getStoppingMetric(): String = $(stoppingMetric)

  def getStoppingTolerance(): Double = $(stoppingTolerance)

  def getGainsliftBins(): Int = $(gainsliftBins)

  def getCustomMetricFunc(): String = $(customMetricFunc)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  def getAucType(): String = $(aucType)

  //
  // Setters
  //
  def setNtrees(value: Int): this.type = {
    set(ntrees, value)
  }
           
  def setMaxDepth(value: Int): this.type = {
    set(maxDepth, value)
  }
           
  def setMinRows(value: Double): this.type = {
    set(minRows, value)
  }
           
  def setMinChildWeight(value: Double): this.type = {
    set(minChildWeight, value)
  }
           
  def setLearnRate(value: Double): this.type = {
    set(learnRate, value)
  }
           
  def setEta(value: Double): this.type = {
    set(eta, value)
  }
           
  def setSampleRate(value: Double): this.type = {
    set(sampleRate, value)
  }
           
  def setSubsample(value: Double): this.type = {
    set(subsample, value)
  }
           
  def setColSampleRate(value: Double): this.type = {
    set(colSampleRate, value)
  }
           
  def setColSampleByLevel(value: Double): this.type = {
    set(colSampleByLevel, value)
  }
           
  def setColSampleRatePerTree(value: Double): this.type = {
    set(colSampleRatePerTree, value)
  }
           
  def setColSampleByTree(value: Double): this.type = {
    set(colSampleByTree, value)
  }
           
  def setColSampleByNode(value: Double): this.type = {
    set(colSampleByNode, value)
  }
           
  def setMaxAbsLeafnodePred(value: Float): this.type = {
    set(maxAbsLeafnodePred, value)
  }
           
  def setMaxDeltaStep(value: Float): this.type = {
    set(maxDeltaStep, value)
  }
           
  def setScoreTreeInterval(value: Int): this.type = {
    set(scoreTreeInterval, value)
  }
           
  def setSeed(value: Long): this.type = {
    set(seed, value)
  }
           
  def setMinSplitImprovement(value: Float): this.type = {
    set(minSplitImprovement, value)
  }
           
  def setGamma(value: Float): this.type = {
    set(gamma, value)
  }
           
  def setNthread(value: Int): this.type = {
    set(nthread, value)
  }
           
  def setBuildTreeOneNode(value: Boolean): this.type = {
    set(buildTreeOneNode, value)
  }
           
  def setSaveMatrixDirectory(value: String): this.type = {
    set(saveMatrixDirectory, value)
  }
           
  def setCalibrateModel(value: Boolean): this.type = {
    set(calibrateModel, value)
  }
           
  def setCalibrationMethod(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[CalibrationMethod](value)
    set(calibrationMethod, validated)
  }
           
  def setMaxBins(value: Int): this.type = {
    set(maxBins, value)
  }
           
  def setMaxLeaves(value: Int): this.type = {
    set(maxLeaves, value)
  }
           
  def setTreeMethod(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[TreeMethod](value)
    set(treeMethod, validated)
  }
           
  def setGrowPolicy(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[GrowPolicy](value)
    set(growPolicy, validated)
  }
           
  def setBooster(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[Booster](value)
    set(booster, validated)
  }
           
  def setRegLambda(value: Float): this.type = {
    set(regLambda, value)
  }
           
  def setRegAlpha(value: Float): this.type = {
    set(regAlpha, value)
  }
           
  def setQuietMode(value: Boolean): this.type = {
    set(quietMode, value)
  }
           
  def setSampleType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[DartSampleType](value)
    set(sampleType, validated)
  }
           
  def setNormalizeType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[DartNormalizeType](value)
    set(normalizeType, validated)
  }
           
  def setRateDrop(value: Float): this.type = {
    set(rateDrop, value)
  }
           
  def setOneDrop(value: Boolean): this.type = {
    set(oneDrop, value)
  }
           
  def setSkipDrop(value: Float): this.type = {
    set(skipDrop, value)
  }
           
  def setDmatrixType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[DMatrixType](value)
    set(dmatrixType, validated)
  }
           
  def setBackend(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[Backend](value)
    set(backend, validated)
  }
           
  def setGpuId(value: Array[Int]): this.type = {
    set(gpuId, value)
  }
           
  def setInteractionConstraints(value: Array[Array[String]]): this.type = {
    set(interactionConstraints, value)
  }
           
  def setScalePosWeight(value: Float): this.type = {
    set(scalePosWeight, value)
  }
           
  def setEvalMetric(value: String): this.type = {
    set(evalMetric, value)
  }
           
  def setScoreEvalMetricOnly(value: Boolean): this.type = {
    set(scoreEvalMetricOnly, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setNfolds(value: Int): this.type = {
    set(nfolds, value)
  }
           
  def setKeepCrossValidationModels(value: Boolean): this.type = {
    set(keepCrossValidationModels, value)
  }
           
  def setKeepCrossValidationPredictions(value: Boolean): this.type = {
    set(keepCrossValidationPredictions, value)
  }
           
  def setKeepCrossValidationFoldAssignment(value: Boolean): this.type = {
    set(keepCrossValidationFoldAssignment, value)
  }
           
  def setParallelizeCrossValidation(value: Boolean): this.type = {
    set(parallelizeCrossValidation, value)
  }
           
  def setDistribution(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[DistributionFamily](value)
    set(distribution, validated)
  }
           
  def setTweediePower(value: Double): this.type = {
    set(tweediePower, value)
  }
           
  def setLabelCol(value: String): this.type = {
    set(labelCol, value)
  }
           
  def setWeightCol(value: String): this.type = {
    set(weightCol, value)
  }
           
  def setOffsetCol(value: String): this.type = {
    set(offsetCol, value)
  }
           
  def setFoldCol(value: String): this.type = {
    set(foldCol, value)
  }
           
  def setFoldAssignment(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[FoldAssignmentScheme](value)
    set(foldAssignment, validated)
  }
           
  def setCategoricalEncoding(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[CategoricalEncodingScheme](value)
    set(categoricalEncoding, validated)
  }
           
  def setIgnoreConstCols(value: Boolean): this.type = {
    set(ignoreConstCols, value)
  }
           
  def setScoreEachIteration(value: Boolean): this.type = {
    set(scoreEachIteration, value)
  }
           
  def setStoppingRounds(value: Int): this.type = {
    set(stoppingRounds, value)
  }
           
  def setMaxRuntimeSecs(value: Double): this.type = {
    set(maxRuntimeSecs, value)
  }
           
  def setStoppingMetric(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[StoppingMetric](value)
    set(stoppingMetric, validated)
  }
           
  def setStoppingTolerance(value: Double): this.type = {
    set(stoppingTolerance, value)
  }
           
  def setGainsliftBins(value: Int): this.type = {
    set(gainsliftBins, value)
  }
           
  def setCustomMetricFunc(value: String): this.type = {
    set(customMetricFunc, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           
  def setAucType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[MultinomialAucType](value)
    set(aucType, validated)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2OXGBoostParams(trainingFrame)
  }

  private[sparkling] def getH2OXGBoostParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "ntrees" -> getNtrees(),
        "max_depth" -> getMaxDepth(),
        "min_rows" -> getMinRows(),
        "min_child_weight" -> getMinChildWeight(),
        "learn_rate" -> getLearnRate(),
        "eta" -> getEta(),
        "sample_rate" -> getSampleRate(),
        "subsample" -> getSubsample(),
        "col_sample_rate" -> getColSampleRate(),
        "colsample_bylevel" -> getColSampleByLevel(),
        "col_sample_rate_per_tree" -> getColSampleRatePerTree(),
        "colsample_bytree" -> getColSampleByTree(),
        "colsample_bynode" -> getColSampleByNode(),
        "max_abs_leafnode_pred" -> getMaxAbsLeafnodePred(),
        "max_delta_step" -> getMaxDeltaStep(),
        "score_tree_interval" -> getScoreTreeInterval(),
        "seed" -> getSeed(),
        "min_split_improvement" -> getMinSplitImprovement(),
        "gamma" -> getGamma(),
        "nthread" -> getNthread(),
        "build_tree_one_node" -> getBuildTreeOneNode(),
        "save_matrix_directory" -> getSaveMatrixDirectory(),
        "calibrate_model" -> getCalibrateModel(),
        "calibration_method" -> getCalibrationMethod(),
        "max_bins" -> getMaxBins(),
        "max_leaves" -> getMaxLeaves(),
        "tree_method" -> getTreeMethod(),
        "grow_policy" -> getGrowPolicy(),
        "booster" -> getBooster(),
        "reg_lambda" -> getRegLambda(),
        "reg_alpha" -> getRegAlpha(),
        "quiet_mode" -> getQuietMode(),
        "sample_type" -> getSampleType(),
        "normalize_type" -> getNormalizeType(),
        "rate_drop" -> getRateDrop(),
        "one_drop" -> getOneDrop(),
        "skip_drop" -> getSkipDrop(),
        "dmatrix_type" -> getDmatrixType(),
        "backend" -> getBackend(),
        "gpu_id" -> getGpuId(),
        "interaction_constraints" -> getInteractionConstraints(),
        "scale_pos_weight" -> getScalePosWeight(),
        "eval_metric" -> getEvalMetric(),
        "score_eval_metric_only" -> getScoreEvalMetricOnly(),
        "model_id" -> getModelId(),
        "nfolds" -> getNfolds(),
        "keep_cross_validation_models" -> getKeepCrossValidationModels(),
        "keep_cross_validation_predictions" -> getKeepCrossValidationPredictions(),
        "keep_cross_validation_fold_assignment" -> getKeepCrossValidationFoldAssignment(),
        "parallelize_cross_validation" -> getParallelizeCrossValidation(),
        "distribution" -> getDistribution(),
        "tweedie_power" -> getTweediePower(),
        "response_column" -> getLabelCol(),
        "weights_column" -> getWeightCol(),
        "offset_column" -> getOffsetCol(),
        "fold_column" -> getFoldCol(),
        "fold_assignment" -> getFoldAssignment(),
        "categorical_encoding" -> getCategoricalEncoding(),
        "ignore_const_cols" -> getIgnoreConstCols(),
        "score_each_iteration" -> getScoreEachIteration(),
        "stopping_rounds" -> getStoppingRounds(),
        "max_runtime_secs" -> getMaxRuntimeSecs(),
        "stopping_metric" -> getStoppingMetric(),
        "stopping_tolerance" -> getStoppingTolerance(),
        "gainslift_bins" -> getGainsliftBins(),
        "custom_metric_func" -> getCustomMetricFunc(),
        "export_checkpoints_dir" -> getExportCheckpointsDir(),
        "auc_type" -> getAucType()) +++
      getMonotoneConstraintsParam(trainingFrame) +++
      getCalibrationDataFrameParam(trainingFrame) +++
      getIgnoredColsParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "ntrees" -> "ntrees",
        "maxDepth" -> "max_depth",
        "minRows" -> "min_rows",
        "minChildWeight" -> "min_child_weight",
        "learnRate" -> "learn_rate",
        "eta" -> "eta",
        "sampleRate" -> "sample_rate",
        "subsample" -> "subsample",
        "colSampleRate" -> "col_sample_rate",
        "colSampleByLevel" -> "colsample_bylevel",
        "colSampleRatePerTree" -> "col_sample_rate_per_tree",
        "colSampleByTree" -> "colsample_bytree",
        "colSampleByNode" -> "colsample_bynode",
        "maxAbsLeafnodePred" -> "max_abs_leafnode_pred",
        "maxDeltaStep" -> "max_delta_step",
        "scoreTreeInterval" -> "score_tree_interval",
        "seed" -> "seed",
        "minSplitImprovement" -> "min_split_improvement",
        "gamma" -> "gamma",
        "nthread" -> "nthread",
        "buildTreeOneNode" -> "build_tree_one_node",
        "saveMatrixDirectory" -> "save_matrix_directory",
        "calibrateModel" -> "calibrate_model",
        "calibrationMethod" -> "calibration_method",
        "maxBins" -> "max_bins",
        "maxLeaves" -> "max_leaves",
        "treeMethod" -> "tree_method",
        "growPolicy" -> "grow_policy",
        "booster" -> "booster",
        "regLambda" -> "reg_lambda",
        "regAlpha" -> "reg_alpha",
        "quietMode" -> "quiet_mode",
        "sampleType" -> "sample_type",
        "normalizeType" -> "normalize_type",
        "rateDrop" -> "rate_drop",
        "oneDrop" -> "one_drop",
        "skipDrop" -> "skip_drop",
        "dmatrixType" -> "dmatrix_type",
        "backend" -> "backend",
        "gpuId" -> "gpu_id",
        "interactionConstraints" -> "interaction_constraints",
        "scalePosWeight" -> "scale_pos_weight",
        "evalMetric" -> "eval_metric",
        "scoreEvalMetricOnly" -> "score_eval_metric_only",
        "modelId" -> "model_id",
        "nfolds" -> "nfolds",
        "keepCrossValidationModels" -> "keep_cross_validation_models",
        "keepCrossValidationPredictions" -> "keep_cross_validation_predictions",
        "keepCrossValidationFoldAssignment" -> "keep_cross_validation_fold_assignment",
        "parallelizeCrossValidation" -> "parallelize_cross_validation",
        "distribution" -> "distribution",
        "tweediePower" -> "tweedie_power",
        "labelCol" -> "response_column",
        "weightCol" -> "weights_column",
        "offsetCol" -> "offset_column",
        "foldCol" -> "fold_column",
        "foldAssignment" -> "fold_assignment",
        "categoricalEncoding" -> "categorical_encoding",
        "ignoreConstCols" -> "ignore_const_cols",
        "scoreEachIteration" -> "score_each_iteration",
        "stoppingRounds" -> "stopping_rounds",
        "maxRuntimeSecs" -> "max_runtime_secs",
        "stoppingMetric" -> "stopping_metric",
        "stoppingTolerance" -> "stopping_tolerance",
        "gainsliftBins" -> "gainslift_bins",
        "customMetricFunc" -> "custom_metric_func",
        "exportCheckpointsDir" -> "export_checkpoints_dir",
        "aucType" -> "auc_type")
  }
      
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy