ai.h2o.sparkling.ml.params.H2OExtendedIsolationForestParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.tree.isoforextended.ExtendedIsolationForestModel.ExtendedIsolationForestParameters
import ai.h2o.sparkling.H2OFrame
import hex.Model.Parameters.CategoricalEncodingScheme
trait H2OExtendedIsolationForestParams
extends H2OAlgoParamsBase
with HasIgnoredCols {
protected def paramTag = reflect.classTag[ExtendedIsolationForestParameters]
//
// Parameter definitions
//
protected val ntrees = intParam(
name = "ntrees",
doc = """Number of Extended Isolation Forest trees.""")
protected val sampleSize = intParam(
name = "sampleSize",
doc = """Number of randomly sampled observations used to train each Extended Isolation Forest tree.""")
protected val extensionLevel = intParam(
name = "extensionLevel",
doc = """Maximum is N - 1 (N = numCols). Minimum is 0. Extended Isolation Forest with extension_Level = 0 behaves like Isolation Forest.""")
protected val seed = longParam(
name = "seed",
doc = """Seed for pseudo random number generator (if applicable).""")
protected val scoreTreeInterval = intParam(
name = "scoreTreeInterval",
doc = """Score the model after every so many trees. Disabled if set to 0.""")
protected val disableTrainingMetrics = booleanParam(
name = "disableTrainingMetrics",
doc = """Disable calculating training metrics (expensive on large datasets).""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val categoricalEncoding = stringParam(
name = "categoricalEncoding",
doc = """Encoding scheme for categorical features. Possible values are ``"AUTO"``, ``"OneHotInternal"``, ``"OneHotExplicit"``, ``"Enum"``, ``"Binary"``, ``"Eigen"``, ``"LabelEncoder"``, ``"SortByResponse"``, ``"EnumLimited"``.""")
protected val ignoreConstCols = booleanParam(
name = "ignoreConstCols",
doc = """Ignore constant columns.""")
protected val scoreEachIteration = booleanParam(
name = "scoreEachIteration",
doc = """Whether to score during each iteration of model training.""")
//
// Default values
//
setDefault(
ntrees -> 100,
sampleSize -> 256,
extensionLevel -> 0,
seed -> -1L,
scoreTreeInterval -> 0,
disableTrainingMetrics -> true,
modelId -> null,
categoricalEncoding -> CategoricalEncodingScheme.AUTO.name(),
ignoreConstCols -> true,
scoreEachIteration -> false)
//
// Getters
//
def getNtrees(): Int = $(ntrees)
def getSampleSize(): Int = $(sampleSize)
def getExtensionLevel(): Int = $(extensionLevel)
def getSeed(): Long = $(seed)
def getScoreTreeInterval(): Int = $(scoreTreeInterval)
def getDisableTrainingMetrics(): Boolean = $(disableTrainingMetrics)
def getModelId(): String = $(modelId)
def getCategoricalEncoding(): String = $(categoricalEncoding)
def getIgnoreConstCols(): Boolean = $(ignoreConstCols)
def getScoreEachIteration(): Boolean = $(scoreEachIteration)
//
// Setters
//
def setNtrees(value: Int): this.type = {
set(ntrees, value)
}
def setSampleSize(value: Int): this.type = {
set(sampleSize, value)
}
def setExtensionLevel(value: Int): this.type = {
set(extensionLevel, value)
}
def setSeed(value: Long): this.type = {
set(seed, value)
}
def setScoreTreeInterval(value: Int): this.type = {
set(scoreTreeInterval, value)
}
def setDisableTrainingMetrics(value: Boolean): this.type = {
set(disableTrainingMetrics, value)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setCategoricalEncoding(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[CategoricalEncodingScheme](value)
set(categoricalEncoding, validated)
}
def setIgnoreConstCols(value: Boolean): this.type = {
set(ignoreConstCols, value)
}
def setScoreEachIteration(value: Boolean): this.type = {
set(scoreEachIteration, value)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2OExtendedIsolationForestParams(trainingFrame)
}
private[sparkling] def getH2OExtendedIsolationForestParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"ntrees" -> getNtrees(),
"sample_size" -> getSampleSize(),
"extension_level" -> getExtensionLevel(),
"seed" -> getSeed(),
"score_tree_interval" -> getScoreTreeInterval(),
"disable_training_metrics" -> getDisableTrainingMetrics(),
"model_id" -> getModelId(),
"categorical_encoding" -> getCategoricalEncoding(),
"ignore_const_cols" -> getIgnoreConstCols(),
"score_each_iteration" -> getScoreEachIteration()) +++
getIgnoredColsParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"ntrees" -> "ntrees",
"sampleSize" -> "sample_size",
"extensionLevel" -> "extension_level",
"seed" -> "seed",
"scoreTreeInterval" -> "score_tree_interval",
"disableTrainingMetrics" -> "disable_training_metrics",
"modelId" -> "model_id",
"categoricalEncoding" -> "categorical_encoding",
"ignoreConstCols" -> "ignore_const_cols",
"scoreEachIteration" -> "score_each_iteration")
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy