ai.h2o.sparkling.ml.params.H2OCoxPHParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.coxph.CoxPHModel.CoxPHParameters
import ai.h2o.sparkling.H2OFrame
import hex.coxph.CoxPHModel.CoxPHParameters.CoxPHTies
trait H2OCoxPHParams
extends H2OAlgoParamsBase
with HasIgnoredCols
with HasInteractionPairs {
protected def paramTag = reflect.classTag[CoxPHParameters]
//
// Parameter definitions
//
protected val startCol = nullableStringParam(
name = "startCol",
doc = """Start Time Column.""")
protected val stopCol = nullableStringParam(
name = "stopCol",
doc = """Stop Time Column.""")
protected val stratifyBy = nullableStringArrayParam(
name = "stratifyBy",
doc = """List of columns to use for stratification.""")
protected val ties = stringParam(
name = "ties",
doc = """Method for Handling Ties. Possible values are ``"efron"``, ``"breslow"``.""")
protected val init = doubleParam(
name = "init",
doc = """Coefficient starting value.""")
protected val lreMin = doubleParam(
name = "lreMin",
doc = """Minimum log-relative error.""")
protected val maxIterations = intParam(
name = "maxIterations",
doc = """Maximum number of iterations.""")
protected val interactionsOnly = nullableStringArrayParam(
name = "interactionsOnly",
doc = """A list of columns that should only be used to create interactions but should not itself participate in model training.""")
protected val interactions = nullableStringArrayParam(
name = "interactions",
doc = """A list of predictor column indices to interact. All pairwise combinations will be computed for the list.""")
protected val useAllFactorLevels = booleanParam(
name = "useAllFactorLevels",
doc = """(Internal. For development only!) Indicates whether to use all factor levels.""")
protected val singleNodeMode = booleanParam(
name = "singleNodeMode",
doc = """Run on a single node to reduce the effect of network overhead (for smaller datasets).""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val labelCol = stringParam(
name = "labelCol",
doc = """Response variable column.""")
protected val weightCol = nullableStringParam(
name = "weightCol",
doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")
protected val offsetCol = nullableStringParam(
name = "offsetCol",
doc = """Offset column. This will be added to the combination of columns before applying the link function.""")
protected val exportCheckpointsDir = nullableStringParam(
name = "exportCheckpointsDir",
doc = """Automatically export generated models to this directory.""")
//
// Default values
//
setDefault(
startCol -> null,
stopCol -> null,
stratifyBy -> null,
ties -> CoxPHTies.efron.name(),
init -> 0.0,
lreMin -> 9.0,
maxIterations -> 20,
interactionsOnly -> null,
interactions -> null,
useAllFactorLevels -> false,
singleNodeMode -> false,
modelId -> null,
labelCol -> "label",
weightCol -> null,
offsetCol -> null,
exportCheckpointsDir -> null)
//
// Getters
//
def getStartCol(): String = $(startCol)
def getStopCol(): String = $(stopCol)
def getStratifyBy(): Array[String] = $(stratifyBy)
def getTies(): String = $(ties)
def getInit(): Double = $(init)
def getLreMin(): Double = $(lreMin)
def getMaxIterations(): Int = $(maxIterations)
def getInteractionsOnly(): Array[String] = $(interactionsOnly)
def getInteractions(): Array[String] = $(interactions)
def getUseAllFactorLevels(): Boolean = $(useAllFactorLevels)
def getSingleNodeMode(): Boolean = $(singleNodeMode)
def getModelId(): String = $(modelId)
def getLabelCol(): String = $(labelCol)
def getWeightCol(): String = $(weightCol)
def getOffsetCol(): String = $(offsetCol)
def getExportCheckpointsDir(): String = $(exportCheckpointsDir)
//
// Setters
//
def setStartCol(value: String): this.type = {
set(startCol, value)
}
def setStopCol(value: String): this.type = {
set(stopCol, value)
}
def setStratifyBy(value: Array[String]): this.type = {
set(stratifyBy, value)
}
def setTies(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[CoxPHTies](value)
set(ties, validated)
}
def setInit(value: Double): this.type = {
set(init, value)
}
def setLreMin(value: Double): this.type = {
set(lreMin, value)
}
def setMaxIterations(value: Int): this.type = {
set(maxIterations, value)
}
def setInteractionsOnly(value: Array[String]): this.type = {
set(interactionsOnly, value)
}
def setInteractions(value: Array[String]): this.type = {
set(interactions, value)
}
def setUseAllFactorLevels(value: Boolean): this.type = {
set(useAllFactorLevels, value)
}
def setSingleNodeMode(value: Boolean): this.type = {
set(singleNodeMode, value)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setLabelCol(value: String): this.type = {
set(labelCol, value)
}
def setWeightCol(value: String): this.type = {
set(weightCol, value)
}
def setOffsetCol(value: String): this.type = {
set(offsetCol, value)
}
def setExportCheckpointsDir(value: String): this.type = {
set(exportCheckpointsDir, value)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2OCoxPHParams(trainingFrame)
}
private[sparkling] def getH2OCoxPHParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"start_column" -> getStartCol(),
"stop_column" -> getStopCol(),
"stratify_by" -> getStratifyBy(),
"ties" -> getTies(),
"init" -> getInit(),
"lre_min" -> getLreMin(),
"max_iterations" -> getMaxIterations(),
"interactions_only" -> getInteractionsOnly(),
"interactions" -> getInteractions(),
"use_all_factor_levels" -> getUseAllFactorLevels(),
"single_node_mode" -> getSingleNodeMode(),
"model_id" -> getModelId(),
"response_column" -> getLabelCol(),
"weights_column" -> getWeightCol(),
"offset_column" -> getOffsetCol(),
"export_checkpoints_dir" -> getExportCheckpointsDir()) +++
getIgnoredColsParam(trainingFrame) +++
getInteractionPairsParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"startCol" -> "start_column",
"stopCol" -> "stop_column",
"stratifyBy" -> "stratify_by",
"ties" -> "ties",
"init" -> "init",
"lreMin" -> "lre_min",
"maxIterations" -> "max_iterations",
"interactionsOnly" -> "interactions_only",
"interactions" -> "interactions",
"useAllFactorLevels" -> "use_all_factor_levels",
"singleNodeMode" -> "single_node_mode",
"modelId" -> "model_id",
"labelCol" -> "response_column",
"weightCol" -> "weights_column",
"offsetCol" -> "offset_column",
"exportCheckpointsDir" -> "export_checkpoints_dir")
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy