ai.h2o.sparkling.ml.params.H2ODRFParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.tree.drf.DRFModel.DRFParameters
import ai.h2o.sparkling.H2OFrame
import hex.tree.SharedTreeModel.SharedTreeParameters.HistogramType
import hex.tree.CalibrationHelper.CalibrationMethod
import hex.genmodel.utils.DistributionFamily
import hex.Model.Parameters.FoldAssignmentScheme
import hex.Model.Parameters.CategoricalEncodingScheme
import hex.ScoreKeeper.StoppingMetric
import hex.MultinomialAucType
trait H2ODRFParams
extends H2OAlgoParamsBase
with HasCalibrationDataFrame
with HasIgnoredCols {
protected def paramTag = reflect.classTag[DRFParameters]
//
// Parameter definitions
//
protected val mtries = intParam(
name = "mtries",
doc = """Number of variables randomly sampled as candidates at each split. If set to -1, defaults to sqrt{p} for classification and p/3 for regression (where p is the # of predictors.""")
protected val binomialDoubleTrees = booleanParam(
name = "binomialDoubleTrees",
doc = """For binary classification: Build 2x as many trees (one per class) - can lead to higher accuracy.""")
protected val sampleRate = doubleParam(
name = "sampleRate",
doc = """Row sample rate per tree (from 0.0 to 1.0).""")
protected val balanceClasses = booleanParam(
name = "balanceClasses",
doc = """Balance training data class counts via over/under-sampling (for imbalanced data).""")
protected val classSamplingFactors = nullableFloatArrayParam(
name = "classSamplingFactors",
doc = """Desired over/under-sampling ratios per class (in lexicographic order). If not specified, sampling factors will be automatically computed to obtain class balance during training. Requires balance_classes.""")
protected val maxAfterBalanceSize = floatParam(
name = "maxAfterBalanceSize",
doc = """Maximum relative size of the training data after balancing class counts (can be less than 1.0). Requires balance_classes.""")
protected val maxConfusionMatrixSize = intParam(
name = "maxConfusionMatrixSize",
doc = """[Deprecated] Maximum size (# classes) for confusion matrices to be printed in the Logs.""")
protected val ntrees = intParam(
name = "ntrees",
doc = """Number of trees.""")
protected val maxDepth = intParam(
name = "maxDepth",
doc = """Maximum tree depth (0 for unlimited).""")
protected val minRows = doubleParam(
name = "minRows",
doc = """Fewest allowed (weighted) observations in a leaf.""")
protected val nbins = intParam(
name = "nbins",
doc = """For numerical columns (real/int), build a histogram of (at least) this many bins, then split at the best point.""")
protected val nbinsTopLevel = intParam(
name = "nbinsTopLevel",
doc = """For numerical columns (real/int), build a histogram of (at most) this many bins at the root level, then decrease by factor of two per level.""")
protected val nbinsCats = intParam(
name = "nbinsCats",
doc = """For categorical columns (factors), build a histogram of this many bins, then split at the best point. Higher values can lead to more overfitting.""")
protected val seed = longParam(
name = "seed",
doc = """Seed for pseudo random number generator (if applicable).""")
protected val buildTreeOneNode = booleanParam(
name = "buildTreeOneNode",
doc = """Run on one node only; no network overhead but fewer cpus used. Suitable for small datasets.""")
protected val sampleRatePerClass = nullableDoubleArrayParam(
name = "sampleRatePerClass",
doc = """A list of row sample rates per class (relative fraction for each class, from 0.0 to 1.0), for each tree.""")
protected val colSampleRatePerTree = doubleParam(
name = "colSampleRatePerTree",
doc = """Column sample rate per tree (from 0.0 to 1.0).""")
protected val colSampleRateChangePerLevel = doubleParam(
name = "colSampleRateChangePerLevel",
doc = """Relative change of the column sampling rate for every level (must be > 0.0 and <= 2.0).""")
protected val scoreTreeInterval = intParam(
name = "scoreTreeInterval",
doc = """Score the model after every so many trees. Disabled if set to 0.""")
protected val minSplitImprovement = doubleParam(
name = "minSplitImprovement",
doc = """Minimum relative improvement in squared error reduction for a split to happen.""")
protected val histogramType = stringParam(
name = "histogramType",
doc = """What type of histogram to use for finding optimal split points. Possible values are ``"AUTO"``, ``"UniformAdaptive"``, ``"Random"``, ``"QuantilesGlobal"``, ``"RoundRobin"``, ``"UniformRobust"``.""")
protected val calibrateModel = booleanParam(
name = "calibrateModel",
doc = """Use Platt Scaling (default) or Isotonic Regression to calculate calibrated class probabilities. Calibration can provide more accurate estimates of class probabilities.""")
protected val calibrationMethod = stringParam(
name = "calibrationMethod",
doc = """Calibration method to use. Possible values are ``"AUTO"``, ``"PlattScaling"``, ``"IsotonicRegression"``.""")
protected val checkConstantResponse = booleanParam(
name = "checkConstantResponse",
doc = """Check if response column is constant. If enabled, then an exception is thrown if the response column is a constant value.If disabled, then model will train regardless of the response column being a constant value or not.""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val nfolds = intParam(
name = "nfolds",
doc = """Number of folds for K-fold cross-validation (0 to disable or >= 2).""")
protected val keepCrossValidationModels = booleanParam(
name = "keepCrossValidationModels",
doc = """Whether to keep the cross-validation models.""")
protected val keepCrossValidationPredictions = booleanParam(
name = "keepCrossValidationPredictions",
doc = """Whether to keep the predictions of the cross-validation models.""")
protected val keepCrossValidationFoldAssignment = booleanParam(
name = "keepCrossValidationFoldAssignment",
doc = """Whether to keep the cross-validation fold assignment.""")
protected val distribution = stringParam(
name = "distribution",
doc = """Distribution function. Possible values are ``"AUTO"``, ``"bernoulli"``, ``"quasibinomial"``, ``"modified_huber"``, ``"multinomial"``, ``"ordinal"``, ``"gaussian"``, ``"poisson"``, ``"gamma"``, ``"tweedie"``, ``"huber"``, ``"laplace"``, ``"quantile"``, ``"fractionalbinomial"``, ``"negativebinomial"``, ``"custom"``.""")
protected val labelCol = stringParam(
name = "labelCol",
doc = """Response variable column.""")
protected val weightCol = nullableStringParam(
name = "weightCol",
doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")
protected val offsetCol = nullableStringParam(
name = "offsetCol",
doc = """Offset column. This will be added to the combination of columns before applying the link function.""")
protected val foldCol = nullableStringParam(
name = "foldCol",
doc = """Column with cross-validation fold index assignment per observation.""")
protected val foldAssignment = stringParam(
name = "foldAssignment",
doc = """Cross-validation fold assignment scheme, if fold_column is not specified. The 'Stratified' option will stratify the folds based on the response variable, for classification problems. Possible values are ``"AUTO"``, ``"Random"``, ``"Modulo"``, ``"Stratified"``.""")
protected val categoricalEncoding = stringParam(
name = "categoricalEncoding",
doc = """Encoding scheme for categorical features. Possible values are ``"AUTO"``, ``"OneHotInternal"``, ``"OneHotExplicit"``, ``"Enum"``, ``"Binary"``, ``"Eigen"``, ``"LabelEncoder"``, ``"SortByResponse"``, ``"EnumLimited"``.""")
protected val ignoreConstCols = booleanParam(
name = "ignoreConstCols",
doc = """Ignore constant columns.""")
protected val scoreEachIteration = booleanParam(
name = "scoreEachIteration",
doc = """Whether to score during each iteration of model training.""")
protected val stoppingRounds = intParam(
name = "stoppingRounds",
doc = """Early stopping based on convergence of stopping_metric. Stop if simple moving average of length k of the stopping_metric does not improve for k:=stopping_rounds scoring events (0 to disable).""")
protected val maxRuntimeSecs = doubleParam(
name = "maxRuntimeSecs",
doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")
protected val stoppingMetric = stringParam(
name = "stoppingMetric",
doc = """Metric to use for early stopping (AUTO: logloss for classification, deviance for regression and anomaly_score for Isolation Forest). Note that custom and custom_increasing can only be used in GBM and DRF with the Python client. Possible values are ``"AUTO"``, ``"deviance"``, ``"logloss"``, ``"MSE"``, ``"RMSE"``, ``"MAE"``, ``"RMSLE"``, ``"AUC"``, ``"AUCPR"``, ``"lift_top_group"``, ``"misclassification"``, ``"mean_per_class_error"``, ``"anomaly_score"``, ``"AUUC"``, ``"ATE"``, ``"ATT"``, ``"ATC"``, ``"qini"``, ``"custom"``, ``"custom_increasing"``.""")
protected val stoppingTolerance = doubleParam(
name = "stoppingTolerance",
doc = """Relative tolerance for metric-based stopping criterion (stop if relative improvement is not at least this much).""")
protected val gainsliftBins = intParam(
name = "gainsliftBins",
doc = """Gains/Lift table number of bins. 0 means disabled.. Default value -1 means automatic binning.""")
protected val customMetricFunc = nullableStringParam(
name = "customMetricFunc",
doc = """Reference to custom evaluation function, format: `language:keyName=funcName`.""")
protected val exportCheckpointsDir = nullableStringParam(
name = "exportCheckpointsDir",
doc = """Automatically export generated models to this directory.""")
protected val aucType = stringParam(
name = "aucType",
doc = """Set default multinomial AUC type. Possible values are ``"AUTO"``, ``"NONE"``, ``"MACRO_OVR"``, ``"WEIGHTED_OVR"``, ``"MACRO_OVO"``, ``"WEIGHTED_OVO"``.""")
//
// Default values
//
setDefault(
mtries -> -1,
binomialDoubleTrees -> false,
sampleRate -> 0.632,
balanceClasses -> false,
classSamplingFactors -> null,
maxAfterBalanceSize -> 5.0f,
maxConfusionMatrixSize -> 20,
ntrees -> 50,
maxDepth -> 20,
minRows -> 1.0,
nbins -> 20,
nbinsTopLevel -> 1024,
nbinsCats -> 1024,
seed -> -1L,
buildTreeOneNode -> false,
sampleRatePerClass -> null,
colSampleRatePerTree -> 1.0,
colSampleRateChangePerLevel -> 1.0,
scoreTreeInterval -> 0,
minSplitImprovement -> 1.0e-5,
histogramType -> HistogramType.AUTO.name(),
calibrateModel -> false,
calibrationMethod -> CalibrationMethod.AUTO.name(),
checkConstantResponse -> true,
modelId -> null,
nfolds -> 0,
keepCrossValidationModels -> true,
keepCrossValidationPredictions -> false,
keepCrossValidationFoldAssignment -> false,
distribution -> DistributionFamily.AUTO.name(),
labelCol -> "label",
weightCol -> null,
offsetCol -> null,
foldCol -> null,
foldAssignment -> FoldAssignmentScheme.AUTO.name(),
categoricalEncoding -> CategoricalEncodingScheme.AUTO.name(),
ignoreConstCols -> true,
scoreEachIteration -> false,
stoppingRounds -> 0,
maxRuntimeSecs -> 0.0,
stoppingMetric -> StoppingMetric.AUTO.name(),
stoppingTolerance -> 0.001,
gainsliftBins -> -1,
customMetricFunc -> null,
exportCheckpointsDir -> null,
aucType -> MultinomialAucType.AUTO.name())
//
// Getters
//
def getMtries(): Int = $(mtries)
def getBinomialDoubleTrees(): Boolean = $(binomialDoubleTrees)
def getSampleRate(): Double = $(sampleRate)
def getBalanceClasses(): Boolean = $(balanceClasses)
def getClassSamplingFactors(): Array[Float] = $(classSamplingFactors)
def getMaxAfterBalanceSize(): Float = $(maxAfterBalanceSize)
def getMaxConfusionMatrixSize(): Int = $(maxConfusionMatrixSize)
def getNtrees(): Int = $(ntrees)
def getMaxDepth(): Int = $(maxDepth)
def getMinRows(): Double = $(minRows)
def getNbins(): Int = $(nbins)
def getNbinsTopLevel(): Int = $(nbinsTopLevel)
def getNbinsCats(): Int = $(nbinsCats)
def getSeed(): Long = $(seed)
def getBuildTreeOneNode(): Boolean = $(buildTreeOneNode)
def getSampleRatePerClass(): Array[Double] = $(sampleRatePerClass)
def getColSampleRatePerTree(): Double = $(colSampleRatePerTree)
def getColSampleRateChangePerLevel(): Double = $(colSampleRateChangePerLevel)
def getScoreTreeInterval(): Int = $(scoreTreeInterval)
def getMinSplitImprovement(): Double = $(minSplitImprovement)
def getHistogramType(): String = $(histogramType)
def getCalibrateModel(): Boolean = $(calibrateModel)
def getCalibrationMethod(): String = $(calibrationMethod)
def getCheckConstantResponse(): Boolean = $(checkConstantResponse)
def getModelId(): String = $(modelId)
def getNfolds(): Int = $(nfolds)
def getKeepCrossValidationModels(): Boolean = $(keepCrossValidationModels)
def getKeepCrossValidationPredictions(): Boolean = $(keepCrossValidationPredictions)
def getKeepCrossValidationFoldAssignment(): Boolean = $(keepCrossValidationFoldAssignment)
def getDistribution(): String = $(distribution)
def getLabelCol(): String = $(labelCol)
def getWeightCol(): String = $(weightCol)
def getOffsetCol(): String = $(offsetCol)
def getFoldCol(): String = $(foldCol)
def getFoldAssignment(): String = $(foldAssignment)
def getCategoricalEncoding(): String = $(categoricalEncoding)
def getIgnoreConstCols(): Boolean = $(ignoreConstCols)
def getScoreEachIteration(): Boolean = $(scoreEachIteration)
def getStoppingRounds(): Int = $(stoppingRounds)
def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)
def getStoppingMetric(): String = $(stoppingMetric)
def getStoppingTolerance(): Double = $(stoppingTolerance)
def getGainsliftBins(): Int = $(gainsliftBins)
def getCustomMetricFunc(): String = $(customMetricFunc)
def getExportCheckpointsDir(): String = $(exportCheckpointsDir)
def getAucType(): String = $(aucType)
//
// Setters
//
def setMtries(value: Int): this.type = {
set(mtries, value)
}
def setBinomialDoubleTrees(value: Boolean): this.type = {
set(binomialDoubleTrees, value)
}
def setSampleRate(value: Double): this.type = {
set(sampleRate, value)
}
def setBalanceClasses(value: Boolean): this.type = {
set(balanceClasses, value)
}
def setClassSamplingFactors(value: Array[Float]): this.type = {
set(classSamplingFactors, value)
}
def setMaxAfterBalanceSize(value: Float): this.type = {
set(maxAfterBalanceSize, value)
}
def setMaxConfusionMatrixSize(value: Int): this.type = {
set(maxConfusionMatrixSize, value)
}
def setNtrees(value: Int): this.type = {
set(ntrees, value)
}
def setMaxDepth(value: Int): this.type = {
set(maxDepth, value)
}
def setMinRows(value: Double): this.type = {
set(minRows, value)
}
def setNbins(value: Int): this.type = {
set(nbins, value)
}
def setNbinsTopLevel(value: Int): this.type = {
set(nbinsTopLevel, value)
}
def setNbinsCats(value: Int): this.type = {
set(nbinsCats, value)
}
def setSeed(value: Long): this.type = {
set(seed, value)
}
def setBuildTreeOneNode(value: Boolean): this.type = {
set(buildTreeOneNode, value)
}
def setSampleRatePerClass(value: Array[Double]): this.type = {
set(sampleRatePerClass, value)
}
def setColSampleRatePerTree(value: Double): this.type = {
set(colSampleRatePerTree, value)
}
def setColSampleRateChangePerLevel(value: Double): this.type = {
set(colSampleRateChangePerLevel, value)
}
def setScoreTreeInterval(value: Int): this.type = {
set(scoreTreeInterval, value)
}
def setMinSplitImprovement(value: Double): this.type = {
set(minSplitImprovement, value)
}
def setHistogramType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[HistogramType](value)
set(histogramType, validated)
}
def setCalibrateModel(value: Boolean): this.type = {
set(calibrateModel, value)
}
def setCalibrationMethod(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[CalibrationMethod](value)
set(calibrationMethod, validated)
}
def setCheckConstantResponse(value: Boolean): this.type = {
set(checkConstantResponse, value)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setNfolds(value: Int): this.type = {
set(nfolds, value)
}
def setKeepCrossValidationModels(value: Boolean): this.type = {
set(keepCrossValidationModels, value)
}
def setKeepCrossValidationPredictions(value: Boolean): this.type = {
set(keepCrossValidationPredictions, value)
}
def setKeepCrossValidationFoldAssignment(value: Boolean): this.type = {
set(keepCrossValidationFoldAssignment, value)
}
def setDistribution(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[DistributionFamily](value)
set(distribution, validated)
}
def setLabelCol(value: String): this.type = {
set(labelCol, value)
}
def setWeightCol(value: String): this.type = {
set(weightCol, value)
}
def setOffsetCol(value: String): this.type = {
set(offsetCol, value)
}
def setFoldCol(value: String): this.type = {
set(foldCol, value)
}
def setFoldAssignment(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[FoldAssignmentScheme](value)
set(foldAssignment, validated)
}
def setCategoricalEncoding(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[CategoricalEncodingScheme](value)
set(categoricalEncoding, validated)
}
def setIgnoreConstCols(value: Boolean): this.type = {
set(ignoreConstCols, value)
}
def setScoreEachIteration(value: Boolean): this.type = {
set(scoreEachIteration, value)
}
def setStoppingRounds(value: Int): this.type = {
set(stoppingRounds, value)
}
def setMaxRuntimeSecs(value: Double): this.type = {
set(maxRuntimeSecs, value)
}
def setStoppingMetric(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[StoppingMetric](value)
set(stoppingMetric, validated)
}
def setStoppingTolerance(value: Double): this.type = {
set(stoppingTolerance, value)
}
def setGainsliftBins(value: Int): this.type = {
set(gainsliftBins, value)
}
def setCustomMetricFunc(value: String): this.type = {
set(customMetricFunc, value)
}
def setExportCheckpointsDir(value: String): this.type = {
set(exportCheckpointsDir, value)
}
def setAucType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[MultinomialAucType](value)
set(aucType, validated)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2ODRFParams(trainingFrame)
}
private[sparkling] def getH2ODRFParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"mtries" -> getMtries(),
"binomial_double_trees" -> getBinomialDoubleTrees(),
"sample_rate" -> getSampleRate(),
"balance_classes" -> getBalanceClasses(),
"class_sampling_factors" -> getClassSamplingFactors(),
"max_after_balance_size" -> getMaxAfterBalanceSize(),
"max_confusion_matrix_size" -> getMaxConfusionMatrixSize(),
"ntrees" -> getNtrees(),
"max_depth" -> getMaxDepth(),
"min_rows" -> getMinRows(),
"nbins" -> getNbins(),
"nbins_top_level" -> getNbinsTopLevel(),
"nbins_cats" -> getNbinsCats(),
"seed" -> getSeed(),
"build_tree_one_node" -> getBuildTreeOneNode(),
"sample_rate_per_class" -> getSampleRatePerClass(),
"col_sample_rate_per_tree" -> getColSampleRatePerTree(),
"col_sample_rate_change_per_level" -> getColSampleRateChangePerLevel(),
"score_tree_interval" -> getScoreTreeInterval(),
"min_split_improvement" -> getMinSplitImprovement(),
"histogram_type" -> getHistogramType(),
"calibrate_model" -> getCalibrateModel(),
"calibration_method" -> getCalibrationMethod(),
"check_constant_response" -> getCheckConstantResponse(),
"model_id" -> getModelId(),
"nfolds" -> getNfolds(),
"keep_cross_validation_models" -> getKeepCrossValidationModels(),
"keep_cross_validation_predictions" -> getKeepCrossValidationPredictions(),
"keep_cross_validation_fold_assignment" -> getKeepCrossValidationFoldAssignment(),
"distribution" -> getDistribution(),
"response_column" -> getLabelCol(),
"weights_column" -> getWeightCol(),
"offset_column" -> getOffsetCol(),
"fold_column" -> getFoldCol(),
"fold_assignment" -> getFoldAssignment(),
"categorical_encoding" -> getCategoricalEncoding(),
"ignore_const_cols" -> getIgnoreConstCols(),
"score_each_iteration" -> getScoreEachIteration(),
"stopping_rounds" -> getStoppingRounds(),
"max_runtime_secs" -> getMaxRuntimeSecs(),
"stopping_metric" -> getStoppingMetric(),
"stopping_tolerance" -> getStoppingTolerance(),
"gainslift_bins" -> getGainsliftBins(),
"custom_metric_func" -> getCustomMetricFunc(),
"export_checkpoints_dir" -> getExportCheckpointsDir(),
"auc_type" -> getAucType()) +++
getCalibrationDataFrameParam(trainingFrame) +++
getIgnoredColsParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"mtries" -> "mtries",
"binomialDoubleTrees" -> "binomial_double_trees",
"sampleRate" -> "sample_rate",
"balanceClasses" -> "balance_classes",
"classSamplingFactors" -> "class_sampling_factors",
"maxAfterBalanceSize" -> "max_after_balance_size",
"maxConfusionMatrixSize" -> "max_confusion_matrix_size",
"ntrees" -> "ntrees",
"maxDepth" -> "max_depth",
"minRows" -> "min_rows",
"nbins" -> "nbins",
"nbinsTopLevel" -> "nbins_top_level",
"nbinsCats" -> "nbins_cats",
"seed" -> "seed",
"buildTreeOneNode" -> "build_tree_one_node",
"sampleRatePerClass" -> "sample_rate_per_class",
"colSampleRatePerTree" -> "col_sample_rate_per_tree",
"colSampleRateChangePerLevel" -> "col_sample_rate_change_per_level",
"scoreTreeInterval" -> "score_tree_interval",
"minSplitImprovement" -> "min_split_improvement",
"histogramType" -> "histogram_type",
"calibrateModel" -> "calibrate_model",
"calibrationMethod" -> "calibration_method",
"checkConstantResponse" -> "check_constant_response",
"modelId" -> "model_id",
"nfolds" -> "nfolds",
"keepCrossValidationModels" -> "keep_cross_validation_models",
"keepCrossValidationPredictions" -> "keep_cross_validation_predictions",
"keepCrossValidationFoldAssignment" -> "keep_cross_validation_fold_assignment",
"distribution" -> "distribution",
"labelCol" -> "response_column",
"weightCol" -> "weights_column",
"offsetCol" -> "offset_column",
"foldCol" -> "fold_column",
"foldAssignment" -> "fold_assignment",
"categoricalEncoding" -> "categorical_encoding",
"ignoreConstCols" -> "ignore_const_cols",
"scoreEachIteration" -> "score_each_iteration",
"stoppingRounds" -> "stopping_rounds",
"maxRuntimeSecs" -> "max_runtime_secs",
"stoppingMetric" -> "stopping_metric",
"stoppingTolerance" -> "stopping_tolerance",
"gainsliftBins" -> "gainslift_bins",
"customMetricFunc" -> "custom_metric_func",
"exportCheckpointsDir" -> "export_checkpoints_dir",
"aucType" -> "auc_type")
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy