ai.h2o.sparkling.ml.params.H2OXGBoostParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.tree.xgboost.XGBoostModel.XGBoostParameters
import ai.h2o.sparkling.H2OFrame
import hex.tree.CalibrationHelper.CalibrationMethod
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.TreeMethod
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.GrowPolicy
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.Booster
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.DartSampleType
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.DartNormalizeType
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.DMatrixType
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.Backend
import hex.genmodel.utils.DistributionFamily
import hex.Model.Parameters.FoldAssignmentScheme
import hex.Model.Parameters.CategoricalEncodingScheme
import hex.ScoreKeeper.StoppingMetric
import hex.MultinomialAucType
trait H2OXGBoostParams
extends H2OAlgoParamsBase
with HasMonotoneConstraints
with HasCalibrationDataFrame
with HasIgnoredCols {
protected def paramTag = reflect.classTag[XGBoostParameters]
//
// Parameter definitions
//
protected val ntrees = intParam(
name = "ntrees",
doc = """(same as n_estimators) Number of trees.""")
protected val maxDepth = intParam(
name = "maxDepth",
doc = """Maximum tree depth (0 for unlimited).""")
protected val minRows = doubleParam(
name = "minRows",
doc = """(same as min_child_weight) Fewest allowed (weighted) observations in a leaf.""")
protected val minChildWeight = doubleParam(
name = "minChildWeight",
doc = """(same as min_rows) Fewest allowed (weighted) observations in a leaf.""")
protected val learnRate = doubleParam(
name = "learnRate",
doc = """(same as eta) Learning rate (from 0.0 to 1.0).""")
protected val eta = doubleParam(
name = "eta",
doc = """(same as learn_rate) Learning rate (from 0.0 to 1.0).""")
protected val sampleRate = doubleParam(
name = "sampleRate",
doc = """(same as subsample) Row sample rate per tree (from 0.0 to 1.0).""")
protected val subsample = doubleParam(
name = "subsample",
doc = """(same as sample_rate) Row sample rate per tree (from 0.0 to 1.0).""")
protected val colSampleRate = doubleParam(
name = "colSampleRate",
doc = """(same as colsample_bylevel) Column sample rate (from 0.0 to 1.0).""")
protected val colSampleByLevel = doubleParam(
name = "colSampleByLevel",
doc = """(same as col_sample_rate) Column sample rate (from 0.0 to 1.0).""")
protected val colSampleRatePerTree = doubleParam(
name = "colSampleRatePerTree",
doc = """(same as colsample_bytree) Column sample rate per tree (from 0.0 to 1.0).""")
protected val colSampleByTree = doubleParam(
name = "colSampleByTree",
doc = """(same as col_sample_rate_per_tree) Column sample rate per tree (from 0.0 to 1.0).""")
protected val colSampleByNode = doubleParam(
name = "colSampleByNode",
doc = """Column sample rate per tree node (from 0.0 to 1.0).""")
protected val maxAbsLeafnodePred = floatParam(
name = "maxAbsLeafnodePred",
doc = """(same as max_delta_step) Maximum absolute value of a leaf node prediction.""")
protected val maxDeltaStep = floatParam(
name = "maxDeltaStep",
doc = """(same as max_abs_leafnode_pred) Maximum absolute value of a leaf node prediction.""")
protected val scoreTreeInterval = intParam(
name = "scoreTreeInterval",
doc = """Score the model after every so many trees. Disabled if set to 0.""")
protected val seed = longParam(
name = "seed",
doc = """Seed for pseudo random number generator (if applicable).""")
protected val minSplitImprovement = floatParam(
name = "minSplitImprovement",
doc = """(same as gamma) Minimum relative improvement in squared error reduction for a split to happen.""")
protected val gamma = floatParam(
name = "gamma",
doc = """(same as min_split_improvement) Minimum relative improvement in squared error reduction for a split to happen.""")
protected val nthread = intParam(
name = "nthread",
doc = """Number of parallel threads that can be used to run XGBoost. Cannot exceed H2O cluster limits (-nthreads parameter). Defaults to maximum available.""")
protected val buildTreeOneNode = booleanParam(
name = "buildTreeOneNode",
doc = """Run on one node only; no network overhead but fewer cpus used. Suitable for small datasets.""")
protected val saveMatrixDirectory = nullableStringParam(
name = "saveMatrixDirectory",
doc = """Directory where to save matrices passed to XGBoost library. Useful for debugging.""")
protected val calibrateModel = booleanParam(
name = "calibrateModel",
doc = """Use Platt Scaling (default) or Isotonic Regression to calculate calibrated class probabilities. Calibration can provide more accurate estimates of class probabilities.""")
protected val calibrationMethod = stringParam(
name = "calibrationMethod",
doc = """Calibration method to use. Possible values are ``"AUTO"``, ``"PlattScaling"``, ``"IsotonicRegression"``.""")
protected val maxBins = intParam(
name = "maxBins",
doc = """For tree_method=hist only: maximum number of bins.""")
protected val maxLeaves = intParam(
name = "maxLeaves",
doc = """For tree_method=hist only: maximum number of leaves.""")
protected val treeMethod = stringParam(
name = "treeMethod",
doc = """Tree method. Possible values are ``"auto"``, ``"exact"``, ``"approx"``, ``"hist"``.""")
protected val growPolicy = stringParam(
name = "growPolicy",
doc = """Grow policy - depthwise is standard GBM, lossguide is LightGBM. Possible values are ``"depthwise"``, ``"lossguide"``.""")
protected val booster = stringParam(
name = "booster",
doc = """Booster type. Possible values are ``"gbtree"``, ``"gblinear"``, ``"dart"``.""")
protected val regLambda = floatParam(
name = "regLambda",
doc = """L2 regularization.""")
protected val regAlpha = floatParam(
name = "regAlpha",
doc = """L1 regularization.""")
protected val quietMode = booleanParam(
name = "quietMode",
doc = """Enable quiet mode.""")
protected val sampleType = stringParam(
name = "sampleType",
doc = """For booster=dart only: sample_type. Possible values are ``"uniform"``, ``"weighted"``.""")
protected val normalizeType = stringParam(
name = "normalizeType",
doc = """For booster=dart only: normalize_type. Possible values are ``"tree"``, ``"forest"``.""")
protected val rateDrop = floatParam(
name = "rateDrop",
doc = """For booster=dart only: rate_drop (0..1).""")
protected val oneDrop = booleanParam(
name = "oneDrop",
doc = """For booster=dart only: one_drop.""")
protected val skipDrop = floatParam(
name = "skipDrop",
doc = """For booster=dart only: skip_drop (0..1).""")
protected val dmatrixType = stringParam(
name = "dmatrixType",
doc = """Type of DMatrix. For sparse, NAs and 0 are treated equally. Possible values are ``"auto"``, ``"dense"``, ``"sparse"``.""")
protected val backend = stringParam(
name = "backend",
doc = """Backend. By default (auto), a GPU is used if available. Possible values are ``"auto"``, ``"gpu"``, ``"cpu"``.""")
protected val gpuId = nullableIntArrayParam(
name = "gpuId",
doc = """Which GPU(s) to use. .""")
protected val interactionConstraints = nullableStringArrayArrayParam(
name = "interactionConstraints",
doc = """A set of allowed column interactions.""")
protected val scalePosWeight = floatParam(
name = "scalePosWeight",
doc = """Controls the effect of observations with positive labels in relation to the observations with negative labels on gradient calculation. Useful for imbalanced problems.""")
protected val evalMetric = nullableStringParam(
name = "evalMetric",
doc = """Specification of evaluation metric that will be passed to the native XGBoost backend.""")
protected val scoreEvalMetricOnly = booleanParam(
name = "scoreEvalMetricOnly",
doc = """If enabled, score only the evaluation metric. This can make model training faster if scoring is frequent (eg. each iteration).""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val nfolds = intParam(
name = "nfolds",
doc = """Number of folds for K-fold cross-validation (0 to disable or >= 2).""")
protected val keepCrossValidationModels = booleanParam(
name = "keepCrossValidationModels",
doc = """Whether to keep the cross-validation models.""")
protected val keepCrossValidationPredictions = booleanParam(
name = "keepCrossValidationPredictions",
doc = """Whether to keep the predictions of the cross-validation models.""")
protected val keepCrossValidationFoldAssignment = booleanParam(
name = "keepCrossValidationFoldAssignment",
doc = """Whether to keep the cross-validation fold assignment.""")
protected val parallelizeCrossValidation = booleanParam(
name = "parallelizeCrossValidation",
doc = """Allow parallel training of cross-validation models.""")
protected val distribution = stringParam(
name = "distribution",
doc = """Distribution function. Possible values are ``"AUTO"``, ``"bernoulli"``, ``"quasibinomial"``, ``"modified_huber"``, ``"multinomial"``, ``"ordinal"``, ``"gaussian"``, ``"poisson"``, ``"gamma"``, ``"tweedie"``, ``"huber"``, ``"laplace"``, ``"quantile"``, ``"fractionalbinomial"``, ``"negativebinomial"``, ``"custom"``.""")
protected val tweediePower = doubleParam(
name = "tweediePower",
doc = """Tweedie power for Tweedie regression, must be between 1 and 2.""")
protected val labelCol = stringParam(
name = "labelCol",
doc = """Response variable column.""")
protected val weightCol = nullableStringParam(
name = "weightCol",
doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")
protected val offsetCol = nullableStringParam(
name = "offsetCol",
doc = """Offset column. This will be added to the combination of columns before applying the link function.""")
protected val foldCol = nullableStringParam(
name = "foldCol",
doc = """Column with cross-validation fold index assignment per observation.""")
protected val foldAssignment = stringParam(
name = "foldAssignment",
doc = """Cross-validation fold assignment scheme, if fold_column is not specified. The 'Stratified' option will stratify the folds based on the response variable, for classification problems. Possible values are ``"AUTO"``, ``"Random"``, ``"Modulo"``, ``"Stratified"``.""")
protected val categoricalEncoding = stringParam(
name = "categoricalEncoding",
doc = """Encoding scheme for categorical features. Possible values are ``"AUTO"``, ``"OneHotInternal"``, ``"OneHotExplicit"``, ``"Enum"``, ``"Binary"``, ``"Eigen"``, ``"LabelEncoder"``, ``"SortByResponse"``, ``"EnumLimited"``.""")
protected val ignoreConstCols = booleanParam(
name = "ignoreConstCols",
doc = """Ignore constant columns.""")
protected val scoreEachIteration = booleanParam(
name = "scoreEachIteration",
doc = """Whether to score during each iteration of model training.""")
protected val stoppingRounds = intParam(
name = "stoppingRounds",
doc = """Early stopping based on convergence of stopping_metric. Stop if simple moving average of length k of the stopping_metric does not improve for k:=stopping_rounds scoring events (0 to disable).""")
protected val maxRuntimeSecs = doubleParam(
name = "maxRuntimeSecs",
doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")
protected val stoppingMetric = stringParam(
name = "stoppingMetric",
doc = """Metric to use for early stopping (AUTO: logloss for classification, deviance for regression and anomaly_score for Isolation Forest). Note that custom and custom_increasing can only be used in GBM and DRF with the Python client. Possible values are ``"AUTO"``, ``"deviance"``, ``"logloss"``, ``"MSE"``, ``"RMSE"``, ``"MAE"``, ``"RMSLE"``, ``"AUC"``, ``"AUCPR"``, ``"lift_top_group"``, ``"misclassification"``, ``"mean_per_class_error"``, ``"anomaly_score"``, ``"AUUC"``, ``"ATE"``, ``"ATT"``, ``"ATC"``, ``"qini"``, ``"custom"``, ``"custom_increasing"``.""")
protected val stoppingTolerance = doubleParam(
name = "stoppingTolerance",
doc = """Relative tolerance for metric-based stopping criterion (stop if relative improvement is not at least this much).""")
protected val gainsliftBins = intParam(
name = "gainsliftBins",
doc = """Gains/Lift table number of bins. 0 means disabled.. Default value -1 means automatic binning.""")
protected val customMetricFunc = nullableStringParam(
name = "customMetricFunc",
doc = """Reference to custom evaluation function, format: `language:keyName=funcName`.""")
protected val exportCheckpointsDir = nullableStringParam(
name = "exportCheckpointsDir",
doc = """Automatically export generated models to this directory.""")
protected val aucType = stringParam(
name = "aucType",
doc = """Set default multinomial AUC type. Possible values are ``"AUTO"``, ``"NONE"``, ``"MACRO_OVR"``, ``"WEIGHTED_OVR"``, ``"MACRO_OVO"``, ``"WEIGHTED_OVO"``.""")
//
// Default values
//
setDefault(
ntrees -> 50,
maxDepth -> 6,
minRows -> 1.0,
minChildWeight -> 1.0,
learnRate -> 0.3,
eta -> 0.3,
sampleRate -> 1.0,
subsample -> 1.0,
colSampleRate -> 1.0,
colSampleByLevel -> 1.0,
colSampleRatePerTree -> 1.0,
colSampleByTree -> 1.0,
colSampleByNode -> 1.0,
maxAbsLeafnodePred -> 0.0f,
maxDeltaStep -> 0.0f,
scoreTreeInterval -> 0,
seed -> -1L,
minSplitImprovement -> 0.0f,
gamma -> 0.0f,
nthread -> -1,
buildTreeOneNode -> false,
saveMatrixDirectory -> null,
calibrateModel -> false,
calibrationMethod -> CalibrationMethod.AUTO.name(),
maxBins -> 256,
maxLeaves -> 0,
treeMethod -> TreeMethod.auto.name(),
growPolicy -> GrowPolicy.depthwise.name(),
booster -> Booster.gbtree.name(),
regLambda -> 1.0f,
regAlpha -> 0.0f,
quietMode -> true,
sampleType -> DartSampleType.uniform.name(),
normalizeType -> DartNormalizeType.tree.name(),
rateDrop -> 0.0f,
oneDrop -> false,
skipDrop -> 0.0f,
dmatrixType -> DMatrixType.auto.name(),
backend -> Backend.auto.name(),
gpuId -> null,
interactionConstraints -> null,
scalePosWeight -> 1.0f,
evalMetric -> null,
scoreEvalMetricOnly -> false,
modelId -> null,
nfolds -> 0,
keepCrossValidationModels -> true,
keepCrossValidationPredictions -> false,
keepCrossValidationFoldAssignment -> false,
parallelizeCrossValidation -> true,
distribution -> DistributionFamily.AUTO.name(),
tweediePower -> 1.5,
labelCol -> "label",
weightCol -> null,
offsetCol -> null,
foldCol -> null,
foldAssignment -> FoldAssignmentScheme.AUTO.name(),
categoricalEncoding -> CategoricalEncodingScheme.AUTO.name(),
ignoreConstCols -> true,
scoreEachIteration -> false,
stoppingRounds -> 0,
maxRuntimeSecs -> 0.0,
stoppingMetric -> StoppingMetric.AUTO.name(),
stoppingTolerance -> 0.001,
gainsliftBins -> -1,
customMetricFunc -> null,
exportCheckpointsDir -> null,
aucType -> MultinomialAucType.AUTO.name())
//
// Getters
//
def getNtrees(): Int = $(ntrees)
def getMaxDepth(): Int = $(maxDepth)
def getMinRows(): Double = $(minRows)
def getMinChildWeight(): Double = $(minChildWeight)
def getLearnRate(): Double = $(learnRate)
def getEta(): Double = $(eta)
def getSampleRate(): Double = $(sampleRate)
def getSubsample(): Double = $(subsample)
def getColSampleRate(): Double = $(colSampleRate)
def getColSampleByLevel(): Double = $(colSampleByLevel)
def getColSampleRatePerTree(): Double = $(colSampleRatePerTree)
def getColSampleByTree(): Double = $(colSampleByTree)
def getColSampleByNode(): Double = $(colSampleByNode)
def getMaxAbsLeafnodePred(): Float = $(maxAbsLeafnodePred)
def getMaxDeltaStep(): Float = $(maxDeltaStep)
def getScoreTreeInterval(): Int = $(scoreTreeInterval)
def getSeed(): Long = $(seed)
def getMinSplitImprovement(): Float = $(minSplitImprovement)
def getGamma(): Float = $(gamma)
def getNthread(): Int = $(nthread)
def getBuildTreeOneNode(): Boolean = $(buildTreeOneNode)
def getSaveMatrixDirectory(): String = $(saveMatrixDirectory)
def getCalibrateModel(): Boolean = $(calibrateModel)
def getCalibrationMethod(): String = $(calibrationMethod)
def getMaxBins(): Int = $(maxBins)
def getMaxLeaves(): Int = $(maxLeaves)
def getTreeMethod(): String = $(treeMethod)
def getGrowPolicy(): String = $(growPolicy)
def getBooster(): String = $(booster)
def getRegLambda(): Float = $(regLambda)
def getRegAlpha(): Float = $(regAlpha)
def getQuietMode(): Boolean = $(quietMode)
def getSampleType(): String = $(sampleType)
def getNormalizeType(): String = $(normalizeType)
def getRateDrop(): Float = $(rateDrop)
def getOneDrop(): Boolean = $(oneDrop)
def getSkipDrop(): Float = $(skipDrop)
def getDmatrixType(): String = $(dmatrixType)
def getBackend(): String = $(backend)
def getGpuId(): Array[Int] = $(gpuId)
def getInteractionConstraints(): Array[Array[String]] = $(interactionConstraints)
def getScalePosWeight(): Float = $(scalePosWeight)
def getEvalMetric(): String = $(evalMetric)
def getScoreEvalMetricOnly(): Boolean = $(scoreEvalMetricOnly)
def getModelId(): String = $(modelId)
def getNfolds(): Int = $(nfolds)
def getKeepCrossValidationModels(): Boolean = $(keepCrossValidationModels)
def getKeepCrossValidationPredictions(): Boolean = $(keepCrossValidationPredictions)
def getKeepCrossValidationFoldAssignment(): Boolean = $(keepCrossValidationFoldAssignment)
def getParallelizeCrossValidation(): Boolean = $(parallelizeCrossValidation)
def getDistribution(): String = $(distribution)
def getTweediePower(): Double = $(tweediePower)
def getLabelCol(): String = $(labelCol)
def getWeightCol(): String = $(weightCol)
def getOffsetCol(): String = $(offsetCol)
def getFoldCol(): String = $(foldCol)
def getFoldAssignment(): String = $(foldAssignment)
def getCategoricalEncoding(): String = $(categoricalEncoding)
def getIgnoreConstCols(): Boolean = $(ignoreConstCols)
def getScoreEachIteration(): Boolean = $(scoreEachIteration)
def getStoppingRounds(): Int = $(stoppingRounds)
def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)
def getStoppingMetric(): String = $(stoppingMetric)
def getStoppingTolerance(): Double = $(stoppingTolerance)
def getGainsliftBins(): Int = $(gainsliftBins)
def getCustomMetricFunc(): String = $(customMetricFunc)
def getExportCheckpointsDir(): String = $(exportCheckpointsDir)
def getAucType(): String = $(aucType)
//
// Setters
//
def setNtrees(value: Int): this.type = {
set(ntrees, value)
}
def setMaxDepth(value: Int): this.type = {
set(maxDepth, value)
}
def setMinRows(value: Double): this.type = {
set(minRows, value)
}
def setMinChildWeight(value: Double): this.type = {
set(minChildWeight, value)
}
def setLearnRate(value: Double): this.type = {
set(learnRate, value)
}
def setEta(value: Double): this.type = {
set(eta, value)
}
def setSampleRate(value: Double): this.type = {
set(sampleRate, value)
}
def setSubsample(value: Double): this.type = {
set(subsample, value)
}
def setColSampleRate(value: Double): this.type = {
set(colSampleRate, value)
}
def setColSampleByLevel(value: Double): this.type = {
set(colSampleByLevel, value)
}
def setColSampleRatePerTree(value: Double): this.type = {
set(colSampleRatePerTree, value)
}
def setColSampleByTree(value: Double): this.type = {
set(colSampleByTree, value)
}
def setColSampleByNode(value: Double): this.type = {
set(colSampleByNode, value)
}
def setMaxAbsLeafnodePred(value: Float): this.type = {
set(maxAbsLeafnodePred, value)
}
def setMaxDeltaStep(value: Float): this.type = {
set(maxDeltaStep, value)
}
def setScoreTreeInterval(value: Int): this.type = {
set(scoreTreeInterval, value)
}
def setSeed(value: Long): this.type = {
set(seed, value)
}
def setMinSplitImprovement(value: Float): this.type = {
set(minSplitImprovement, value)
}
def setGamma(value: Float): this.type = {
set(gamma, value)
}
def setNthread(value: Int): this.type = {
set(nthread, value)
}
def setBuildTreeOneNode(value: Boolean): this.type = {
set(buildTreeOneNode, value)
}
def setSaveMatrixDirectory(value: String): this.type = {
set(saveMatrixDirectory, value)
}
def setCalibrateModel(value: Boolean): this.type = {
set(calibrateModel, value)
}
def setCalibrationMethod(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[CalibrationMethod](value)
set(calibrationMethod, validated)
}
def setMaxBins(value: Int): this.type = {
set(maxBins, value)
}
def setMaxLeaves(value: Int): this.type = {
set(maxLeaves, value)
}
def setTreeMethod(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[TreeMethod](value)
set(treeMethod, validated)
}
def setGrowPolicy(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[GrowPolicy](value)
set(growPolicy, validated)
}
def setBooster(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[Booster](value)
set(booster, validated)
}
def setRegLambda(value: Float): this.type = {
set(regLambda, value)
}
def setRegAlpha(value: Float): this.type = {
set(regAlpha, value)
}
def setQuietMode(value: Boolean): this.type = {
set(quietMode, value)
}
def setSampleType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[DartSampleType](value)
set(sampleType, validated)
}
def setNormalizeType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[DartNormalizeType](value)
set(normalizeType, validated)
}
def setRateDrop(value: Float): this.type = {
set(rateDrop, value)
}
def setOneDrop(value: Boolean): this.type = {
set(oneDrop, value)
}
def setSkipDrop(value: Float): this.type = {
set(skipDrop, value)
}
def setDmatrixType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[DMatrixType](value)
set(dmatrixType, validated)
}
def setBackend(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[Backend](value)
set(backend, validated)
}
def setGpuId(value: Array[Int]): this.type = {
set(gpuId, value)
}
def setInteractionConstraints(value: Array[Array[String]]): this.type = {
set(interactionConstraints, value)
}
def setScalePosWeight(value: Float): this.type = {
set(scalePosWeight, value)
}
def setEvalMetric(value: String): this.type = {
set(evalMetric, value)
}
def setScoreEvalMetricOnly(value: Boolean): this.type = {
set(scoreEvalMetricOnly, value)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setNfolds(value: Int): this.type = {
set(nfolds, value)
}
def setKeepCrossValidationModels(value: Boolean): this.type = {
set(keepCrossValidationModels, value)
}
def setKeepCrossValidationPredictions(value: Boolean): this.type = {
set(keepCrossValidationPredictions, value)
}
def setKeepCrossValidationFoldAssignment(value: Boolean): this.type = {
set(keepCrossValidationFoldAssignment, value)
}
def setParallelizeCrossValidation(value: Boolean): this.type = {
set(parallelizeCrossValidation, value)
}
def setDistribution(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[DistributionFamily](value)
set(distribution, validated)
}
def setTweediePower(value: Double): this.type = {
set(tweediePower, value)
}
def setLabelCol(value: String): this.type = {
set(labelCol, value)
}
def setWeightCol(value: String): this.type = {
set(weightCol, value)
}
def setOffsetCol(value: String): this.type = {
set(offsetCol, value)
}
def setFoldCol(value: String): this.type = {
set(foldCol, value)
}
def setFoldAssignment(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[FoldAssignmentScheme](value)
set(foldAssignment, validated)
}
def setCategoricalEncoding(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[CategoricalEncodingScheme](value)
set(categoricalEncoding, validated)
}
def setIgnoreConstCols(value: Boolean): this.type = {
set(ignoreConstCols, value)
}
def setScoreEachIteration(value: Boolean): this.type = {
set(scoreEachIteration, value)
}
def setStoppingRounds(value: Int): this.type = {
set(stoppingRounds, value)
}
def setMaxRuntimeSecs(value: Double): this.type = {
set(maxRuntimeSecs, value)
}
def setStoppingMetric(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[StoppingMetric](value)
set(stoppingMetric, validated)
}
def setStoppingTolerance(value: Double): this.type = {
set(stoppingTolerance, value)
}
def setGainsliftBins(value: Int): this.type = {
set(gainsliftBins, value)
}
def setCustomMetricFunc(value: String): this.type = {
set(customMetricFunc, value)
}
def setExportCheckpointsDir(value: String): this.type = {
set(exportCheckpointsDir, value)
}
def setAucType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[MultinomialAucType](value)
set(aucType, validated)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2OXGBoostParams(trainingFrame)
}
private[sparkling] def getH2OXGBoostParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"ntrees" -> getNtrees(),
"max_depth" -> getMaxDepth(),
"min_rows" -> getMinRows(),
"min_child_weight" -> getMinChildWeight(),
"learn_rate" -> getLearnRate(),
"eta" -> getEta(),
"sample_rate" -> getSampleRate(),
"subsample" -> getSubsample(),
"col_sample_rate" -> getColSampleRate(),
"colsample_bylevel" -> getColSampleByLevel(),
"col_sample_rate_per_tree" -> getColSampleRatePerTree(),
"colsample_bytree" -> getColSampleByTree(),
"colsample_bynode" -> getColSampleByNode(),
"max_abs_leafnode_pred" -> getMaxAbsLeafnodePred(),
"max_delta_step" -> getMaxDeltaStep(),
"score_tree_interval" -> getScoreTreeInterval(),
"seed" -> getSeed(),
"min_split_improvement" -> getMinSplitImprovement(),
"gamma" -> getGamma(),
"nthread" -> getNthread(),
"build_tree_one_node" -> getBuildTreeOneNode(),
"save_matrix_directory" -> getSaveMatrixDirectory(),
"calibrate_model" -> getCalibrateModel(),
"calibration_method" -> getCalibrationMethod(),
"max_bins" -> getMaxBins(),
"max_leaves" -> getMaxLeaves(),
"tree_method" -> getTreeMethod(),
"grow_policy" -> getGrowPolicy(),
"booster" -> getBooster(),
"reg_lambda" -> getRegLambda(),
"reg_alpha" -> getRegAlpha(),
"quiet_mode" -> getQuietMode(),
"sample_type" -> getSampleType(),
"normalize_type" -> getNormalizeType(),
"rate_drop" -> getRateDrop(),
"one_drop" -> getOneDrop(),
"skip_drop" -> getSkipDrop(),
"dmatrix_type" -> getDmatrixType(),
"backend" -> getBackend(),
"gpu_id" -> getGpuId(),
"interaction_constraints" -> getInteractionConstraints(),
"scale_pos_weight" -> getScalePosWeight(),
"eval_metric" -> getEvalMetric(),
"score_eval_metric_only" -> getScoreEvalMetricOnly(),
"model_id" -> getModelId(),
"nfolds" -> getNfolds(),
"keep_cross_validation_models" -> getKeepCrossValidationModels(),
"keep_cross_validation_predictions" -> getKeepCrossValidationPredictions(),
"keep_cross_validation_fold_assignment" -> getKeepCrossValidationFoldAssignment(),
"parallelize_cross_validation" -> getParallelizeCrossValidation(),
"distribution" -> getDistribution(),
"tweedie_power" -> getTweediePower(),
"response_column" -> getLabelCol(),
"weights_column" -> getWeightCol(),
"offset_column" -> getOffsetCol(),
"fold_column" -> getFoldCol(),
"fold_assignment" -> getFoldAssignment(),
"categorical_encoding" -> getCategoricalEncoding(),
"ignore_const_cols" -> getIgnoreConstCols(),
"score_each_iteration" -> getScoreEachIteration(),
"stopping_rounds" -> getStoppingRounds(),
"max_runtime_secs" -> getMaxRuntimeSecs(),
"stopping_metric" -> getStoppingMetric(),
"stopping_tolerance" -> getStoppingTolerance(),
"gainslift_bins" -> getGainsliftBins(),
"custom_metric_func" -> getCustomMetricFunc(),
"export_checkpoints_dir" -> getExportCheckpointsDir(),
"auc_type" -> getAucType()) +++
getMonotoneConstraintsParam(trainingFrame) +++
getCalibrationDataFrameParam(trainingFrame) +++
getIgnoredColsParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"ntrees" -> "ntrees",
"maxDepth" -> "max_depth",
"minRows" -> "min_rows",
"minChildWeight" -> "min_child_weight",
"learnRate" -> "learn_rate",
"eta" -> "eta",
"sampleRate" -> "sample_rate",
"subsample" -> "subsample",
"colSampleRate" -> "col_sample_rate",
"colSampleByLevel" -> "colsample_bylevel",
"colSampleRatePerTree" -> "col_sample_rate_per_tree",
"colSampleByTree" -> "colsample_bytree",
"colSampleByNode" -> "colsample_bynode",
"maxAbsLeafnodePred" -> "max_abs_leafnode_pred",
"maxDeltaStep" -> "max_delta_step",
"scoreTreeInterval" -> "score_tree_interval",
"seed" -> "seed",
"minSplitImprovement" -> "min_split_improvement",
"gamma" -> "gamma",
"nthread" -> "nthread",
"buildTreeOneNode" -> "build_tree_one_node",
"saveMatrixDirectory" -> "save_matrix_directory",
"calibrateModel" -> "calibrate_model",
"calibrationMethod" -> "calibration_method",
"maxBins" -> "max_bins",
"maxLeaves" -> "max_leaves",
"treeMethod" -> "tree_method",
"growPolicy" -> "grow_policy",
"booster" -> "booster",
"regLambda" -> "reg_lambda",
"regAlpha" -> "reg_alpha",
"quietMode" -> "quiet_mode",
"sampleType" -> "sample_type",
"normalizeType" -> "normalize_type",
"rateDrop" -> "rate_drop",
"oneDrop" -> "one_drop",
"skipDrop" -> "skip_drop",
"dmatrixType" -> "dmatrix_type",
"backend" -> "backend",
"gpuId" -> "gpu_id",
"interactionConstraints" -> "interaction_constraints",
"scalePosWeight" -> "scale_pos_weight",
"evalMetric" -> "eval_metric",
"scoreEvalMetricOnly" -> "score_eval_metric_only",
"modelId" -> "model_id",
"nfolds" -> "nfolds",
"keepCrossValidationModels" -> "keep_cross_validation_models",
"keepCrossValidationPredictions" -> "keep_cross_validation_predictions",
"keepCrossValidationFoldAssignment" -> "keep_cross_validation_fold_assignment",
"parallelizeCrossValidation" -> "parallelize_cross_validation",
"distribution" -> "distribution",
"tweediePower" -> "tweedie_power",
"labelCol" -> "response_column",
"weightCol" -> "weights_column",
"offsetCol" -> "offset_column",
"foldCol" -> "fold_column",
"foldAssignment" -> "fold_assignment",
"categoricalEncoding" -> "categorical_encoding",
"ignoreConstCols" -> "ignore_const_cols",
"scoreEachIteration" -> "score_each_iteration",
"stoppingRounds" -> "stopping_rounds",
"maxRuntimeSecs" -> "max_runtime_secs",
"stoppingMetric" -> "stopping_metric",
"stoppingTolerance" -> "stopping_tolerance",
"gainsliftBins" -> "gainslift_bins",
"customMetricFunc" -> "custom_metric_func",
"exportCheckpointsDir" -> "export_checkpoints_dir",
"aucType" -> "auc_type")
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy