ai.h2o.sparkling.ml.params.H2OAutoMLInputParams.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.h2o.sparkling.ml.params
import ai.h2o.automl.AutoMLBuildSpec.AutoMLInput
import ai.h2o.sparkling.H2OFrame
import ai.h2o.sparkling.ml.utils.H2OAutoMLSortMetric
trait H2OAutoMLInputParams
extends H2OAlgoParamsBase
with HasIgnoredCols
with HasLeaderboardDataFrame
with HasBlendingDataFrame {
//
// Parameter definitions
//
protected val labelCol = stringParam(
name = "labelCol",
doc = """Response column.""")
protected val foldCol = nullableStringParam(
name = "foldCol",
doc = """Fold column (contains fold IDs) in the training frame. These assignments are used to create the folds for cross-validation of the models.""")
protected val weightCol = nullableStringParam(
name = "weightCol",
doc = """Weights column in the training frame, which specifies the row weights used in model training.""")
protected val sortMetric = stringParam(
name = "sortMetric",
doc = """Metric used to sort leaderboard. Possible values are ``"AUTO"``, ``"deviance"``, ``"logloss"``, ``"MSE"``, ``"RMSE"``, ``"MAE"``, ``"RMSLE"``, ``"AUC"``, ``"mean_per_class_error"``.""")
//
// Default values
//
setDefault(
labelCol -> "label",
foldCol -> null,
weightCol -> null,
sortMetric -> H2OAutoMLSortMetric.AUTO.name())
//
// Getters
//
def getLabelCol(): String = $(labelCol)
def getFoldCol(): String = $(foldCol)
def getWeightCol(): String = $(weightCol)
def getSortMetric(): String = $(sortMetric)
//
// Setters
//
def setLabelCol(value: String): this.type = {
set(labelCol, value)
}
def setFoldCol(value: String): this.type = {
set(foldCol, value)
}
def setWeightCol(value: String): this.type = {
set(weightCol, value)
}
def setSortMetric(value: String): this.type = {
val validated = EnumParamValidator.getValidatedEnumValue[H2OAutoMLSortMetric](value)
set(sortMetric, validated)
}
override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
super.getH2OAlgorithmParams(trainingFrame) ++ getH2OAutoMLInputParams(trainingFrame)
}
private[sparkling] def getH2OAutoMLInputParams(trainingFrame: H2OFrame): Map[String, Any] = {
Map(
"response_column" -> getLabelCol(),
"fold_column" -> getFoldCol(),
"weights_column" -> getWeightCol(),
"sort_metric" -> getSortMetric()) +++
getIgnoredColsParam(trainingFrame) +++
getLeaderboardDataFrameParam(trainingFrame) +++
getBlendingDataFrameParam(trainingFrame)
}
override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
super.getSWtoH2OParamNameMap() ++
Map(
"labelCol" -> "response_column",
"foldCol" -> "fold_column",
"weightCol" -> "weights_column",
"sortMetric" -> "sort_metric")
}
}