All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.lightgbm

import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
import com.microsoft.azure.synapse.ml.lightgbm.params.{BaseTrainParams, ClassifierTrainParams,
  LightGBMModelParams, LightGBMPredictionParams}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.classification.{ProbabilisticClassificationModel, ProbabilisticClassifier}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructField

object LightGBMClassifier extends DefaultParamsReadable[LightGBMClassifier]

/** Trains a LightGBM Classification model, a fast, distributed, high performance gradient boosting
  * framework based on decision tree algorithms.
  * For more information please see here: https://github.com/Microsoft/LightGBM.
  * For parameter information see here: https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst
  * @param uid The unique ID.
  */
class LightGBMClassifier(override val uid: String)
  extends ProbabilisticClassifier[Vector, LightGBMClassifier, LightGBMClassificationModel]
  with LightGBMBase[LightGBMClassificationModel] with SynapseMLLogging {
  logClass(FeatureNames.LightGBM)

  def this() = this(Identifiable.randomUID("LightGBMClassifier"))

  // Set default objective to be binary classification
  setDefault(objective -> LightGBMConstants.BinaryObjective)

  val isUnbalance = new BooleanParam(this, "isUnbalance",
    "Set to true if training data is unbalanced in binary classification scenario")
  setDefault(isUnbalance -> false)
  def getIsUnbalance: Boolean = $(isUnbalance)
  def setIsUnbalance(value: Boolean): this.type = set(isUnbalance, value)

  def getTrainParams(numTasks: Int, featuresSchema: StructField, numTasksPerExec: Int): BaseTrainParams = {
    ClassifierTrainParams(
      get(passThroughArgs),
      getIsUnbalance,
      getBoostFromAverage,
      get(isProvideTrainingMetric),
      getDelegate,
      getGeneralParams(numTasks, featuresSchema),
      getDatasetParams,
      getDartParams,
      getExecutionParams(numTasksPerExec),
      getObjectiveParams,
      getSeedParams,
      getCategoricalParams)
  }

  override protected def addCustomTrainParams(params: BaseTrainParams, dataset: Dataset[_]): BaseTrainParams = {
    /* The native code for getting numClasses is always 1 unless it is multiclass-classification problem
     * so we infer the actual numClasses from the dataset here.  Since this could be a full pass over
     * the data, explicitly call it out as a calculation and only do it if needed.
     */
    val classifierParams = params.asInstanceOf[ClassifierTrainParams]
    if (classifierParams.isBinary) params
    else classifierParams.setNumClass(getNumClasses(dataset, getMaxNumClasses))
  }

  def getModel(trainParams: BaseTrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
    val classifierTrainParams = trainParams.asInstanceOf[ClassifierTrainParams]
    val model = new LightGBMClassificationModel(uid)
      .setLightGBMBooster(lightGBMBooster)
      .setFeaturesCol(getFeaturesCol)
      .setPredictionCol(getPredictionCol)
      .setProbabilityCol(getProbabilityCol)
      .setRawPredictionCol(getRawPredictionCol)
      .setLeafPredictionCol(getLeafPredictionCol)
      .setFeaturesShapCol(getFeaturesShapCol)
      .setActualNumClasses(classifierTrainParams.numClass)
      .setNumIterations(lightGBMBooster.bestIteration)
    if (isDefined(thresholds)) model.setThresholds(getThresholds) else model
  }

  def stringFromTrainedModel(model: LightGBMClassificationModel): String = {
    model.getModel.modelStr.get
  }

  override def copy(extra: ParamMap): LightGBMClassifier = defaultCopy(extra)
}

/** Special parameter for classification model for actual number of classes in dataset
  */
trait HasActualNumClasses extends Params {
  val actualNumClasses = new IntParam(this, "actualNumClasses",
    "Inferred number of classes based on dataset metadata or, if there is no metadata, unique count")
  def getActualNumClasses: Int = $(actualNumClasses)
  def setActualNumClasses(value: Int): this.type = set(actualNumClasses, value)
}

/** Model produced by [[LightGBMClassifier]]. */
class LightGBMClassificationModel(override val uid: String)
    extends ProbabilisticClassificationModel[Vector, LightGBMClassificationModel]
      with LightGBMModelParams with LightGBMModelMethods with LightGBMPredictionParams
      with HasActualNumClasses with ComplexParamsWritable with SynapseMLLogging {
  logClass(FeatureNames.LightGBM)

  def this() = this(Identifiable.randomUID("LightGBMClassificationModel"))

  override protected lazy val pyInternalWrapper = true

  /**
    * Implementation based on ProbabilisticClassifier with slight modifications to
    * avoid calling raw2probabilityInPlace to defer the probability calculation
    * to lightgbm native code.
    *
    * @param dataset input dataset
    * @return transformed dataset
    */
  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      updateBoosterParamsBeforePredict()
      transformSchema(dataset.schema, logging = true)
      if (isDefined(thresholds)) {
        require(getThresholds.length == numClasses, this.getClass.getSimpleName +
          ".transform() called with non-matching numClasses and thresholds.length." +
          s" numClasses=$numClasses, but thresholds has length ${getThresholds.length}")
      }

      // Output selected columns only.
      var outputData = dataset
      var numColsOutput = 0
      if (getRawPredictionCol.nonEmpty) {
        val predictRawUDF = udf(predictRaw _)
        outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
        numColsOutput += 1
      }
      if (getProbabilityCol.nonEmpty) {
        val probabilityUDF = udf(predictProbability _)
        outputData = outputData.withColumn(getProbabilityCol, probabilityUDF(col(getFeaturesCol)))
        numColsOutput += 1
      }
      if (getPredictionCol.nonEmpty) {
        val predUDF = predictColumn
        outputData = outputData.withColumn(getPredictionCol, predUDF)
        numColsOutput += 1
      }
      if (getLeafPredictionCol.nonEmpty) {
        val predLeafUDF = udf(predictLeaf _)
        outputData = outputData.withColumn(getLeafPredictionCol, predLeafUDF(col(getFeaturesCol)))
        numColsOutput += 1
      }
      if (getFeaturesShapCol.nonEmpty) {
        val featureShapUDF = udf(featuresShap _)
        outputData = outputData.withColumn(getFeaturesShapCol, featureShapUDF(col(getFeaturesCol)))
        numColsOutput += 1
      }

      if (numColsOutput == 0) {
        this.logWarning(s"$uid: LightGBMClassificationModel.transform() was called as NOOP" +
          " since no output columns were set.")
      }
      outputData.toDF
    }, dataset.columns.length)
  }

  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
    throw new NotImplementedError("Unexpected error in LightGBMClassificationModel:" +
      " raw2probabilityInPlace should not be called!")
  }

  override def numClasses: Int = getActualNumClasses

  override def predictRaw(features: Vector): Vector = {
    Vectors.dense(getModel.score(features, raw = true, classification = true, getPredictDisableShapeCheck))
  }

  override def predictProbability(features: Vector): Vector = {
    Vectors.dense(getModel.score(features, raw = false, classification = true, getPredictDisableShapeCheck))
  }

  override def copy(extra: ParamMap): LightGBMClassificationModel = defaultCopy(extra)

  protected def predictColumn: Column = {
    if (getRawPredictionCol.nonEmpty && !isDefined(thresholds)) {
      // Note: Only call raw2prediction if thresholds not defined
      udf(raw2prediction _).apply(col(getRawPredictionCol))
    } else if (getProbabilityCol.nonEmpty) {
      udf(probability2prediction _).apply(col(getProbabilityCol))
    } else {
      udf(predict _).apply(col(getFeaturesCol))
    }
  }
}

object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] {
  def loadNativeModelFromFile(filename: String): LightGBMClassificationModel = {
    val uid = Identifiable.randomUID("LightGBMClassificationModel")
    val session = SparkSession.builder().getOrCreate()
    val textRdd = session.read.text(filename)
    val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
    val lightGBMBooster = new LightGBMBooster(text)
    val actualNumClasses = lightGBMBooster.numClasses
    new LightGBMClassificationModel(uid).setLightGBMBooster(lightGBMBooster).setActualNumClasses(actualNumClasses)
  }

  def loadNativeModelFromString(model: String): LightGBMClassificationModel = {
    val uid = Identifiable.randomUID("LightGBMClassificationModel")
    val lightGBMBooster = new LightGBMBooster(model)
    val actualNumClasses = lightGBMBooster.numClasses
    new LightGBMClassificationModel(uid).setLightGBMBooster(lightGBMBooster).setActualNumClasses(actualNumClasses)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy