All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.lightgbm.params.ClassifierTrainParams.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.lightgbm.params

import com.microsoft.azure.synapse.ml.core.utils.ParamsStringBuilder
import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMConstants, LightGBMDelegate}

/** Defines the Booster parameters passed to the LightGBM classifier.
  */
case class ClassifierTrainParams(passThroughArgs: Option[String],
                                 isUnbalance: Boolean,
                                 boostFromAverage: Boolean,
                                 isProvideTrainingMetric: Option[Boolean],
                                 delegate: Option[LightGBMDelegate],
                                 generalParams: GeneralParams,
                                 datasetParams: DatasetParams,
                                 dartModeParams: DartModeParams,
                                 executionParams: ExecutionParams,
                                 objectiveParams: ObjectiveParams,
                                 seedParams: SeedParams,
                                 categoricalParams: CategoricalParams,
                                 numClass: Int = 1) extends BaseTrainParams {
  val isBinary: Boolean = objectiveParams.objective == LightGBMConstants.BinaryObjective

  override def getNumClass(): Int = { numClass }

  override def appendSpecializedParams(sb: ParamsStringBuilder): ParamsStringBuilder =
  {
    sb.appendParamValueIfNotThere("num_class", if (!isBinary) Option(numClass) else None)
      .appendParamValueIfNotThere("is_unbalance", if (isBinary) Option(isUnbalance) else None)
      .appendParamValueIfNotThere("boost_from_average", Option(boostFromAverage))
  }

  def setNumClass(n: Int): BaseTrainParams = {
    if (n == numClass) this
    else ClassifierTrainParams(passThroughArgs,
                               isUnbalance,
                               boostFromAverage,
                               isProvideTrainingMetric,
                               delegate,
                               generalParams,
                               datasetParams,
                               dartModeParams,
                               executionParams,
                               objectiveParams,
                               seedParams,
                               categoricalParams,
                               n)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy