All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.ml.params.H2OKMeansParams.scala Maven / Gradle / Ivy

There is a newer version: 3.46.0.4-1-3.5
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.kmeans.KMeansModel.KMeansParameters
import ai.h2o.sparkling.H2OFrame
import hex.kmeans.KMeans.Initialization
import hex.Model.Parameters.FoldAssignmentScheme
import hex.Model.Parameters.CategoricalEncodingScheme

trait H2OKMeansParams
  extends H2OAlgoParamsBase
  with HasUserPoints
  with HasIgnoredCols {

  protected def paramTag = reflect.classTag[KMeansParameters]

  //
  // Parameter definitions
  //
  protected val maxIterations = intParam(
    name = "maxIterations",
    doc = """Maximum training iterations (if estimate_k is enabled, then this is for each inner Lloyds iteration).""")

  protected val standardize = booleanParam(
    name = "standardize",
    doc = """Standardize columns before computing distances.""")

  protected val seed = longParam(
    name = "seed",
    doc = """RNG Seed.""")

  protected val init = stringParam(
    name = "init",
    doc = """Initialization mode. Possible values are ``"Random"``, ``"PlusPlus"``, ``"Furthest"``, ``"User"``.""")

  protected val estimateK = booleanParam(
    name = "estimateK",
    doc = """Whether to estimate the number of clusters (<=k) iteratively and deterministically.""")

  protected val clusterSizeConstraints = nullableIntArrayParam(
    name = "clusterSizeConstraints",
    doc = """An array specifying the minimum number of points that should be in each cluster. The length of the constraints array has to be the same as the number of clusters.""")

  protected val k = intParam(
    name = "k",
    doc = """The max. number of clusters. If estimate_k is disabled, the model will find k centroids, otherwise it will find up to k centroids.""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val nfolds = intParam(
    name = "nfolds",
    doc = """Number of folds for K-fold cross-validation (0 to disable or >= 2).""")

  protected val keepCrossValidationModels = booleanParam(
    name = "keepCrossValidationModels",
    doc = """Whether to keep the cross-validation models.""")

  protected val keepCrossValidationPredictions = booleanParam(
    name = "keepCrossValidationPredictions",
    doc = """Whether to keep the predictions of the cross-validation models.""")

  protected val keepCrossValidationFoldAssignment = booleanParam(
    name = "keepCrossValidationFoldAssignment",
    doc = """Whether to keep the cross-validation fold assignment.""")

  protected val foldCol = nullableStringParam(
    name = "foldCol",
    doc = """Column with cross-validation fold index assignment per observation.""")

  protected val foldAssignment = stringParam(
    name = "foldAssignment",
    doc = """Cross-validation fold assignment scheme, if fold_column is not specified. The 'Stratified' option will stratify the folds based on the response variable, for classification problems. Possible values are ``"AUTO"``, ``"Random"``, ``"Modulo"``, ``"Stratified"``.""")

  protected val categoricalEncoding = stringParam(
    name = "categoricalEncoding",
    doc = """Encoding scheme for categorical features. Possible values are ``"AUTO"``, ``"OneHotInternal"``, ``"OneHotExplicit"``, ``"Enum"``, ``"Binary"``, ``"Eigen"``, ``"LabelEncoder"``, ``"SortByResponse"``, ``"EnumLimited"``.""")

  protected val ignoreConstCols = booleanParam(
    name = "ignoreConstCols",
    doc = """Ignore constant columns.""")

  protected val scoreEachIteration = booleanParam(
    name = "scoreEachIteration",
    doc = """Whether to score during each iteration of model training.""")

  protected val maxRuntimeSecs = doubleParam(
    name = "maxRuntimeSecs",
    doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  //
  // Default values
  //
  setDefault(
    maxIterations -> 10,
    standardize -> true,
    seed -> -1L,
    init -> Initialization.Furthest.name(),
    estimateK -> false,
    clusterSizeConstraints -> null,
    k -> 1,
    modelId -> null,
    nfolds -> 0,
    keepCrossValidationModels -> true,
    keepCrossValidationPredictions -> false,
    keepCrossValidationFoldAssignment -> false,
    foldCol -> null,
    foldAssignment -> FoldAssignmentScheme.AUTO.name(),
    categoricalEncoding -> CategoricalEncodingScheme.AUTO.name(),
    ignoreConstCols -> true,
    scoreEachIteration -> false,
    maxRuntimeSecs -> 0.0,
    exportCheckpointsDir -> null)

  //
  // Getters
  //
  def getMaxIterations(): Int = $(maxIterations)

  def getStandardize(): Boolean = $(standardize)

  def getSeed(): Long = $(seed)

  def getInit(): String = $(init)

  def getEstimateK(): Boolean = $(estimateK)

  def getClusterSizeConstraints(): Array[Int] = $(clusterSizeConstraints)

  def getK(): Int = $(k)

  def getModelId(): String = $(modelId)

  def getNfolds(): Int = $(nfolds)

  def getKeepCrossValidationModels(): Boolean = $(keepCrossValidationModels)

  def getKeepCrossValidationPredictions(): Boolean = $(keepCrossValidationPredictions)

  def getKeepCrossValidationFoldAssignment(): Boolean = $(keepCrossValidationFoldAssignment)

  def getFoldCol(): String = $(foldCol)

  def getFoldAssignment(): String = $(foldAssignment)

  def getCategoricalEncoding(): String = $(categoricalEncoding)

  def getIgnoreConstCols(): Boolean = $(ignoreConstCols)

  def getScoreEachIteration(): Boolean = $(scoreEachIteration)

  def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  //
  // Setters
  //
  def setMaxIterations(value: Int): this.type = {
    set(maxIterations, value)
  }
           
  def setStandardize(value: Boolean): this.type = {
    set(standardize, value)
  }
           
  def setSeed(value: Long): this.type = {
    set(seed, value)
  }
           
  def setInit(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[Initialization](value)
    set(init, validated)
  }
           
  def setEstimateK(value: Boolean): this.type = {
    set(estimateK, value)
  }
           
  def setClusterSizeConstraints(value: Array[Int]): this.type = {
    set(clusterSizeConstraints, value)
  }
           
  def setK(value: Int): this.type = {
    set(k, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setNfolds(value: Int): this.type = {
    set(nfolds, value)
  }
           
  def setKeepCrossValidationModels(value: Boolean): this.type = {
    set(keepCrossValidationModels, value)
  }
           
  def setKeepCrossValidationPredictions(value: Boolean): this.type = {
    set(keepCrossValidationPredictions, value)
  }
           
  def setKeepCrossValidationFoldAssignment(value: Boolean): this.type = {
    set(keepCrossValidationFoldAssignment, value)
  }
           
  def setFoldCol(value: String): this.type = {
    set(foldCol, value)
  }
           
  def setFoldAssignment(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[FoldAssignmentScheme](value)
    set(foldAssignment, validated)
  }
           
  def setCategoricalEncoding(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[CategoricalEncodingScheme](value)
    set(categoricalEncoding, validated)
  }
           
  def setIgnoreConstCols(value: Boolean): this.type = {
    set(ignoreConstCols, value)
  }
           
  def setScoreEachIteration(value: Boolean): this.type = {
    set(scoreEachIteration, value)
  }
           
  def setMaxRuntimeSecs(value: Double): this.type = {
    set(maxRuntimeSecs, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2OKMeansParams(trainingFrame)
  }

  private[sparkling] def getH2OKMeansParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "max_iterations" -> getMaxIterations(),
        "standardize" -> getStandardize(),
        "seed" -> getSeed(),
        "init" -> getInit(),
        "estimate_k" -> getEstimateK(),
        "cluster_size_constraints" -> getClusterSizeConstraints(),
        "k" -> getK(),
        "model_id" -> getModelId(),
        "nfolds" -> getNfolds(),
        "keep_cross_validation_models" -> getKeepCrossValidationModels(),
        "keep_cross_validation_predictions" -> getKeepCrossValidationPredictions(),
        "keep_cross_validation_fold_assignment" -> getKeepCrossValidationFoldAssignment(),
        "fold_column" -> getFoldCol(),
        "fold_assignment" -> getFoldAssignment(),
        "categorical_encoding" -> getCategoricalEncoding(),
        "ignore_const_cols" -> getIgnoreConstCols(),
        "score_each_iteration" -> getScoreEachIteration(),
        "max_runtime_secs" -> getMaxRuntimeSecs(),
        "export_checkpoints_dir" -> getExportCheckpointsDir()) +++
      getUserPointsParam(trainingFrame) +++
      getIgnoredColsParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "maxIterations" -> "max_iterations",
        "standardize" -> "standardize",
        "seed" -> "seed",
        "init" -> "init",
        "estimateK" -> "estimate_k",
        "clusterSizeConstraints" -> "cluster_size_constraints",
        "k" -> "k",
        "modelId" -> "model_id",
        "nfolds" -> "nfolds",
        "keepCrossValidationModels" -> "keep_cross_validation_models",
        "keepCrossValidationPredictions" -> "keep_cross_validation_predictions",
        "keepCrossValidationFoldAssignment" -> "keep_cross_validation_fold_assignment",
        "foldCol" -> "fold_column",
        "foldAssignment" -> "fold_assignment",
        "categoricalEncoding" -> "categorical_encoding",
        "ignoreConstCols" -> "ignore_const_cols",
        "scoreEachIteration" -> "score_each_iteration",
        "maxRuntimeSecs" -> "max_runtime_secs",
        "exportCheckpointsDir" -> "export_checkpoints_dir")
  }
      
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy