All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.scala.spark.XGBoost.scala Maven / Gradle / Ivy

The newest version!
/*
 Copyright (c) 2014-2024 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.spark

import java.io.File

import scala.collection.mutable
import scala.util.Random
import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.FileSystem

import org.apache.spark.rdd.RDD
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.sql.SparkSession

/**
 * Rabit tracker configurations.
 *
 * @param timeout The number of seconds before timeout waiting for workers to connect. and
 *                for the tracker to shutdown.
 * @param hostIp The Rabit Tracker host IP address.
 *               This is only needed if the host IP cannot be automatically guessed.
 * @param port The port number for the tracker to listen to. Use a system allocated one by
 *             default.
 */
case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)

object TrackerConf {
  def apply(): TrackerConf = TrackerConf(0)
}

private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)

private[scala] case class XGBoostExecutionParams(
    numWorkers: Int,
    numRounds: Int,
    useExternalMemory: Boolean,
    obj: ObjectiveTrait,
    eval: EvalTrait,
    missing: Float,
    allowNonZeroForMissing: Boolean,
    trackerConf: TrackerConf,
    checkpointParam: Option[ExternalCheckpointParams],
    xgbInputParams: XGBoostExecutionInputParams,
    earlyStoppingRounds: Int,
    cacheTrainingSet: Boolean,
    device: Option[String],
    isLocal: Boolean,
    featureNames: Option[Array[String]],
    featureTypes: Option[Array[String]],
    runOnGpu: Boolean) {

  private var rawParamMap: Map[String, Any] = _

  def setRawParamMap(inputMap: Map[String, Any]): Unit = {
    rawParamMap = inputMap
  }

  def toMap: Map[String, Any] = {
    rawParamMap
  }
}

private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){

  private val logger = LogFactory.getLog("XGBoostSpark")

  private val isLocal = sc.isLocal

  private val overridedParams = overrideParams(rawParams, sc)

  validateSparkSslConf()

  /**
   * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
   * If so, throw an exception unless this safety measure has been explicitly overridden
   * via conf `xgboost.spark.ignoreSsl`.
   */
  private def validateSparkSslConf(): Unit = {
    val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
      SparkSession.getActiveSession match {
        case Some(ss) =>
          (ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
            ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
        case None =>
          (sc.getConf.getBoolean("spark.ssl.enabled", false),
            sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
      }
    if (sparkSslEnabled) {
      if (xgboostSparkIgnoreSsl) {
        logger.warn(s"spark-xgboost is being run without encrypting data in transit!  " +
          s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
      } else {
        throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
          "in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
          "To override this protection and still use xgboost-spark at your own risk, " +
          "you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
      }
    }
  }

  /**
   * we should not include any nested structure in the output of this function as the map is
   * eventually to be feed to xgboost4j layer
   */
  private def overrideParams(
      params: Map[String, Any],
      sc: SparkContext): Map[String, Any] = {
    val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
    var overridedParams = params
    if (overridedParams.contains("nthread")) {
      val nThread = overridedParams("nthread").toString.toInt
      require(nThread <= coresPerTask,
        s"the nthread configuration ($nThread) must be no larger than " +
          s"spark.task.cpus ($coresPerTask)")
    } else {
      overridedParams = overridedParams + ("nthread" -> coresPerTask)
    }

    val numEarlyStoppingRounds = overridedParams.getOrElse(
      "num_early_stopping_rounds", 0).asInstanceOf[Int]
    overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
    if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
        throw new IllegalArgumentException("custom_eval does not support early stopping")
    }
    overridedParams
  }

  /**
   * The Map parameters accepted by estimator's constructor may have string type,
   * Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
   * kind of parameters into the correct type in the function.
   *
   * @return XGBoostExecutionParams
   */
  def buildXGBRuntimeParams: XGBoostExecutionParams = {

    val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
    val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
    if (obj != null) {
      require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
        "is not defined, you have to specify the objective type as classification or regression" +
        " with a customized objective function")
    }

    var trainTestRatio = 1.0
    if (overridedParams.contains("train_test_ratio")) {
      logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
        " pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
        "'eval_set_names'")
      trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
    }

    val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
    val round = overridedParams("num_round").asInstanceOf[Int]
    val useExternalMemory = overridedParams
      .getOrElse("use_external_memory", false).asInstanceOf[Boolean]

    val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
    val allowNonZeroForMissing = overridedParams
                                 .getOrElse("allow_non_zero_for_missing", false)
                                 .asInstanceOf[Boolean]

    val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
    val device: Option[String] = overridedParams.get("device").map(_.toString)
    val deviceIsGpu = device.exists(_ == "cuda")

    require(!(treeMethod.exists(_ == "approx") && deviceIsGpu),
      "The tree method \"approx\" is not yet supported for Spark GPU cluster")

    // back-compatible with "gpu_hist"
    val runOnGpu = treeMethod.exists(_ == "gpu_hist") || deviceIsGpu

    val trackerConf = overridedParams.get("tracker_conf") match {
      case None => TrackerConf()
      case Some(conf: TrackerConf) => conf
      case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
        "instance of TrackerConf.")
    }

    val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)

    val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
    val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)

    val earlyStoppingRounds = overridedParams.getOrElse(
      "num_early_stopping_rounds", 0).asInstanceOf[Int]

    val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
      .asInstanceOf[Boolean]

    val featureNames = if (overridedParams.contains("feature_names")) {
      Some(overridedParams("feature_names").asInstanceOf[Array[String]])
    } else None
    val featureTypes = if (overridedParams.contains("feature_types")){
      Some(overridedParams("feature_types").asInstanceOf[Array[String]])
    } else None

    val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
      missing, allowNonZeroForMissing, trackerConf,
      checkpointParam,
      inputParams,
      earlyStoppingRounds,
      cacheTrainingSet,
      device,
      isLocal,
      featureNames,
      featureTypes,
      runOnGpu
    )
    xgbExecParam.setRawParamMap(overridedParams)
    xgbExecParam
  }
}

/**
 * A trait to manage stage-level scheduling
 */
private[spark] trait XGBoostStageLevel extends Serializable {
  private val logger = LogFactory.getLog("XGBoostSpark")

  private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
    val master = conf.get("spark.master")
    master != null && (master.startsWith("spark://") || master.startsWith("local-cluster"))
  }

  /**
   * To determine if stage-level scheduling should be skipped according to the spark version
   * and spark configurations
   *
   * @param sparkVersion spark version
   * @param runOnGpu     if xgboost training run on GPUs
   * @param conf         spark configurations
   * @return Boolean to skip stage-level scheduling or not
   */
  private[spark] def skipStageLevelScheduling(
      sparkVersion: String,
      runOnGpu: Boolean,
      conf: SparkConf): Boolean = {
    if (runOnGpu) {
      if (sparkVersion < "3.4.0") {
        logger.info("Stage-level scheduling in xgboost requires spark version 3.4.0+")
        return true
      }

      if (!isStandaloneOrLocalCluster(conf)) {
        logger.info("Stage-level scheduling in xgboost requires spark standalone or " +
          "local-cluster mode")
        return true
      }

      val executorCores = conf.getInt("spark.executor.cores", -1)
      val executorGpus = conf.getInt("spark.executor.resource.gpu.amount", -1)
      if (executorCores == -1 || executorGpus == -1) {
        logger.info("Stage-level scheduling in xgboost requires spark.executor.cores, " +
          "spark.executor.resource.gpu.amount to be set.")
        return true
      }

      if (executorCores == 1) {
        logger.info("Stage-level scheduling in xgboost requires spark.executor.cores > 1")
        return true
      }

      if (executorGpus > 1) {
        logger.info("Stage-level scheduling in xgboost will not work " +
          "when spark.executor.resource.gpu.amount > 1")
        return true
      }

      val taskGpuAmount = conf.getDouble("spark.task.resource.gpu.amount", -1.0).toFloat

      if (taskGpuAmount == -1.0) {
        // The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
        // but with stage-level scheduling, we can make training task grab the gpu.
        return false
      }

      if (taskGpuAmount == executorGpus.toFloat) {
        // spark.executor.resource.gpu.amount = spark.task.resource.gpu.amount
        // results in only 1 task running at a time, which may cause perf issue.
        return true
      }
      // We can enable stage-level scheduling
      false
    } else true // Skip stage-level scheduling for cpu training.
  }

  /**
   * Attempt to modify the task resources so that only one task can be executed
   * on a single executor simultaneously.
   *
   * @param sc  the spark context
   * @param rdd which rdd to be applied with new resource profile
   * @return the original rdd or the changed rdd
   */
  private[spark] def tryStageLevelScheduling(
      sc: SparkContext,
      xgbExecParams: XGBoostExecutionParams,
      rdd: RDD[(Booster, Map[String, Array[Float]])]
    ): RDD[(Booster, Map[String, Array[Float]])] = {

    val conf = sc.getConf
    if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
      return rdd
    }

    // Ensure executor_cores is not None
    val executor_cores = conf.getInt("spark.executor.cores", -1)
    if (executor_cores == -1) {
      throw new RuntimeException("Wrong spark.executor.cores")
    }

    // Spark-rapids is a GPU-acceleration project for Spark SQL.
    // When spark-rapids is enabled, we prevent concurrent execution of other ETL tasks
    // that utilize GPUs alongside training tasks in order to avoid GPU out-of-memory errors.
    val spark_plugins = conf.get("spark.plugins", " ")
    val spark_rapids_sql_enabled = conf.get("spark.rapids.sql.enabled", "true")

    // Determine the number of cores required for each task.
    val task_cores = if (spark_plugins.contains("com.nvidia.spark.SQLPlugin") &&
      spark_rapids_sql_enabled.toLowerCase == "true") {
      executor_cores
    } else {
      (executor_cores / 2) + 1
    }

    // Each training task requires cpu cores > total executor cores//2 + 1 to
    // ensure tasks are sent to different executors.
    // Note: We cannot use GPUs to limit concurrent tasks
    // due to https://issues.apache.org/jira/browse/SPARK-45527.
    val task_gpus = 1.0
    val treqs = new TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    val rp = new ResourceProfileBuilder().require(treqs).build()

    logger.info(s"XGBoost training tasks require the resource(cores=$task_cores, gpu=$task_gpus).")
    rdd.withResources(rp)
  }
}

object XGBoost extends XGBoostStageLevel {
  private val logger = LogFactory.getLog("XGBoostSpark")

  def getGPUAddrFromResources: Int = {
    val tc = TaskContext.get()
    if (tc == null) {
      throw new RuntimeException("Something wrong for task context")
    }
    val resources = tc.resources()
    if (resources.contains("gpu")) {
      val addrs = resources("gpu").addresses
      if (addrs.size > 1) {
        // TODO should we throw exception ?
        logger.warn("XGBoost only supports 1 gpu per worker")
      }
      // take the first one
      addrs.head.toInt
    } else {
      throw new RuntimeException("gpu is not allocated by spark, " +
        "please check if gpu scheduling is enabled")
    }
  }

  private def buildWatchesAndCheck(buildWatchesFun: () => Watches): Watches = {
    val watches = buildWatchesFun()
    // to workaround the empty partitions in training dataset,
    // this might not be the best efficient implementation, see
    // (https://github.com/dmlc/xgboost/issues/1277)
    if (!watches.toMap.contains("train")) {
      throw new XGBoostError(
        s"detected an empty partition in the training data, partition ID:" +
          s" ${TaskContext.getPartitionId()}")
    }
    watches
  }

  private def buildDistributedBooster(
      buildWatches: () => Watches,
      xgbExecutionParam: XGBoostExecutionParams,
      rabitEnv: java.util.Map[String, Object],
      obj: ObjectiveTrait,
      eval: EvalTrait,
      prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {

    var watches: Watches = null
    val taskId = TaskContext.getPartitionId().toString
    val attempt = TaskContext.get().attemptNumber.toString
    rabitEnv.put("DMLC_TASK_ID", taskId)
    val numRounds = xgbExecutionParam.numRounds
    val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0

    try {
      Communicator.init(rabitEnv)

      watches = buildWatchesAndCheck(buildWatches)

      val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
      val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
      val externalCheckpointParams = xgbExecutionParam.checkpointParam

      var params = xgbExecutionParam.toMap
      if (xgbExecutionParam.runOnGpu) {
        val gpuId = if (xgbExecutionParam.isLocal) {
          // For local mode, force gpu id to primary device
          0
        } else {
          getGPUAddrFromResources
        }
        logger.info("Leveraging gpu device " + gpuId + " to train")
        params = params + ("device" -> s"cuda:$gpuId")
      }

      val booster = if (makeCheckpoint) {
        SXGBoost.trainAndSaveCheckpoint(
          watches.toMap("train"), params, numRounds,
          watches.toMap, metrics, obj, eval,
          earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
      } else {
        SXGBoost.train(watches.toMap("train"), params, numRounds,
          watches.toMap, metrics, obj, eval,
          earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
      }
      if (TaskContext.get().partitionId() == 0) {
        Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
      } else {
        Iterator.empty
      }
    } catch {
      case xgbException: XGBoostError =>
        logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
        throw xgbException
    } finally {
      Communicator.shutdown()
      if (watches != null) watches.delete()
    }
  }

  // Executes the provided code block inside a tracker and then stops the tracker
  private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
    val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
    require(tracker.start(), "FAULT: Failed to start tracker")
    try {
      block(tracker)
    } finally {
      tracker.stop()
    }
  }

  /**
   * @return A tuple of the booster and the metrics used to build training summary
   */
  @throws(classOf[XGBoostError])
  private[spark] def trainDistributed(
      sc: SparkContext,
      buildTrainingData: XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]),
      params: Map[String, Any]):
    (Booster, Map[String, Array[Float]]) = {

    logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")

    val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
    val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams

    val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
      val checkpointManager = new ExternalCheckpointManager(
        checkpointParam.checkpointPath,
        FileSystem.get(sc.hadoopConfiguration))
      checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
      checkpointManager.loadCheckpointAsScalaBooster()
    }.orNull

    // Get the training data RDD and the cachedRDD
    val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)

    try {
      val (booster, metrics) = withTracker(
        runtimeParams.numWorkers,
        runtimeParams.trackerConf
      ) { tracker =>
        val rabitEnv = tracker.getWorkerArgs()

        val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
          var optionWatches: Option[() => Watches] = None

          // take the first Watches to train
          if (iter.hasNext) {
            optionWatches = Some(iter.next())
          }

          optionWatches.map { buildWatches =>
              buildDistributedBooster(buildWatches,
                runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
            }.getOrElse(throw new RuntimeException("No Watches to train"))
        }

        val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
          boostersAndMetrics)
        // The repartition step is to make training stage as ShuffleMapStage, so that when one
        // of the training task fails the training stage can retry. ResultStage won't retry when
        // it fails.
        val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
        (booster, metrics)
      }

      // we should delete the checkpoint directory after a successful training
      runtimeParams.checkpointParam.foreach {
        cpParam =>
          if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
            val checkpointManager = new ExternalCheckpointManager(
              cpParam.checkpointPath,
              FileSystem.get(sc.hadoopConfiguration))
            checkpointManager.cleanPath()
          }
      }
      (booster, metrics)
    } catch {
      case t: Throwable =>
        // if the job was aborted due to an exception
        logger.error("the job was aborted due to ", t)
        throw t
    } finally {
      optionalCachedRDD.foreach(_.unpersist())
    }
  }

}

class Watches private[scala] (
    val datasets: Array[DMatrix],
    val names: Array[String],
    val cacheDirName: Option[String]) {

  def toMap: Map[String, DMatrix] = {
    names.zip(datasets).toMap.filter { case (_, matrix) => matrix.rowNum > 0 }
  }

  def size: Int = toMap.size

  def delete(): Unit = {
    toMap.values.foreach(_.delete())
    cacheDirName.foreach { name =>
      FileUtils.deleteDirectory(new File(name))
    }
  }

  override def toString: String = toMap.toString
}

private object Watches {

  private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
    val builder = new mutable.ArrayBuilder.ofFloat()
    var nTotal = 0
    var nUndefined = 0
    while (baseMargins.hasNext) {
      nTotal += 1
      val baseMargin = baseMargins.next()
      if (baseMargin.isNaN) {
        nUndefined += 1  // don't waste space for all-NaNs.
      } else {
        builder += baseMargin
      }
    }
    if (nUndefined == nTotal) {
      None
    } else if (nUndefined == 0) {
      Some(builder.result())
    } else {
      throw new IllegalArgumentException(
        s"Encountered a partition with $nUndefined NaN base margin values. " +
          s"If you want to specify base margin, ensure all values are non-NaN.")
    }
  }

  def buildWatches(
      nameAndLabeledPointSets: Iterator[(String, Iterator[XGBLabeledPoint])],
      cachedDirName: Option[String]): Watches = {
    val dms = nameAndLabeledPointSets.map {
      case (name, labeledPoints) =>
        val baseMargins = new mutable.ArrayBuilder.ofFloat
        val duplicatedItr = labeledPoints.map(labeledPoint => {
          baseMargins += labeledPoint.baseMargin
          labeledPoint
        })
        val dMatrix = new DMatrix(duplicatedItr, cachedDirName.map(_ + s"/$name").orNull)
        val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
        if (baseMargin.isDefined) {
          dMatrix.setBaseMargin(baseMargin.get)
        }
        (name, dMatrix)
    }.toArray
    new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
  }

  def buildWatches(
      xgbExecutionParams: XGBoostExecutionParams,
      labeledPoints: Iterator[XGBLabeledPoint],
      cacheDirName: Option[String]): Watches = {
    val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
    val seed = xgbExecutionParams.xgbInputParams.seed
    val r = new Random(seed)
    val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
    val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
    val testBaseMargins = new mutable.ArrayBuilder.ofFloat
    val trainPoints = labeledPoints.filter { labeledPoint =>
      val accepted = r.nextDouble() <= trainTestRatio
      if (!accepted) {
        testPoints += labeledPoint
        testBaseMargins += labeledPoint.baseMargin
      } else {
        trainBaseMargins += labeledPoint.baseMargin
      }
      accepted
    }
    val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
    val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)

    val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
    val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
    if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
    if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)

    if (xgbExecutionParams.featureNames.isDefined) {
      trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
      testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
    }

    if (xgbExecutionParams.featureTypes.isDefined) {
      trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
      testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
    }

    new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
  }

  def buildWatchesWithGroup(
      nameAndlabeledPointGroupSets: Iterator[(String, Iterator[Array[XGBLabeledPoint]])],
      cachedDirName: Option[String]): Watches = {
    val dms = nameAndlabeledPointGroupSets.map {
      case (name, labeledPointsGroups) =>
        val baseMargins = new mutable.ArrayBuilder.ofFloat
        val groupsInfo = new mutable.ArrayBuilder.ofInt
        val weights = new mutable.ArrayBuilder.ofFloat
        val iter = labeledPointsGroups.filter(labeledPointGroup => {
          var groupWeight = -1.0f
          var groupSize = 0
          labeledPointGroup.map { labeledPoint => {
            if (groupWeight < 0) {
              groupWeight = labeledPoint.weight
            } else if (groupWeight != labeledPoint.weight) {
              throw new IllegalArgumentException("the instances in the same group have to be" +
                s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
            }
            baseMargins += labeledPoint.baseMargin
            groupSize += 1
            labeledPoint
          }
          }
          weights += groupWeight
          groupsInfo += groupSize
          true
        })
        val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
        val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
        if (baseMargin.isDefined) {
          dMatrix.setBaseMargin(baseMargin.get)
        }
        dMatrix.setGroup(groupsInfo.result())
        dMatrix.setWeight(weights.result())
        (name, dMatrix)
    }.toArray
    new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
  }

  def buildWatchesWithGroup(
      xgbExecutionParams: XGBoostExecutionParams,
      labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
      cacheDirName: Option[String]): Watches = {
    val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
    val seed = xgbExecutionParams.xgbInputParams.seed
    val r = new Random(seed)
    val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
    val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
    val testBaseMargins = new mutable.ArrayBuilder.ofFloat

    val trainGroups = new mutable.ArrayBuilder.ofInt
    val testGroups = new mutable.ArrayBuilder.ofInt

    val trainWeights = new mutable.ArrayBuilder.ofFloat
    val testWeights = new mutable.ArrayBuilder.ofFloat

    val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
      val accepted = r.nextDouble() <= trainTestRatio
      if (!accepted) {
        var groupWeight = -1.0f
        var groupSize = 0
        labeledPointGroup.foreach(labeledPoint => {
          testPoints += labeledPoint
          testBaseMargins += labeledPoint.baseMargin
          if (groupWeight < 0) {
            groupWeight = labeledPoint.weight
          } else if (labeledPoint.weight != groupWeight) {
            throw new IllegalArgumentException("the instances in the same group have to be" +
              s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
          }
          groupSize += 1
        })
        testWeights += groupWeight
        testGroups += groupSize
      } else {
        var groupWeight = -1.0f
        var groupSize = 0
        labeledPointGroup.foreach { labeledPoint => {
          if (groupWeight < 0) {
            groupWeight = labeledPoint.weight
          } else if (labeledPoint.weight != groupWeight) {
            throw new IllegalArgumentException("the instances in the same group have to be" +
              s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
          }
          trainBaseMargins += labeledPoint.baseMargin
          groupSize += 1
        }}
        trainWeights += groupWeight
        trainGroups += groupSize
      }
      accepted
    }

    val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
    val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
    trainMatrix.setGroup(trainGroups.result())
    trainMatrix.setWeight(trainWeights.result())

    val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
    if (trainTestRatio < 1.0) {
      testMatrix.setGroup(testGroups.result())
      testMatrix.setWeight(testWeights.result())
    }

    val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
    val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
    if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
    if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)

    if (xgbExecutionParams.featureNames.isDefined) {
      trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
      testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
    }
    if (xgbExecutionParams.featureTypes.isDefined) {
      trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
      testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
    }

    new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy