ai.h2o.sparkling.ml.params.H2OGLRMParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import hex.glrm.GLRMModel.GLRMParameters
import ai.h2o.sparkling.H2OFrame
import hex.DataInfo.TransformType
import hex.genmodel.algos.glrm.GlrmLoss
import hex.genmodel.algos.glrm.GlrmLoss
import hex.genmodel.algos.glrm.GlrmRegularizer
import hex.genmodel.algos.glrm.GlrmRegularizer
import hex.genmodel.algos.glrm.GlrmInitialization
import hex.svd.SVDModel.SVDParameters.Method
trait H2OGLRMParams
extends H2OAlgoParamsBase
with HasUserX
with HasUserY
with HasLossByColNames {
protected def paramTag = reflect.classTag[GLRMParameters]
//
// Parameter definitions
//
protected val transform = stringParam(
name = "transform",
doc = """Transformation of training data. Possible values are ``"NONE"``, ``"STANDARDIZE"``, ``"NORMALIZE"``, ``"DEMEAN"``, ``"DESCALE"``.""")
protected val k = intParam(
name = "k",
doc = """Rank of matrix approximation.""")
protected val loss = stringParam(
name = "loss",
doc = """Numeric loss function. Possible values are ``"Quadratic"``, ``"Absolute"``, ``"Huber"``, ``"Poisson"``, ``"Periodic(0)"``, ``"Logistic"``, ``"Hinge"``, ``"Categorical"``, ``"Ordinal"``.""")
protected val multiLoss = stringParam(
name = "multiLoss",
doc = """Categorical loss function. Possible values are ``"Quadratic"``, ``"Absolute"``, ``"Huber"``, ``"Poisson"``, ``"Periodic(0)"``, ``"Logistic"``, ``"Hinge"``, ``"Categorical"``, ``"Ordinal"``.""")
protected val lossByCol = nullableStringArrayParam(
name = "lossByCol",
doc = """Loss function by column (override). Possible values are ``"Quadratic"``, ``"Absolute"``, ``"Huber"``, ``"Poisson"``, ``"Periodic(0)"``, ``"Logistic"``, ``"Hinge"``, ``"Categorical"``, ``"Ordinal"``.""")
protected val period = intParam(
name = "period",
doc = """Length of period (only used with periodic loss function).""")
protected val regularizationX = stringParam(
name = "regularizationX",
doc = """Regularization function for X matrix. Possible values are ``"None"``, ``"Quadratic"``, ``"L2"``, ``"L1"``, ``"NonNegative"``, ``"OneSparse"``, ``"UnitOneSparse"``, ``"Simplex"``.""")
protected val regularizationY = stringParam(
name = "regularizationY",
doc = """Regularization function for Y matrix. Possible values are ``"None"``, ``"Quadratic"``, ``"L2"``, ``"L1"``, ``"NonNegative"``, ``"OneSparse"``, ``"UnitOneSparse"``, ``"Simplex"``.""")
protected val gammaX = doubleParam(
name = "gammaX",
doc = """Regularization weight on X matrix.""")
protected val gammaY = doubleParam(
name = "gammaY",
doc = """Regularization weight on Y matrix.""")
protected val maxIterations = intParam(
name = "maxIterations",
doc = """Maximum number of iterations.""")
protected val maxUpdates = intParam(
name = "maxUpdates",
doc = """Maximum number of updates, defaults to 2*max_iterations.""")
protected val initStepSize = doubleParam(
name = "initStepSize",
doc = """Initial step size.""")
protected val minStepSize = doubleParam(
name = "minStepSize",
doc = """Minimum step size.""")
protected val seed = longParam(
name = "seed",
doc = """RNG seed for initialization.""")
protected val init = stringParam(
name = "init",
doc = """Initialization mode. Possible values are ``"Random"``, ``"SVD"``, ``"PlusPlus"``, ``"User"``, ``"Power"``.""")
protected val svdMethod = stringParam(
name = "svdMethod",
doc = """Method for computing SVD during initialization (Caution: Randomized is currently experimental and unstable). Possible values are ``"GramSVD"``, ``"Power"``, ``"Randomized"``.""")
protected val loadingName = nullableStringParam(
name = "loadingName",
doc = """[Deprecated] Use representation_name instead. Frame key to save resulting X.""")
protected val representationName = nullableStringParam(
name = "representationName",
doc = """Frame key to save resulting X.""")
protected val expandUserY = booleanParam(
name = "expandUserY",
doc = """Expand categorical columns in user-specified initial Y.""")
protected val imputeOriginal = booleanParam(
name = "imputeOriginal",
doc = """Reconstruct original training data by reversing transform.""")
protected val recoverSvd = booleanParam(
name = "recoverSvd",
doc = """Recover singular values and eigenvectors of XY.""")
protected val modelId = nullableStringParam(
name = "modelId",
doc = """Destination id for this model; auto-generated if not specified.""")
protected val ignoredCols = nullableStringArrayParam(
name = "ignoredCols",
doc = """Names of columns to ignore for training.""")
protected val ignoreConstCols = booleanParam(
name = "ignoreConstCols",
doc = """Ignore constant columns.""")
protected val scoreEachIteration = booleanParam(
name = "scoreEachIteration",
doc = """Whether to score during each iteration of model training.""")
protected val maxRuntimeSecs = doubleParam(
name = "maxRuntimeSecs",
doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")
protected val exportCheckpointsDir = nullableStringParam(
name = "exportCheckpointsDir",
doc = """Automatically export generated models to this directory.""")
//
// Default values
//
setDefault(
transform -> TransformType.NONE.name(),
k -> 1,
loss -> GlrmLoss.Quadratic.name(),
multiLoss -> GlrmLoss.Categorical.name(),
lossByCol -> null,
period -> 1,
regularizationX -> GlrmRegularizer.None.name(),
regularizationY -> GlrmRegularizer.None.name(),
gammaX -> 0.0,
gammaY -> 0.0,
maxIterations -> 1000,
maxUpdates -> 2000,
initStepSize -> 1.0,
minStepSize -> 1.0e-4,
seed -> -1L,
init -> GlrmInitialization.PlusPlus.name(),
svdMethod -> Method.Randomized.name(),
loadingName -> null,
representationName -> null,
expandUserY -> true,
imputeOriginal -> false,
recoverSvd -> false,
modelId -> null,
ignoredCols -> null,
ignoreConstCols -> true,
scoreEachIteration -> false,
maxRuntimeSecs -> 0.0,
exportCheckpointsDir -> null)
//
// Getters
//
def getTransform(): String = $(transform)
def getK(): Int = $(k)
def getLoss(): String = $(loss)
def getMultiLoss(): String = $(multiLoss)
def getLossByCol(): Array[String] = $(lossByCol)
def getPeriod(): Int = $(period)
def getRegularizationX(): String = $(regularizationX)
def getRegularizationY(): String = $(regularizationY)
def getGammaX(): Double = $(gammaX)
def getGammaY(): Double = $(gammaY)
def getMaxIterations(): Int = $(maxIterations)
def getMaxUpdates(): Int = $(maxUpdates)
def getInitStepSize(): Double = $(initStepSize)
def getMinStepSize(): Double = $(minStepSize)
def getSeed(): Long = $(seed)
def getInit(): String = $(init)
def getSvdMethod(): String = $(svdMethod)
def getLoadingName(): String = $(loadingName)
def getRepresentationName(): String = $(representationName)
def getExpandUserY(): Boolean = $(expandUserY)
def getImputeOriginal(): Boolean = $(imputeOriginal)
def getRecoverSvd(): Boolean = $(recoverSvd)
def getModelId(): String = $(modelId)
def getIgnoredCols(): Array[String] = $(ignoredCols)
def getIgnoreConstCols(): Boolean = $(ignoreConstCols)
def getScoreEachIteration(): Boolean = $(scoreEachIteration)
def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)
def getExportCheckpointsDir(): String = $(exportCheckpointsDir)
//
// Setters
//
def setTransform(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[TransformType](value)
set(transform, validated)
}
def setK(value: Int): this.type = {
set(k, value)
}
def setLoss(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[GlrmLoss](value)
set(loss, validated)
}
def setMultiLoss(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[GlrmLoss](value)
set(multiLoss, validated)
}
def setLossByCol(value: Array[String]): this.type = {
val validated = EnumParamValidator.getValidatedEnumValues[hex.genmodel.algos.glrm.GlrmLoss](value, nullEnabled = true)
set(lossByCol, validated)
}
def setPeriod(value: Int): this.type = {
set(period, value)
}
def setRegularizationX(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[GlrmRegularizer](value)
set(regularizationX, validated)
}
def setRegularizationY(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[GlrmRegularizer](value)
set(regularizationY, validated)
}
def setGammaX(value: Double): this.type = {
set(gammaX, value)
}
def setGammaY(value: Double): this.type = {
set(gammaY, value)
}
def setMaxIterations(value: Int): this.type = {
set(maxIterations, value)
}
def setMaxUpdates(value: Int): this.type = {
set(maxUpdates, value)
}
def setInitStepSize(value: Double): this.type = {
set(initStepSize, value)
}
def setMinStepSize(value: Double): this.type = {
set(minStepSize, value)
}
def setSeed(value: Long): this.type = {
set(seed, value)
}
def setInit(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[GlrmInitialization](value)
set(init, validated)
}
def setSvdMethod(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[Method](value)
set(svdMethod, validated)
}
def setLoadingName(value: String): this.type = {
set(loadingName, value)
}
def setRepresentationName(value: String): this.type = {
set(representationName, value)
}
def setExpandUserY(value: Boolean): this.type = {
set(expandUserY, value)
}
def setImputeOriginal(value: Boolean): this.type = {
set(imputeOriginal, value)
}
def setRecoverSvd(value: Boolean): this.type = {
set(recoverSvd, value)
}
def setModelId(value: String): this.type = {
set(modelId, value)
}
def setIgnoredCols(value: Array[String]): this.type = {
set(ignoredCols, value)
}
def setIgnoreConstCols(value: Boolean): this.type = {
set(ignoreConstCols, value)
}
def setScoreEachIteration(value: Boolean): this.type = {
set(scoreEachIteration, value)
}
def setMaxRuntimeSecs(value: Double): this.type = {
set(maxRuntimeSecs, value)
}
def setExportCheckpointsDir(value: String): this.type = {
set(exportCheckpointsDir, value)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2OGLRMParams(trainingFrame)
}
private[sparkling] def getH2OGLRMParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"transform" -> getTransform(),
"k" -> getK(),
"loss" -> getLoss(),
"multi_loss" -> getMultiLoss(),
"loss_by_col" -> getLossByCol(),
"period" -> getPeriod(),
"regularization_x" -> getRegularizationX(),
"regularization_y" -> getRegularizationY(),
"gamma_x" -> getGammaX(),
"gamma_y" -> getGammaY(),
"max_iterations" -> getMaxIterations(),
"max_updates" -> getMaxUpdates(),
"init_step_size" -> getInitStepSize(),
"min_step_size" -> getMinStepSize(),
"seed" -> getSeed(),
"init" -> getInit(),
"svd_method" -> getSvdMethod(),
"loading_name" -> getLoadingName(),
"representation_name" -> getRepresentationName(),
"expand_user_y" -> getExpandUserY(),
"impute_original" -> getImputeOriginal(),
"recover_svd" -> getRecoverSvd(),
"model_id" -> getModelId(),
"ignored_columns" -> getIgnoredCols(),
"ignore_const_cols" -> getIgnoreConstCols(),
"score_each_iteration" -> getScoreEachIteration(),
"max_runtime_secs" -> getMaxRuntimeSecs(),
"export_checkpoints_dir" -> getExportCheckpointsDir()) +++
getUserXParam(trainingFrame) +++
getUserYParam(trainingFrame) +++
getLossByColNamesParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"transform" -> "transform",
"k" -> "k",
"loss" -> "loss",
"multiLoss" -> "multi_loss",
"lossByCol" -> "loss_by_col",
"period" -> "period",
"regularizationX" -> "regularization_x",
"regularizationY" -> "regularization_y",
"gammaX" -> "gamma_x",
"gammaY" -> "gamma_y",
"maxIterations" -> "max_iterations",
"maxUpdates" -> "max_updates",
"initStepSize" -> "init_step_size",
"minStepSize" -> "min_step_size",
"seed" -> "seed",
"init" -> "init",
"svdMethod" -> "svd_method",
"loadingName" -> "loading_name",
"representationName" -> "representation_name",
"expandUserY" -> "expand_user_y",
"imputeOriginal" -> "impute_original",
"recoverSvd" -> "recover_svd",
"modelId" -> "model_id",
"ignoredCols" -> "ignored_columns",
"ignoreConstCols" -> "ignore_const_cols",
"scoreEachIteration" -> "score_each_iteration",
"maxRuntimeSecs" -> "max_runtime_secs",
"exportCheckpointsDir" -> "export_checkpoints_dir")
}
}