ai.h2o.sparkling.ml.params.H2ORuleFitParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.rulefit.RuleFitModel.RuleFitParameters
import ai.h2o.sparkling.H2OFrame
import hex.rulefit.RuleFitModel.Algorithm
import hex.rulefit.RuleFitModel.ModelType
import hex.genmodel.utils.DistributionFamily
import hex.MultinomialAucType
trait H2ORuleFitParams
extends H2OAlgoParamsBase
with HasUnsupportedOffsetCol
with HasIgnoredCols {
protected def paramTag = reflect.classTag[RuleFitParameters]
//
// Parameter definitions
//
protected val seed = longParam(
name = "seed",
doc = """Seed for pseudo random number generator (if applicable).""")
protected val algorithm = stringParam(
name = "algorithm",
doc = """The algorithm to use to generate rules. Possible values are ``"DRF"``, ``"GBM"``, ``"AUTO"``.""")
protected val minRuleLength = intParam(
name = "minRuleLength",
doc = """Minimum length of rules. Defaults to 3.""")
protected val maxRuleLength = intParam(
name = "maxRuleLength",
doc = """Maximum length of rules. Defaults to 3.""")
protected val maxNumRules = intParam(
name = "maxNumRules",
doc = """The maximum number of rules to return. defaults to -1 which means the number of rules is selected
by diminishing returns in model deviance.""")
protected val modelType = stringParam(
name = "modelType",
doc = """Specifies type of base learners in the ensemble. Possible values are ``"RULES"``, ``"RULES_AND_LINEAR"``, ``"LINEAR"``.""")
protected val ruleGenerationNtrees = intParam(
name = "ruleGenerationNtrees",
doc = """Specifies the number of trees to build in the tree model. Defaults to 50.""")
protected val removeDuplicates = booleanParam(
name = "removeDuplicates",
doc = """Whether to remove rules which are identical to an earlier rule. Defaults to true.""")
protected val lambdaValue = nullableDoubleArrayParam(
name = "lambdaValue",
doc = """Lambda for LASSO regressor.""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val distribution = stringParam(
name = "distribution",
doc = """Distribution function. Possible values are ``"AUTO"``, ``"bernoulli"``, ``"quasibinomial"``, ``"modified_huber"``, ``"multinomial"``, ``"ordinal"``, ``"gaussian"``, ``"poisson"``, ``"gamma"``, ``"tweedie"``, ``"huber"``, ``"laplace"``, ``"quantile"``, ``"fractionalbinomial"``, ``"negativebinomial"``, ``"custom"``.""")
protected val labelCol = stringParam(
name = "labelCol",
doc = """Response variable column.""")
protected val weightCol = nullableStringParam(
name = "weightCol",
doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")
protected val maxCategoricalLevels = intParam(
name = "maxCategoricalLevels",
doc = """For every categorical feature, only use this many most frequent categorical levels for model training. Only used for categorical_encoding == EnumLimited.""")
protected val aucType = stringParam(
name = "aucType",
doc = """Set default multinomial AUC type. Possible values are ``"AUTO"``, ``"NONE"``, ``"MACRO_OVR"``, ``"WEIGHTED_OVR"``, ``"MACRO_OVO"``, ``"WEIGHTED_OVO"``.""")
//
// Default values
//
setDefault(
seed -> -1L,
algorithm -> Algorithm.AUTO.name(),
minRuleLength -> 3,
maxRuleLength -> 3,
maxNumRules -> -1,
modelType -> ModelType.RULES_AND_LINEAR.name(),
ruleGenerationNtrees -> 50,
removeDuplicates -> true,
lambdaValue -> null,
modelId -> null,
distribution -> DistributionFamily.AUTO.name(),
labelCol -> "label",
weightCol -> null,
maxCategoricalLevels -> 10,
aucType -> MultinomialAucType.AUTO.name())
//
// Getters
//
def getSeed(): Long = $(seed)
def getAlgorithm(): String = $(algorithm)
def getMinRuleLength(): Int = $(minRuleLength)
def getMaxRuleLength(): Int = $(maxRuleLength)
def getMaxNumRules(): Int = $(maxNumRules)
def getModelType(): String = $(modelType)
def getRuleGenerationNtrees(): Int = $(ruleGenerationNtrees)
def getRemoveDuplicates(): Boolean = $(removeDuplicates)
def getLambdaValue(): Array[Double] = $(lambdaValue)
def getModelId(): String = $(modelId)
def getDistribution(): String = $(distribution)
def getLabelCol(): String = $(labelCol)
def getWeightCol(): String = $(weightCol)
def getMaxCategoricalLevels(): Int = $(maxCategoricalLevels)
def getAucType(): String = $(aucType)
//
// Setters
//
def setSeed(value: Long): this.type = {
set(seed, value)
}
def setAlgorithm(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[Algorithm](value)
set(algorithm, validated)
}
def setMinRuleLength(value: Int): this.type = {
set(minRuleLength, value)
}
def setMaxRuleLength(value: Int): this.type = {
set(maxRuleLength, value)
}
def setMaxNumRules(value: Int): this.type = {
set(maxNumRules, value)
}
def setModelType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[ModelType](value)
set(modelType, validated)
}
def setRuleGenerationNtrees(value: Int): this.type = {
set(ruleGenerationNtrees, value)
}
def setRemoveDuplicates(value: Boolean): this.type = {
set(removeDuplicates, value)
}
def setLambdaValue(value: Array[Double]): this.type = {
set(lambdaValue, value)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setDistribution(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[DistributionFamily](value)
set(distribution, validated)
}
def setLabelCol(value: String): this.type = {
set(labelCol, value)
}
def setWeightCol(value: String): this.type = {
set(weightCol, value)
}
def setMaxCategoricalLevels(value: Int): this.type = {
set(maxCategoricalLevels, value)
}
def setAucType(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[MultinomialAucType](value)
set(aucType, validated)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2ORuleFitParams(trainingFrame)
}
private[sparkling] def getH2ORuleFitParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"seed" -> getSeed(),
"algorithm" -> getAlgorithm(),
"min_rule_length" -> getMinRuleLength(),
"max_rule_length" -> getMaxRuleLength(),
"max_num_rules" -> getMaxNumRules(),
"model_type" -> getModelType(),
"rule_generation_ntrees" -> getRuleGenerationNtrees(),
"remove_duplicates" -> getRemoveDuplicates(),
"lambda" -> getLambdaValue(),
"model_id" -> getModelId(),
"distribution" -> getDistribution(),
"response_column" -> getLabelCol(),
"weights_column" -> getWeightCol(),
"max_categorical_levels" -> getMaxCategoricalLevels(),
"auc_type" -> getAucType()) +++
getUnsupportedOffsetColParam(trainingFrame) +++
getIgnoredColsParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"seed" -> "seed",
"algorithm" -> "algorithm",
"minRuleLength" -> "min_rule_length",
"maxRuleLength" -> "max_rule_length",
"maxNumRules" -> "max_num_rules",
"modelType" -> "model_type",
"ruleGenerationNtrees" -> "rule_generation_ntrees",
"removeDuplicates" -> "remove_duplicates",
"lambdaValue" -> "lambda",
"modelId" -> "model_id",
"distribution" -> "distribution",
"labelCol" -> "response_column",
"weightCol" -> "weights_column",
"maxCategoricalLevels" -> "max_categorical_levels",
"aucType" -> "auc_type")
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy