All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.lightgbm.LightGBMDelegate.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.lightgbm

import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
import com.microsoft.azure.synapse.ml.lightgbm.params.BaseTrainParams
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger

trait LightGBMDelegate extends Serializable {
  def beforeTrainBatch(batchIndex: Int, log: Logger, dataset: Dataset[_],
                       previousBooster: Option[LightGBMBooster]): Unit = {
    // override this function and write code
  }

  def afterTrainBatch(batchIndex: Int, log: Logger, dataset: Dataset[_],
                      booster: LightGBMBooster): Unit = {
    // override this function and write code
  }

  def beforeGenerateTrainDataset(batchIndex: Int, partitionId: Int, columnParams: ColumnParams, schema: StructType,
                                 log: Logger, trainParams: BaseTrainParams): Unit = {
    // override this function and write code
  }

  def afterGenerateTrainDataset(batchIndex: Int, partitionId: Int, columnParams: ColumnParams, schema: StructType,
                                log: Logger, trainParams: BaseTrainParams): Unit = {
    // override this function and write code
  }

  def beforeGenerateValidDataset(batchIndex: Int, partitionId: Int, columnParams: ColumnParams, schema: StructType,
                                 log: Logger, trainParams: BaseTrainParams): Unit = {
    // override this function and write code
  }

  def afterGenerateValidDataset(batchIndex: Int, partitionId: Int, columnParams: ColumnParams, schema: StructType,
                                log: Logger, trainParams: BaseTrainParams): Unit = {
    // override this function and write code
  }

  def beforeTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger,
                           trainParams: BaseTrainParams, booster: LightGBMBooster, hasValid: Boolean): Unit = {
    // override this function and write code
  }

  def afterTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger,
                          trainParams: BaseTrainParams, booster: LightGBMBooster, hasValid: Boolean,
                          isFinished: Boolean,
                          trainEvalResults: Option[Map[String, Double]],
                          validEvalResults: Option[Map[String, Double]]): Unit = {
    // override this function and write code
  }

  def getLearningRate(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger, trainParams: BaseTrainParams,
                      previousLearningRate: Double): Double = {
    // override this function and write code
    previousLearningRate
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy