All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.lightgbm

import com.microsoft.azure.synapse.ml.core.utils.{ClusterUtil, ParamsStringBuilder}
import com.microsoft.azure.synapse.ml.io.http.SharedSingleton
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
import com.microsoft.azure.synapse.ml.lightgbm.dataset._
import com.microsoft.azure.synapse.ml.lightgbm.params._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.ml.param.shared.{HasFeaturesCol => HasFeaturesColSpark, HasLabelCol => HasLabelColSpark}
import org.apache.spark.ml.{ComplexParamsWritable, Estimator, Model}
import org.apache.spark.sql._
import org.apache.spark.sql.types._

import scala.collection.immutable.HashSet
import scala.language.existentials
import scala.math.min
import scala.util.matching.Regex

// scalastyle:off file.size.limit
trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams] extends Estimator[TrainedModel]
  with LightGBMParams with ComplexParamsWritable
  with HasFeaturesColSpark with HasLabelColSpark with LightGBMPerformance with SynapseMLLogging {

  /** Trains the LightGBM model.  If batches are specified, breaks training dataset into batches for training.
    *
    * @param dataset The input dataset to train.
    * @return The trained model.
    */
  protected def train(dataset: Dataset[_]): TrainedModel = {
    LightGBMUtils.initializeNativeLibrary()

    val isMultiBatch = getNumBatches > 0
    val numBatches = if (isMultiBatch) getNumBatches else 1
    setBatchPerformanceMeasures(Array.fill(numBatches)(None))

    logFit({
      getOptGroupCol.foreach(DatasetUtils.validateGroupColumn(_, dataset.schema))
      if (isMultiBatch) {
        val ratio = 1.0 / numBatches
        val datasetBatches = dataset.randomSplit((0 until numBatches).map(_ => ratio).toArray)
        datasetBatches.zipWithIndex.foldLeft(None: Option[TrainedModel]) { (model, datasetWithIndex) =>
          if (model.isDefined) {
            setModelString(stringFromTrainedModel(model.get))
          }
          val batchDataset = datasetWithIndex._1
          val batchIndex = datasetWithIndex._2

          beforeTrainBatch(batchIndex, batchDataset, model)
          val newModel = trainOneDataBatch(batchDataset, batchIndex, numBatches)
          afterTrainBatch(batchIndex, batchDataset, newModel)

          Some(newModel)
        }.get
      } else {
        trainOneDataBatch(dataset, batchIndex = 0, 1)
      }
    }, dataset.columns.length)
  }

  def beforeTrainBatch(batchIndex: Int, dataset: Dataset[_], model: Option[TrainedModel]): Unit = {
    if (getDelegate.isDefined) {
      val previousBooster: Option[LightGBMBooster] = model match {
        case Some(_) => Option(new LightGBMBooster(stringFromTrainedModel(model.get)))
        case None => None
      }

      getDelegate.get.beforeTrainBatch(batchIndex, log, dataset, previousBooster)
    }
  }

  def afterTrainBatch(batchIndex: Int, dataset: Dataset[_], model: TrainedModel): Unit = {
    if (getDelegate.isDefined) {
      val booster = new LightGBMBooster(stringFromTrainedModel(model))
      getDelegate.get.afterTrainBatch(batchIndex, log, dataset, booster)
    }
  }

  protected def castColumns(dataset: Dataset[_], trainingCols: Array[(String, Seq[DataType])]): DataFrame = {
    val schema = dataset.schema
    // Cast columns to correct types
    dataset.select(
      trainingCols.map {
        case (name, datatypes: Seq[DataType]) =>
          val index = schema.fieldIndex(name)
          // Note: We only want to cast if original column was of numeric type

          schema.fields(index).dataType match {
            case numericDataType: NumericType =>
              // If more than one datatype is allowed, see if any match
              if (datatypes.contains(numericDataType)) {
                dataset(name)
              } else {
                // If none match, cast to the first option
                dataset(name).cast(datatypes.head)
              }
            case _ => dataset(name)
          }
      }: _*
    ).toDF()
  }

  protected def prepareDataframe(dataset: Dataset[_], numTasks: Int): DataFrame = {
    val df = castColumns(dataset, getTrainingCols)
    // Reduce number of partitions to number of executor tasks
    /* Note: with barrier execution mode we must use repartition instead of coalesce when
     * running on spark standalone.
     * Using coalesce, we get the error:
     *
     * org.apache.spark.scheduler.BarrierJobUnsupportedRDDChainException:
     * [SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following
     * pattern of RDD chain within a barrier stage:
     * 1. Ancestor RDDs that have different number of partitions from the resulting
     * RDD (eg. union()/coalesce()/first()/take()/PartitionPruningRDD). A workaround
     * for first()/take() can be barrierRdd.collect().head (scala) or barrierRdd.collect()[0] (python).
     * 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)).
     *
     * Without repartition, we may hit the error:
     * org.apache.spark.scheduler.BarrierJobSlotsNumberCheckFailed: [SPARK-24819]:
     * Barrier execution mode does not allow run a barrier stage that requires more
     * slots than the total number of slots in the cluster currently. Please init a
     * new cluster with more CPU cores or repartition the input RDD(s) to reduce the
     * number of slots required to run this barrier stage.
     *
     * Hence we still need to estimate the number of tasks and repartition even when using
     * barrier execution, which is unfortunate as repartition is more expensive than coalesce.
     */
    if (getUseBarrierExecutionMode) {
      val numPartitions = df.rdd.getNumPartitions
      if (numPartitions > numTasks) {
        df.repartition(numTasks)
      } else {
        df
      }
    } else {
      df.coalesce(numTasks)
    }
  }

  protected def getTrainingCols: Array[(String, Seq[DataType])] = {
    val colsToCheck: Array[(Option[String], Seq[DataType])] = Array(
      (Some(getLabelCol), Seq(DoubleType)),
      (Some(getFeaturesCol), Seq(VectorType)),
      (get(weightCol), Seq(DoubleType)),
      (getOptGroupCol, Seq(IntegerType, LongType, StringType)),
      (get(validationIndicatorCol), Seq(BooleanType)),
      (get(initScoreCol), Seq(DoubleType, VectorType)))

    colsToCheck.flatMap {
      case (col: Option[String], colType: Seq[DataType]) =>
        if (col.isDefined) Some(col.get, colType) else None
    }
  }

  /**
    * Retrieves the categorical indexes in the features column.
    *
    * @param featuresSchema The schema of the features column
    * @return the categorical indexes in the features column.
    */
  protected def getCategoricalIndexes(featuresSchema: StructField): Array[Int] = {
    val categoricalColumnSlotNames = get(categoricalSlotNames).getOrElse(Array.empty[String])
    val categoricalIndexes = if (getSlotNames.nonEmpty) {
      categoricalColumnSlotNames.map(getSlotNames.indexOf(_))
    } else {
      val categoricalSlotNamesSet = HashSet(categoricalColumnSlotNames: _*)
      val attributes = AttributeGroup.fromStructField(featuresSchema).attributes
      if (attributes.isEmpty) {
        Array[Int]()
      } else {
       attributes.get.zipWithIndex.flatMap {
          case (null, _) => None  // scalastyle:ignore null
          case (attr, idx) =>
            if (attr.name.isDefined && categoricalSlotNamesSet.contains(attr.name.get)) {
              Some(idx)
            } else {
              attr match {
                case _: NumericAttribute | UnresolvedAttribute => None
                // Note: it seems that BinaryAttribute is not considered categorical,
                // since all OHE cols are marked with this, but StringIndexed are always Nominal
                case _: BinaryAttribute => None
                case _: NominalAttribute => Some(idx)
              }
            }
        }
      }
    }

    get(categoricalSlotIndexes)
      .getOrElse(Array.empty[Int])
      .union(categoricalIndexes).distinct
  }

  private def getSlotNamesWithMetadata(featuresSchema: StructField): Option[Array[String]] = {
    if (getSlotNames.nonEmpty) {
      Some(getSlotNames)
    } else {
      AttributeGroup.fromStructField(featuresSchema).attributes.flatMap(attributes =>
        if (attributes.isEmpty) {
          None
        } else {
          val colNames = attributes.indices.map(_.toString).toArray
          attributes.foreach(attr =>
            attr.index.foreach(index => colNames(index) = attr.name.getOrElse(index.toString)))
          Some(colNames)
        }
      )
    }
  }

  private def validateSlotNames(featuresSchema: StructField): Unit = {
    val metadata = AttributeGroup.fromStructField(featuresSchema)
    if (metadata.attributes.isDefined) {
      val slotNamesOpt = getSlotNamesWithMetadata(featuresSchema)
      val pattern = new Regex("[\",:\\[\\]{}]")
      slotNamesOpt.foreach(slotNames => {
        val badSlotNames = slotNames.flatMap(slotName =>
          if (pattern.findFirstIn(slotName).isEmpty) None else Option(slotName))
        if (!badSlotNames.isEmpty) {
          throw new IllegalArgumentException(
            s"Invalid slot names detected in features column: ${badSlotNames.mkString(",")}" +
            " \n Special characters \" , : \\ [ ] { } will cause unexpected behavior in LGBM unless changed." +
            " This error can be fixed by renaming the problematic columns prior to vector assembly.")
        }
      })
    }
  }

  /**
    * Constructs the DartModeParams
    *
    * @return DartModeParams object containing parameters related to dart mode.
    */
  protected def getDartParams: DartModeParams = {
    DartModeParams(getDropRate, getMaxDrop, getSkipDrop, getXGBoostDartMode, getUniformDrop)
  }

  /**
    * Constructs the GeneralParams.
    *
    * @return GeneralParams object containing parameters related to general LightGBM parameters.
    */
  protected def getGeneralParams(numTasks: Int, featuresSchema: StructField): GeneralParams = {
    GeneralParams(
      getParallelism,
      get(topK),
      getNumIterations,
      getLearningRate,
      get(numLeaves),
      get(maxBin),
      get(binSampleCount),
      get(baggingFraction),
      get(posBaggingFraction),
      get(negBaggingFraction),
      get(baggingFreq),
      get(baggingSeed),
      getEarlyStoppingRound,
      getImprovementTolerance,
      get(featureFraction),
      get(featureFractionByNode),
      get(maxDepth),
      get(minSumHessianInLeaf),
      numTasks,
      if (getModelString == null || getModelString.isEmpty) None else get(modelString),
      getCategoricalIndexes(featuresSchema),
      getVerbosity,
      getBoostingType,
      get(lambdaL1),
      get(lambdaL2),
      get(metric),
      get(minGainToSplit),
      get(maxDeltaStep),
      getMaxBinByFeature,
      get(minDataPerBin),
      get(minDataInLeaf),
      get(topRate),
      get(otherRate),
      getMonotoneConstraints,
      get(monotoneConstraintsMethod),
      get(monotonePenalty),
      getSlotNames)
  }

  /**
    * Constructs the DatasetParams.
    *
    * @return DatasetParams object containing parameters related to LightGBM Dataset parameters.
    */
  protected def getDatasetParams: DatasetParams = {
    DatasetParams(
      get(isEnableSparse),
      get(useMissing),
      get(zeroAsMissing)
    )
  }

  /**
    * Constructs the ExecutionParams.
    *
    * @return ExecutionParams object containing parameters related to LightGBM execution.
    */
  protected def getExecutionParams(numTasksPerExec: Int): ExecutionParams = {
    val execNumThreads =
      if (getUseSingleDatasetMode) get(numThreads).getOrElse(numTasksPerExec - 1)
      else getNumThreads
    ExecutionParams(getChunkSize,
                    getMatrixType,
                    execNumThreads,
                    getDataTransferMode,
                    getSamplingMode,
                    getSamplingSubsetSize,
                    getMicroBatchSize,
                    getUseSingleDatasetMode,
                    getMaxStreamingOMPThreads)
  }

  /**
    * Constructs the ColumnParams.
    *
    * @return ColumnParams object containing the parameters related to LightGBM columns.
    */
  protected def getColumnParams: ColumnParams = {
    ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol)
  }

  /**
    * Constructs the ObjectiveParams.
    *
    * @return ObjectiveParams object containing parameters related to the objective function.
    */
  protected def getObjectiveParams: ObjectiveParams = {
    ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None)
  }

  /**
    * Constructs the SeedParams.
    *
    * @return SeedParams object containing the parameters related to LightGBM seeds and determinism.
    */
  protected def getSeedParams: SeedParams = {
    SeedParams(
      get(seed),
      get(deterministic),
      get(baggingSeed),
      get(featureFractionSeed),
      get(extraSeed),
      get(dropSeed),
      get(dataRandomSeed),
      get(objectiveSeed),
      getBoostingType,
      getObjective)
  }

  /**
    * Constructs the CategoricalParams.
    *
    * @return CategoricalParams object containing the parameters related to LightGBM categorical features.
    */
  protected def getCategoricalParams: CategoricalParams = {
    CategoricalParams(
      get(minDataPerGroup),
      get(maxCatThreshold),
      get(catl2),
      get(catSmooth),
      get(maxCatToOnehot))
  }

  protected def getDatasetCreationParams(categoricalIndexes: Array[Int], numThreads: Int): String = {
    new ParamsStringBuilder(prefix = "", delimiter = "=")
      .appendParamValueIfNotThere("is_pre_partition", Option("True"))
      .appendParamValueIfNotThere("max_bin", Option(getMaxBin))
      .appendParamValueIfNotThere("bin_construct_sample_cnt", Option(getBinSampleCount))
      .appendParamValueIfNotThere("min_data_in_leaf", Option(getMinDataInLeaf))
      .appendParamValueIfNotThere("num_threads", Option(numThreads))
      .appendParamListIfNotThere("categorical_feature", categoricalIndexes)
      .appendParamValueIfNotThere("data_random_seed", get(dataRandomSeed).orElse(get(seed)))
      .result()

    // TODO do we need to add any of the new Dataset parameters here? (e.g. is_enable_sparse)
  }

  /**
    * Inner train method for LightGBM learners.  Calculates the number of workers,
    * creates a driver thread, and runs mapPartitions on the dataset.
    *
    * @param dataset    The dataset to train on.
    * @param batchIndex In running in batch training mode, gets the batch number.
    * @return The LightGBM Model from the trained LightGBM Booster.
    */
  private def trainOneDataBatch(dataset: Dataset[_], batchIndex: Int, batchCount: Int): TrainedModel = {
    val measures = new InstrumentationMeasures()
    setBatchPerformanceMeasure(batchIndex, measures)

    val numTasksPerExecutor = ClusterUtil.getNumTasksPerExecutor(dataset.sparkSession, log)
    val numTasks = determineNumTasks(dataset, getNumTasks, numTasksPerExecutor)
    val sc = dataset.sparkSession.sparkContext

    val df = prepareDataframe(dataset, numTasks)

    val (trainingData, validationData) =
      if (get(validationIndicatorCol).isDefined && dataset.columns.contains(getValidationIndicatorCol))
        (df.filter(x => !x.getBoolean(x.fieldIndex(getValidationIndicatorCol))),
          Some(sc.broadcast(collectValidationData(df, measures))))
      else (df, None)

    val preprocessedDF = preprocessData(trainingData)

    val (numCols, numInitScoreClasses) = calculateColumnStatistics(preprocessedDF, measures)

    val featuresSchema = dataset.schema(getFeaturesCol)
    val generalTrainParams: BaseTrainParams = getTrainParams(numTasks, featuresSchema, numTasksPerExecutor)
    val trainParams = addCustomTrainParams(generalTrainParams, dataset)
    log.info(s"LightGBM batch $batchIndex of $batchCount, parameters: ${trainParams.toString()}")

    val isStreamingMode = getDataTransferMode == LightGBMConstants.StreamingDataTransferMode
    val (serializedReferenceDataset: Option[Array[Byte]], partitionCounts: Option[Array[Long]]) =
      if (isStreamingMode) {
        val (referenceDataset, partitionCounts) =
          calculateRowStatistics(trainingData, trainParams, numCols, measures)

        // Save the reference Dataset so it's available to client and other batches
        if (getReferenceDataset.isEmpty) {
          log.info(s"Saving reference dataset of length: ${referenceDataset.length}")
          setReferenceDataset(referenceDataset)
        }
        (Some(referenceDataset), Some(partitionCounts))
      } else (None, None)

    validateSlotNames(featuresSchema)
    executeTraining(preprocessedDF,
                    validationData,
                    serializedReferenceDataset,
                    partitionCounts,
                    trainParams,
                    numCols,
                    numInitScoreClasses,
                    batchIndex,
                    numTasks,
                    numTasksPerExecutor,
                    measures)
  }

  private def determineNumTasks(dataset: Dataset[_], configNumTasks: Int, numTasksPerExecutor: Int) = {
    // By default, we try to intelligently calculate the number of executors, but user can override this with numTasks
    if (configNumTasks > 0) configNumTasks
    else {
      val numExecutorTasks = ClusterUtil.getNumExecutorTasks(dataset.sparkSession, numTasksPerExecutor, log)
      min(numExecutorTasks, dataset.rdd.getNumPartitions)
    }
  }

  /**
    * Get the validation data from the initial dataset.
    *
    * @param df The dataset to train on.
    * @return The number of feature columns and initial score classes
    */
  private def collectValidationData(df: DataFrame, measures: InstrumentationMeasures): Array[Row] = {
    measures.markValidDataCollectionStart()
    val data = preprocessData(df.filter(x =>
      x.getBoolean(x.fieldIndex(getValidationIndicatorCol)))).collect()
    measures.markValidDataCollectionStop()
    data
  }

  /**
    * Extract column counts from the dataset.
    *
    * @param dataframe The dataset to train on.
    * @return The number of feature columns and initial score classes
    */
  private def calculateColumnStatistics(dataframe: DataFrame, measures: InstrumentationMeasures): (Int, Int) = {
    measures.markColumnStatisticsStart()
    // Use the first row to get the column count
    val firstRow: Row = dataframe.first()

    val featureData = firstRow.getAs[Any](getFeaturesCol)
    val numCols = featureData match {
      case sparse: SparseVector => sparse.size
      case dense: DenseVector => dense.size
      case _ => throw new IllegalArgumentException("Unknown row data type to push")
    }

    val numInitScoreClasses =
      if (get(initScoreCol).isEmpty) 1
      else if (dataframe.schema(getInitScoreCol).dataType == VectorType)
        firstRow.getAs[DenseVector](getInitScoreCol).size
      else 1

    measures.markColumnStatisticsStop()
    (numCols, numInitScoreClasses)
  }

  /**
    * Calculate row statistics for streaming mode. Gets reference data set and partition counts.
    *
    * @param dataframe The dataset to train on.
    * @param trainingParams The training parameters.
    * @param numCols The number of feature columns.
    * @param measures Instrumentation measures.
    * @return The serialized Dataset reference and an array of partition counts.
    */
  private def calculateRowStatistics(dataframe: DataFrame,
                                     trainingParams: BaseTrainParams,
                                     numCols: Int,
                                     measures: InstrumentationMeasures): (Array[Byte], Array[Long]) = {
    measures.markRowStatisticsStart()

    // Get the row counts per partition
    measures.markRowCountsStart()
    // Get an array where the index is implicitly the partition id
    val rowCounts: Array[Long] = ClusterUtil.getNumRowsPerPartition(dataframe, dataframe.col(getLabelCol))
    val totalNumRows = rowCounts.sum
    measures.markRowCountsStop()

    val datasetParams = getDatasetCreationParams(
      trainingParams.generalParams.categoricalFeatures,
      trainingParams.executionParams.numThreads)

    // Either get a reference dataset (as bytes) from params, or calculate it
    val precalculatedDataset = getReferenceDataset
    val serializedReference = if (precalculatedDataset.nonEmpty) {
      log.info(s"Using precalculated reference Dataset of length: ${precalculatedDataset.length}")
      precalculatedDataset
    } else {
      // Get sample data rows
      measures.markSamplingStart()
      val collectedSampleData = getSampledRows(dataframe, totalNumRows, trainingParams)
      log.info(s"Using ${collectedSampleData.length} sample rows")
      measures.markSamplingStop()

      ReferenceDatasetUtils.createReferenceDatasetFromSample(
        datasetParams,
        getFeaturesCol,
        totalNumRows,
        numCols,
        collectedSampleData,
        measures,
        log)
    }

    measures.markRowStatisticsStop()
    (serializedReference, rowCounts)
  }

  /**
  * Run a parallel job via map partitions to initialize the native library and network,
  * translate the data to the LightGBM in-memory representation and train the models.
  *
  * @param dataframe The dataset to train on.
  * @param validationData The dataset to use as validation. (optional)
  * @param serializedReferenceDataset The serialized reference dataset (optional).
  * @param partitionCounts The count per partition for streaming mode (optional).
  * @param trainParams Training parameters.
  * @param numCols Number of columns.
  * @param numInitValueClasses Number of classes for initial values (used only for multiclass).
  * @param batchIndex In running in batch training mode, gets the batch number.
  * @param numTasks Number of tasks/partitions.
  * @param numTasksPerExecutor Number of tasks per executor.
  * @param measures Instrumentation measures to populate.
  * @return The LightGBM Model from the trained LightGBM Booster.
  */
  private def executeTraining(dataframe: DataFrame,
                                validationData: Option[Broadcast[Array[Row]]],
                                serializedReferenceDataset: Option[Array[Byte]],
                                partitionCounts: Option[Array[Long]],
                                trainParams: BaseTrainParams,
                                numCols: Int,
                                numInitValueClasses: Int,
                                batchIndex: Int,
                                numTasks: Int,
                                numTasksPerExecutor: Int,
                                measures: InstrumentationMeasures): TrainedModel = {
    val networkManager = NetworkManager.create(numTasks,
                                               dataframe.sparkSession,
                                               getDriverListenPort,
                                               getTimeout,
                                               getUseBarrierExecutionMode)
    val ctx = getTrainingContext(dataframe,
                                 validationData,
                                 serializedReferenceDataset,
                                 partitionCounts,
                                 trainParams,
                                 getSlotNamesWithMetadata(dataframe.schema(getFeaturesCol)),
                                 numCols,
                                 numInitValueClasses,
                                 batchIndex,
                                 numTasksPerExecutor,
                                 networkManager)

    // Execute the Tasks on workers
    val lightGBMBooster = executePartitionTasks(ctx, dataframe, measures)

    // Wait for network to complete (should be done by now)
    networkManager.waitForNetworkCommunicationsDone()
    measures.markTrainingStop()
    val model = getModel(trainParams, lightGBMBooster)
    measures.markExecutionEnd()
    model
  }

  private def executePartitionTasks(ctx: TrainingContext,
                                      dataframe: DataFrame,
                                      measures: InstrumentationMeasures): LightGBMBooster = {
    // Create the object that will manage the mapPartitions function
    val workerTaskHandler: BasePartitionTask =
      if (ctx.isStreaming) new StreamingPartitionTask()
      else new BulkPartitionTask()
    val mapPartitionsFunc = workerTaskHandler.mapPartitionTask(ctx)(_)

    val encoder = Encoders.kryo[PartitionResult]
    measures.markTrainingStart()
    val results: Array[PartitionResult] =
      if (getUseBarrierExecutionMode)
        dataframe.rdd.barrier().mapPartitions(mapPartitionsFunc).collect()
      else
        dataframe.mapPartitions(mapPartitionsFunc)(encoder).collect()

    val lightGBMBooster = results.flatMap(r => r.booster).head
    measures.setTaskMeasures(results.map(r => r.taskMeasures))
    lightGBMBooster
  }

  /**
    * Get the object that holds all relevant context information for the training session.
    *
    * @param dataframe    The dataset to train on.
    * @param validationData The dataset to use as validation. (optional)
    * @param serializedReferenceDataset The serialized reference Dataset. (optional).
    * @param partitionCounts The count per partition for streaming mode (optional).
    * @param trainParams Training parameters.
    * @param numCols Number of columns.
    * @param numInitValueClasses Number of classes for initial values (used only for multiclass).
    * @param batchIndex In running in batch training mode, gets the batch number.
    * @param numTasksPerExecutor Number of tasks per executor.
    * @param networkManager The network manager.
    * @return The context of the training session.
    */
  private def getTrainingContext(dataframe: DataFrame,
                                   validationData: Option[Broadcast[Array[Row]]],
                                   serializedReferenceDataset: Option[Array[Byte]],
                                   partitionCounts: Option[Array[Long]],
                                   trainParams: BaseTrainParams,
                                   featureNames: Option[Array[String]],
                                   numCols: Int,
                                   numInitValueClasses: Int,
                                   batchIndex: Int,
                                   numTasksPerExecutor: Int,
                                   networkManager: NetworkManager): TrainingContext = {
    val networkParams = NetworkParams(
      getDefaultListenPort,
      networkManager.host,
      networkManager.port,
      networkManager.useBarrierExecutionMode)
    val sharedState = SharedSingleton(new SharedState(trainParams))
    val datasetParams = getDatasetCreationParams(
      trainParams.generalParams.categoricalFeatures,
      trainParams.executionParams.numThreads)

    TrainingContext(batchIndex, // TODO log this context?
                    sharedState,
                    dataframe.schema,
                    numCols,
                    numInitValueClasses,
                    trainParams,
                    networkParams,
                    getColumnParams,
                    datasetParams,
                    featureNames,
                    numTasksPerExecutor,
                    validationData,
                    serializedReferenceDataset,
                    partitionCounts)
  }

  /** Optional group column for Ranking, set to None by default.
    *
    * @return None
    */
  protected def getOptGroupCol: Option[String] = None

  /** Gets the trained model given the train parameters and booster.
    *
    * @return trained model.
    */
  protected def getModel(trainParams: BaseTrainParams, lightGBMBooster: LightGBMBooster): TrainedModel

  /** Gets the training parameters.
    *
    * @param numTasks        The total number of tasks.
    * @param featuresSchema  The features column schema.
    * @param numTasksPerExec The number of tasks per executor.
    * @return train parameters.
    */
  protected def getTrainParams(numTasks: Int,
                               featuresSchema: StructField,
                               numTasksPerExec: Int): BaseTrainParams

  protected def addCustomTrainParams(params: BaseTrainParams, dataset: Dataset[_]): BaseTrainParams = params

  protected def stringFromTrainedModel(model: TrainedModel): String

  /** Allow algorithm specific preprocessing of dataset.
    *
    * @param df The dataframe to preprocess prior to training.
    * @return The preprocessed data.
    */
  protected def preprocessData(df: DataFrame): DataFrame = df

  /**
    * Creates an array of Rows to use as sample data.
    *
    * @param dataframe      The dataset to train on.
    * @param totalNumRows   The total number of rows in the dataset.
    * @param trainingParams The training parameters.
    * @return The serialized Dataset reference and an array of partition counts.
    */
  private def getSampledRows(dataframe: DataFrame,
                             totalNumRows: Long,
                             trainingParams: BaseTrainParams): Array[Row] = {
    val sampleCount: Int = getBinSampleCount
    val seed: Int = getSeedParams.dataRandomSeed.getOrElse(0)
    val featureColName = getFeaturesCol
    val fraction = if (sampleCount > totalNumRows) 1.0
    else Math.min(1.0, (sampleCount.toDouble + 10000) / totalNumRows)
    val numSamples = Math.min(sampleCount, totalNumRows).toInt
    val samplingSubsetSize = Math.min(trainingParams.executionParams.samplingSetSize, numSamples)

    val samplingMode = trainingParams.executionParams.samplingMode
    log.info(s"Using sampling mode: $samplingMode (if subset, size is $samplingSubsetSize)")
    samplingMode match {
      case LightGBMConstants.SubsetSamplingModeGlobal =>
        // sample randomly from all rows (expensive for large data sets)
        dataframe.select(dataframe.col(featureColName)).sample(fraction, seed).limit(numSamples).collect()
      case LightGBMConstants.SubsetSamplingModeSubset =>
        // sample randomly from first 'samplingSetSize' rows (optimization to save time on large data sets)
        dataframe.select(dataframe.col(featureColName)).limit(samplingSubsetSize).sample(fraction, seed).collect()
      case LightGBMConstants.SubsetSamplingModeFixed =>
        // just take first 'N' rows.  Quick but assumes data already randomized and representative.
        dataframe.select(dataframe.col(featureColName)).limit(numSamples).collect()
      case _ => throw new NotImplementedError(s"Unknown sampling mode: $samplingMode")
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy