All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.train.GBDTTrainer.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.train

import com.tencent.angel.sona.tree.gbdt.GBDTConf._
import org.apache.hadoop.fs.Path
import com.tencent.angel.sona.tree.gbdt.GBDTModel
import com.tencent.angel.sona.tree.gbdt.GBDTModel.GBTTree
import com.tencent.angel.sona.tree.gbdt.dataset.Dataset
import com.tencent.angel.sona.tree.gbdt.helper._
import com.tencent.angel.sona.tree.gbdt.histogram.GradPair
import com.tencent.angel.sona.tree.gbdt.metadata.{FeatureInfo, InstanceInfo}
import com.tencent.angel.sona.tree.gbdt.tree.{GBDTParam, GBTNode, GBTSplit}
import com.tencent.angel.sona.tree.objective.ObjectiveFactory
import com.tencent.angel.sona.tree.objective.metric.EvalMetric
import com.tencent.angel.sona.tree.regression.RegTParam
import com.tencent.angel.sona.tree.stats.hash.Mix64Hash
import com.tencent.angel.sona.tree.stats.quantile.QuantileSketch
import com.tencent.angel.sona.tree.util.MathUtil
import org.apache.spark.{SparkConf, SparkContext, SparkEnv}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.util.SparkUtil

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => AB}

object GBDTTrainer {
  private[gbdt] val DATA_PARALLEL_MODE = "dp"
  private[gbdt] val FEATURE_PARALLEL_MODE = "fp"
  private[gbdt] val DEFAULT_PARALLEL_MODE = FEATURE_PARALLEL_MODE
  // TODO: add 'auto' parallel mode

  def main(args: Array[String]): Unit = {
    // spark conf
    @transient val conf = new SparkConf()
    require(!conf.getBoolean("spark.dynamicAllocation.enabled", false),
      "'spark.dynamicAllocation.enabled' should not be true")
    val numExecutors = conf.get("spark.executor.instances").toInt
    val numCores = conf.get("spark.executor.cores").toInt
    conf.set("spark.task.cpus", numCores.toString)
    conf.set("spark.locality.wait", "0")
    conf.set("spark.memory.fraction", "0.7")
    conf.set("spark.memory.storageFraction", "0.8")
    conf.set("spark.rpc.message.maxSize", "2000")
    conf.set("spark.scheduler.maxRegisteredResourcesWaitingTime", "600s")
    conf.set("spark.scheduler.minRegisteredResourcesRatio", "1.0")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.yarn.maxAppAttempts", "1")
    conf.set("spark.network.timeout", "1000")
    conf.set("spark.executor.heartbeatInterval", "500")
    @transient implicit val sc = SparkContext.getOrCreate(conf)

    // parse hyper-parameters
    val params = SparkUtil.parse(args)
    val isRegression = params.getOrElse(ML_TASK_TYPE, DEFAULT_ML_TASK_TYPE) match {
      case "classification" => false
      case "regression" => true
      case s => throw new IllegalArgumentException(s"Unsupported task type: $s, " +
        s"please input 'classification' or 'regression'")
    }
    val numClass = if (isRegression) 2 else params.getOrElse(ML_NUM_CLASS, DEFAULT_ML_NUM_CLASS.toString).toInt
    val (fullHessian, multiTree) = if (!isRegression && numClass > 2) {
      val fh = params.getOrElse(ML_GBDT_FULL_HESSIAN, DEFAULT_ML_GBDT_FULL_HESSIAN.toString).toBoolean
      val mt = params.getOrElse(ML_GBDT_MULTI_TREE, DEFAULT_ML_GBDT_MULTI_TREE.toString).toBoolean
      require(!(fh && mt), "For multi-class tasks, full-hessian and multi-tree are exclusive")
      (fh, mt)
    } else {
      (false, false)
    }
    val numFeature = params(ML_NUM_FEATURE).toInt
    val numRound = params.getOrElse(ML_GBDT_ROUND_NUM, DEFAULT_ML_GBDT_ROUND_NUM.toString).toInt
    val initLearningRate = params.getOrElse(ML_INIT_LEARN_RATE, DEFAULT_ML_INIT_LEARN_RATE.toString).toFloat
    val maxDepth = params.getOrElse(ML_GBDT_MAX_DEPTH, DEFAULT_ML_GBDT_MAX_DEPTH.toString).toInt
    val treeNodeNum = MathUtil.maxNodeNum(maxDepth)
    val maxNodeNum = params.getOrElse(ML_GBDT_MAX_NODE_NUM, treeNodeNum.toString).toInt min treeNodeNum
    val numSplit = params.getOrElse(ML_GBDT_SPLIT_NUM, DEFAULT_ML_GBDT_SPLIT_NUM.toString).toInt
    val minChildWeight = params.getOrElse(ML_GBDT_MIN_CHILD_WEIGHT, DEFAULT_ML_GBDT_MIN_CHILD_WEIGHT.toString).toFloat
    val leafwise = params.getOrElse(ML_GBDT_LEAF_WISE, DEFAULT_ML_GBDT_LEAF_WISE.toString).toBoolean
    val minNodeInstance = params.getOrElse(ML_GBDT_MIN_NODE_INSTANCE, DEFAULT_ML_GBDT_MIN_NODE_INSTANCE.toString).toInt
    val minSplitGain = params.getOrElse(ML_GBDT_MIN_SPLIT_GAIN, DEFAULT_ML_GBDT_MIN_SPLIT_GAIN.toString).toFloat
    val regAlpha = params.getOrElse(ML_GBDT_REG_ALPHA, DEFAULT_ML_GBDT_REG_ALPHA.toString).toFloat
    val regLambda = params.getOrElse(ML_GBDT_REG_LAMBDA, DEFAULT_ML_GBDT_REG_LAMBDA.toString).toFloat max 1.0f
    val maxLeafWeight = params.getOrElse(ML_GBDT_MAX_LEAF_WEIGHT, DEFAULT_ML_GBDT_MAX_LEAF_WEIGHT.toString).toFloat
    val insSampleRatio = params.getOrElse(ML_INSTANCE_SAMPLE_RATIO, DEFAULT_ML_INSTANCE_SAMPLE_RATIO.toString).toFloat
    val featSampleRatio = params.getOrElse(ML_FEATURE_SAMPLE_RATIO, DEFAULT_ML_FEATURE_SAMPLE_RATIO.toString).toFloat
    val (lossFunc, evalMetrics) = if (isRegression) {
      // for regression task, use RMSE loss
      ("rmse", Array("rmse"))
    } else {
      // get loss function
      val loss = params(ML_LOSS_FUNCTION)
      // ensure that the loss function fits #class & get default eval metric
      val defaultMetric = (if (numClass == 2) ObjectiveFactory.getBinaryLoss(loss)
      else ObjectiveFactory.getMultiLoss(loss)).defaultEvalMetric().toString
      // get eval metric
      var metrics = params.getOrElse(ML_EVAL_METRIC, defaultMetric)
        .split(",").map(_.trim).filter(_.nonEmpty)
      // we may schedule learning rate w.r.t. to default metric
      metrics = defaultMetric +: metrics.filter(_ != defaultMetric)
      // ensure that all evaluation metrics fit #class
      val labels = Array.ofDim[Float](1)
      val preds = Array.ofDim[Float](labels.length * (if (numClass == 2) 1 else numClass))
      ObjectiveFactory.getEvalMetrics(metrics).foreach(_.eval(preds, labels))
      (loss, metrics)
    }

    // RegTParam
    val regTParam = new RegTParam()
      .setLeafwise(leafwise)
      .setMinSplitGain(minSplitGain)
      .setMinNodeInstance(minNodeInstance)
    // TParam
    regTParam.setNumFeature(numFeature)
      .setMaxDepth(maxDepth)
      .setMaxNodeNum(maxNodeNum)
      .setNumSplit(numSplit)
      .setInsSampleRatio(insSampleRatio)
      .setFeatSampleRatio(featSampleRatio)
    // GBDTParam
    val param = new GBDTParam(regTParam)
      .setRegression(isRegression)
      .setNumClass(numClass)
      .setFullHessian(fullHessian)
      .setMultiTree(multiTree)
      .setNumRound(numRound)
      .setInitLearningRate(initLearningRate)
      .setMinChildWeight(minChildWeight)
      .setRegAlpha(regAlpha)
      .setRegLambda(regLambda)
      .setMaxLeafWeight(maxLeafWeight)
      .setLossFunc(lossFunc)
      .setEvalMetrics(evalMetrics)
    println(s"Hyper-parameters:\n$param")
    val trainer = params.getOrElse(ML_PARALLEL_MODE, DEFAULT_ML_PARALLEL_MODE) match {
      case FEATURE_PARALLEL_MODE =>
        println("Training in feature parallel mode")
        new FPGBDTTrainer(param)
      case DATA_PARALLEL_MODE =>
        println("Training in data parallel mode")
        new DPGBDTTrainer(param)
      case mode =>
        throw new IllegalArgumentException(s"Unrecognizable parallel mode: $mode")
    }

    val trainInput = params(ML_TRAIN_DATA_PATH)
    val validInput = params(ML_VALID_DATA_PATH)
    val modelPath = params(ML_MODEL_PATH)
    val importanceType = params.getOrElse(ML_GBDT_IMPORTANCE_TYPE, DEFAULT_ML_GBDT_IMPORTANCE_TYPE)
    FeatureImportance.ensureImportanceType(importanceType)

    val model = trainer.fit(trainInput, validInput, numExecutors, numCores)
    save(model, modelPath, importanceType)
  }

  def save(model: GBDTModel, basePath: String, importanceType: String)
                        (implicit sc: SparkContext): Unit = {
    val path = new Path(basePath)
    val fs = path.getFileSystem(sc.hadoopConfiguration)
    if (fs.exists(path)) fs.delete(path, true)

    val modelPath = basePath + "/model"
    println(s"Saving model to $modelPath")
    sc.parallelize(Seq(model), numSlices = 1).saveAsObjectFile(modelPath)

    val featurePath = basePath + "/feature_importance"
    println(s"Saving feature importance ($importanceType) to $featurePath")

    sc.parallelize(FeatureImportance.featImportance(model, importanceType)
      .map(pair => s"${pair._1} ${pair._2}"),
      numSlices = 1).saveAsTextFile(featurePath)
  }

  private[gbdt] def hashFeatureGrouping(numFeature: Int, numGroup: Int): (Array[Int], Array[Array[Int]]) = {
    val fidToGroupId = Array.ofDim[Int](numFeature)
    val buffers = Array.ofDim[AB.ofInt](numGroup)
    for (partId <- 0 until numGroup) {
      buffers(partId) = new AB.ofInt
      buffers(partId).sizeHint((1.5 * numFeature / numGroup).toInt)
    }
    val hasFunc = new Mix64Hash(numGroup)
    for (fid <- 0 until numFeature) {
      val groupId = hasFunc.hash(fid) % numGroup
      fidToGroupId(fid) = groupId
      buffers(groupId) += fid
    }
    val groupIdToFid = buffers.map(_.result())
    (fidToGroupId, groupIdToFid)
  }

  private[gbdt] def getSplits(sketch: QuantileSketch, numSplit: Int): (Boolean, Array[Float], Int) = {
    // return (isCategorical, candidate splits, nnz)
    if (sketch == null || sketch.isEmpty) {
      (false, null, 0)
    } else {
      val nnz = sketch.getN.toInt
      val distinct = sketch.tryDistinct(FeatureInfo.ENUM_THRESHOLD)
      if (distinct != null) {
        (true, distinct, nnz)
      } else {
        val tmp = MathUtil.unique(sketch.getQuantiles(numSplit))
        if (tmp.length > 1) {
          (false, tmp, nnz)
        } else {
          val s = tmp.head
          if (s > 0) (false, Array(0, s), nnz)
          else if (s < 0) (false, Array(s, 0), nnz)
          else (false, Array(0, MathUtil.EPSILON), nnz)
        }
      }
    }
  }
}

abstract class GBDTTrainer(param: GBDTParam) {

  def fit(trainInput: String, validInput: String, numWorker: Int, numThread: Int)
         (implicit sc: SparkContext): GBDTModel = {
    println(s"Training data path: $trainInput")
    println(s"Validation data path: $validInput")
    println(s"#workers[$numWorker], #threads[$numThread]")
    initialize(trainInput, validInput, numWorker, numThread)
    val model = train()
    shutdown()
    model
  }

  protected def initialize(trainInput: String, validInput: String,
                           numWorker: Int, numThread: Int)
                          (implicit sc: SparkContext): Unit = ???

  protected def shutdown(): Unit = ???

  protected def train()(implicit sc: SparkContext): GBDTModel = {
    println("Start to train a GBDT model")
    val model = new GBDTModel(param)
    val trainStart = System.currentTimeMillis()

    // tree/node meta data
    val numTreePerRound = if (param.isMultiClassMultiTree) param.numClass else 1
    val maxInnerNum = MathUtil.maxInnerNodeNum(param.regTParam.maxDepth)
    val activeNodes = ArrayBuffer[Int]()
    val toSplits = mutable.Map[Int, GBTSplit]()
    val nodeSizes = Array.ofDim[Int](MathUtil.maxNodeNum(param.regTParam.maxDepth))
    // lr scheduler
    val lrScheduler = if (!sc.getConf.getBoolean(ML_REDUCE_LR_ON_PLATEAU,
      DEFAULT_ML_REDUCE_LR_ON_PLATEAU)) {
      new LRScheduler(param.numRound, param.evalMetrics, patient = -1)
    } else {
      val patient = sc.getConf.getInt(ML_REDUCE_LR_ON_PLATEAU_PATIENT,
        DEFAULT_ML_REDUCE_LR_ON_PLATEAU_PATIENT)
      val threshold = sc.getConf.getDouble(ML_REDUCE_LR_ON_PLATEAU_THRESHOLD,
        DEFAULT_ML_REDUCE_LR_ON_PLATEAU_THRESHOLD.toDouble).toFloat
      val decayFactor = sc.getConf.getDouble(ML_REDUCE_LR_ON_PLATEAU_DECAY_FACTOR,
        DEFAULT_ML_REDUCE_LR_ON_PLATEAU_DECAY_FACTOR.toDouble).toFloat
      val earlyStop = sc.getConf.getInt(ML_REDUCE_LR_ON_PLATEAU_EARLY_STOP,
        DEFAULT_ML_REDUCE_LR_ON_PLATEAU_EARLY_STOP)
      println(s"Train with ReduceLROnPlateau strategy: patient[$patient] " +
        s"threshold[$threshold] decay factor[$decayFactor] early stop[$earlyStop]")
      new LRScheduler(param.numRound, param.evalMetrics, patient = patient,
        threshold = threshold, decayFactor = decayFactor, earlyStop = earlyStop)
    }

    var currentLR = param.initLearningRate
    var earlyStop = false
    var round = 0
    while (round < param.numRound && !earlyStop) {
      println(s"Start to train round[${round + 1}]")
      val roundStart = System.currentTimeMillis()

      // 1. calc grad pairs
      calcGradPairs()
      // 2. train trees in current round
      for (treeId <- 0 until numTreePerRound) {
        val treeStart = System.currentTimeMillis()
        activeNodes.clear()
        toSplits.clear()
        for (nid <- nodeSizes.indices)
          nodeSizes(nid) = -1
        val classIdOpt = if (numTreePerRound > 1) Some(treeId) else None

        // 2.1. create new tree, reset workers
        val tree = createNewTree(nodeSizes, classIdOpt = classIdOpt)
        // 2.2. iteratively build one tree
        activeNodes += 0
        while (tree.size() + 2 < param.regTParam.maxNodeNum
          && (activeNodes.nonEmpty || toSplits.nonEmpty)) {
          // 2.2.1. build histograms and find splits
          if (activeNodes.nonEmpty) {
            findSplits(tree, activeNodes, nodeSizes,
              classIdOpt = classIdOpt).foreach { case (nid, split) => toSplits(nid) = split }
          }
          val splitNodeNum = if (param.regTParam.leafwise) 1 else Math.min(
            toSplits.size, (param.regTParam.maxNodeNum - tree.size()) / 2)
          val (toSplit, toSetLeaf) = chooseSplits(tree, toSplits, activeNodes, splitNodeNum)
          activeNodes.clear()
          // 2.2.2. set leaves
          setAsLeaves(tree, toSetLeaf)
          // 2.2.3. split nodes
          if (toSplit.nonEmpty) {
            splitNodes(tree, toSplit, nodeSizes, classIdOpt = classIdOpt)
            toSplit.foreach {
              case (nid, _) =>
                if (2 * nid + 1 < maxInnerNum) {
                  activeNodes += 2 * nid + 1
                  activeNodes += 2 * nid + 2
                }
            }
          }
        }
        // 2.3. finish current tree & update preds
        finishTree(tree, currentLR, classIdOpt = classIdOpt)
        model.add(tree, currentLR)
        println(s"Train tree[${treeId + 1}/$numTreePerRound] of round[${round + 1}] " +
          s"with ${tree.size()} nodes (${(tree.size() - 1) / 2 + 1} leaves) " +
          s"cost ${System.currentTimeMillis() - treeStart} ms, " +
          s"${System.currentTimeMillis() - trainStart} ms elapsed")
      }
      // 3. evaluation
      val metrics = evaluate(round, lrScheduler)
      // 4. finish training of current round
      println(s"Train round[${round + 1}] cost ${System.currentTimeMillis() - roundStart} ms, " +
        s"${System.currentTimeMillis() - trainStart} ms elapsed")
      round += 1
      // 5. schedule learning rate or early stop
      currentLR = lrScheduler.step(metrics, currentLR)
      earlyStop = currentLR < MathUtil.EPSILON
    }

    if (sc.getConf.getBoolean(ML_GBDT_BEST_CHECKPOINT, DEFAULT_ML_GBDT_BEST_CHECKPOINT)) {
      println(s"Train model cost ${System.currentTimeMillis() - trainStart} ms")
      val best = lrScheduler.getBest
      model.keepFirstTrees(best._2 * numTreePerRound)
      println(s"Best performance occurs in round[${best._2}] with ${best._1}[${best._3}], " +
        s"so the final model contains ${model.numTree} trees")
    } else {
      println(s"Train model with ${model.numTree} " +
        s"cost ${System.currentTimeMillis() - trainStart} ms")
    }
    model
  }

  protected def calcGradPairs(): Unit = {
    val calcStart = System.currentTimeMillis()
    doCalcGradPairs()
    // println(s"Calc gradients cost ${System.currentTimeMillis() - calcStart} ms")
  }

  protected def createNewTree(nodeSizes: Array[Int], classIdOpt: Option[Int] = None): GBTTree = {
    val createStart = System.currentTimeMillis()
    val tree = new GBTTree(param.regTParam)
    val root = new GBTNode(0, null)
    tree.setRoot(root)
    val (rootSize, rootGradPair) = doNewTree(classIdOpt = classIdOpt)
    nodeSizes(0) = rootSize
    root.setSumGradPair(rootGradPair)
    root.calcGain(param)
    // println(s"Create new tree cost ${System.currentTimeMillis() - createStart} ms")
    tree
  }

  protected def findSplits(tree: GBTTree, nids: Seq[Int], nodeSizes: Seq[Int],
                           classIdOpt: Option[Int] = None): Map[Int, GBTSplit] = {
    val findStart = System.currentTimeMillis()

    def nodeCanSplit(nid: Int): Boolean = {
      nodeSizes(nid) > param.regTParam.minNodeInstance &&
        tree.getNode(nid).getSumGradPair.satisfyWeight(param)
    }

    val splits = if (nids.head == 0) {
      if (!nodeCanSplit(0)) {
        Map.empty.asInstanceOf[Map[Int, GBTSplit]]
      } else {
        val rootSplit = doFindRootSplit(tree.getRoot.getSumGradPair, classIdOpt = classIdOpt)
        if (rootSplit != null) Map(0 -> rootSplit)
        else Map.empty.asInstanceOf[Map[Int, GBTSplit]]
      }
    } else {
      val canSplits = nids.map(nodeCanSplit)
      val nodeGradPairs = nids.map(nid => tree.getNode(nid).getSumGradPair)
      val (toBuild, toSubtract, toRemove) = HistBuilder.buildScheme(
        nids, canSplits, nodeSizes, param)
      doFindSplits(nids, nodeGradPairs, canSplits, toBuild, toSubtract, toRemove,
        classIdOpt = classIdOpt)
    }

    // println(s"Build histograms and find best splits cost " +
    //   s"${System.currentTimeMillis() - findStart} ms")

    splits
  }

  protected def chooseSplits(tree: GBTTree, toSplits: mutable.Map[Int, GBTSplit], activeNodes: Seq[Int],
                             splitNodeNum: Int): (Seq[(Int, GBTSplit)], Seq[Int]) = {
    val leaves = activeNodes.filter(nid => !toSplits.contains(nid))
    if (splitNodeNum < toSplits.size) {
      val sorted = toSplits.toSeq.sortBy(-_._2.getSplitEntry.getGain)
      val top = sorted.slice(0, splitNodeNum)
      top.foreach{ case (nid, _) => toSplits.remove(nid) }
      (top, leaves)
    } else {
      val all = toSplits.toSeq
      toSplits.clear()
      (all, leaves)
    }
  }

  protected def splitNodes(tree: GBTTree, toSplit: Seq[(Int, GBTSplit)], nodeSizes: Array[Int],
                           classIdOpt: Option[Int] = None): Unit = {
    val splitStart = System.currentTimeMillis()
    toSplit.foreach {
      case (nid, split) =>
        val node = tree.getNode(nid)
        node.setSplitEntry(split.getSplitEntry)
        val leftChild = new GBTNode(2 * nid + 1, node)
        val rightChild = new GBTNode(2 * nid + 2, node)
        tree.setNode(2 * nid + 1, leftChild)
        tree.setNode(2 * nid + 2, rightChild)
        node.setLeftChild(leftChild)
        node.setRightChild(rightChild)
        leftChild.setSumGradPair(split.getLeftGradPair)
        rightChild.setSumGradPair(split.getRightGradPair)
        leftChild.calcGain(param)
        rightChild.calcGain(param)
    }
    doSplitNodes(toSplit, classIdOpt = classIdOpt).foreach {
      case (nid, nodeSize) =>
        require(nodeSizes(nid) == -1)
        nodeSizes(nid) = nodeSize
    }
    // println(s"Split nodes cost ${System.currentTimeMillis() - splitStart} ms")
  }

  protected def setAsLeaves(tree: GBTTree, nids: Seq[Int]): Unit = {
    doSetAsLeaves(nids)
    nids.foreach(nid => {
      tree.getNode(nid).chgToLeaf()
    })
  }

  protected def finishTree(tree: GBTTree, learningRate: Float,
                           classIdOpt: Option[Int] = None): Unit = {
    val finishStart = System.currentTimeMillis()
    tree.getNodes.foreach {
      case (_, node) =>
        if (node.getSplitEntry == null && !node.isLeaf)
          node.chgToLeaf()
        if (node.isLeaf) {
          if (!param.isLeafVector)
            node.calcWeight(param)
          else
            node.calcWeights(param)
        }
    }
    doFinishTree(tree, learningRate, classIdOpt = classIdOpt)
    // println(s"Finish tree cost ${System.currentTimeMillis() - finishStart} ms")
  }

  protected def evaluate(round: Int, lrScheduler: LRScheduler)
  : Seq[(EvalMetric.Kind, Double, Double)] = {
    val metrics = doEvaluate()
    val trainEvalMsg = metrics.map {
      case (kind, train, _) => s"$kind[$train]"
    }.mkString(", ")
    val validEvalMsg = metrics.map {
      case (kind, _, valid) => s"$kind[$valid]"
    }.mkString(", ")
    println(s"Evaluation on train data after ${round + 1} round(s): $trainEvalMsg")
    println(s"Evaluation on valid data after ${round + 1} round(s): $validEvalMsg")
    metrics
  }

  /////////////////////////////////////////////////////////////
  //                       Worker APIs                       //
  /////////////////////////////////////////////////////////////
  protected def doCalcGradPairs(): Unit = ???

  protected def doNewTree(classIdOpt: Option[Int] = None): (Int, GradPair) = ???

  protected def doFindRootSplit(rootGradPair: GradPair, classIdOpt: Option[Int] = None): GBTSplit = ???

  protected def doFindSplits(nids: Seq[Int], nodeGradPairs: Seq[GradPair],
                             canSplits: Seq[Boolean], toBuild: Seq[Int],
                             toSubtract: Seq[Boolean], toRemove: Seq[Int],
                             classIdOpt: Option[Int] = None): Map[Int, GBTSplit] = ???

  protected def doSplitNodes(splits: Seq[(Int, GBTSplit)],
                             classIdOpt: Option[Int] = None): Seq[(Int, Int)] = ???

  protected def doSetAsLeaves(nids: Seq[Int]): Unit = ???

  protected def doFinishTree(tree: GBTTree, learningRate: Float,
                             classIdOpt: Option[Int] = None): Unit = ???

  protected def doEvaluate(): Seq[(EvalMetric.Kind, Double, Double)] = ???
}


private[gbdt] abstract class GBDTWorker(val workerId: Int, val param: GBDTParam) {
  @transient private[gbdt] var trainLabels: Array[Float] = _
  @transient private[gbdt] var trainData: Dataset[Int, Int] = _
  @transient private[gbdt] var validLabels: Array[Float] = _
  @transient private[gbdt] var validData: Array[Vector] = _
  @transient private[gbdt] var validPreds: Array[Float] = _

  private[gbdt] lazy val numTrain: Int = trainData.size
  private[gbdt] lazy val numValid: Int = validData.length

  @transient private[gbdt] var insInfo: InstanceInfo = _
  @transient private[gbdt] var featInfo: FeatureInfo = _

  @transient private[gbdt] var nodeIndexer: NodeIndexer = _
  @transient private[gbdt] var histManager: HistManager = _
  @transient private[gbdt] var histBuilder: HistBuilder = _
  @transient private[gbdt] var splitFinder: SplitFinder = _

  private[gbdt] val loss = ObjectiveFactory.getLoss(param.lossFunc)
  private[gbdt] val evalMetrics = ObjectiveFactory.getEvalMetricsOrDefault(param.evalMetrics, loss)

  private[gbdt] def initialize(trainLabels: Array[Float], trainData: Dataset[Int, Int],
                               validLabels: Array[Float], validData: Array[Vector],
                               insInfo: InstanceInfo, featInfo: FeatureInfo): Unit = {
    this.trainLabels = trainLabels
    this.trainData = trainData
    this.validLabels = validLabels
    this.validData = validData
    this.validPreds = Array.ofDim[Float](validData.length *
      (if (param.numClass == 2) 1 else param.numClass))
    this.insInfo = insInfo
    this.featInfo = featInfo
    this.nodeIndexer = NodeIndexer(param.regTParam.maxDepth, trainData.size)
    this.histManager = HistManager(param, featInfo)
    this.histBuilder = HistBuilder(param, insInfo, featInfo)
    this.splitFinder = SplitFinder(param, featInfo)
    this.histManager.reset(featInfo.isFeatUsed)
  }

  private[gbdt] def calcGradPairs(): Unit = {
    val rootGradPair = insInfo.calcGradPairs(trainLabels,
      ObjectiveFactory.getLoss(param.lossFunc), param)
    nodeIndexer.setNodeGradPair(0, rootGradPair)
  }

  private[gbdt] def reset(insSampleSeed: Option[Long] = None,
                          featSampleSeed: Option[Long] = None,
                          classIdOpt: Option[Int] = None): Unit = {
    // 1. get root grad pairs
    var rootGradPair = nodeIndexer.getNodeGradPair(0)
    // 2. reset node indexing & instance sampling
    nodeIndexer.reset()
    if (nodeIndexer.sample(param.regTParam.insSampleRatio, seed = insSampleSeed))
      rootGradPair = insInfo.sumGradPairs(nodeIndexer.nodeToIns,
        nodeIndexer.getNodePosStart(0), nodeIndexer.getNodePosEnd(0),
        param, classIdOpt = classIdOpt)
    nodeIndexer.setNodeGradPair(0, rootGradPair)
    nodeIndexer.setNodeGain(0, rootGradPair.calcGain(param))
    // 3. feature sampling & clear stored histograms
    if (featInfo.sample(param.regTParam.featSampleRatio, seed = featSampleSeed))
      histManager.reset(featInfo.isFeatUsed)
    else
      histManager.removeAll()
  }

  private[gbdt] def updatePreds(tree: GBTTree, learningRate: Float): Unit = {
    require(param.numClass == 2 || !param.multiTree)
    tree.getNodes.foreach {
      case (nid, node) =>
        if (node.isLeaf) {
          if (param.isRegression || param.numClass == 2) {
            val weight = node.getWeight
            insInfo.updatePreds(nid, nodeIndexer, weight, learningRate)
          } else {
            val weights = node.getWeights
            insInfo.updatePreds(nid, nodeIndexer, weights, learningRate)
          }
        }
    }

    if (param.isRegression || param.numClass == 2) {
      for (i <- validData.indices) {
        val weight = tree.predictBinary(validData(i))
        validPreds(i) += weight * learningRate
      }
    } else {
      for (i <- validData.indices) {
        val weights = tree.predictMulti(validData(i))
        for (k <- 0 until param.numClass)
          validPreds(i * param.numClass + k) += weights(k) * learningRate
      }
    }
  }

  private[gbdt] def updatePredsMultiTree(tree: GBTTree, treeId: Int,
                                         learningRate: Float): Unit = {
    require(param.isMultiClassMultiTree)
    tree.getNodes.foreach {
      case (nid, node) =>
        if (node.isLeaf) {
          val weight = node.getWeight
          insInfo.updatePredsMultiTree(nid, treeId, param.numClass, nodeIndexer,
            weight, learningRate)
        }
    }

    for (i <- validData.indices) {
      val weight = tree.predictBinary(validData(i))
      validPreds(i * param.numClass + treeId) += weight * learningRate
    }
  }

  private[gbdt] def evaluate(): Seq[(EvalMetric.Kind, Double, Double)] = {
    ObjectiveFactory.getEvalMetrics(param.evalMetrics).map(evalMetric => {
      val kind = evalMetric.getKind
      val trainMetric = evalMetric.eval(insInfo.predictions, trainLabels)
      val validMetric = evalMetric.eval(validPreds, validLabels)
      (kind, trainMetric, validMetric)
    })
  }
}


private[gbdt] object GBDTWorker {
  @transient private val workers = scala.collection.mutable.Map[Int, GBDTWorker]()

  private[gbdt] def putWorker(worker: GBDTWorker): Unit = {
    workers.synchronized {
      val workerId = worker.workerId
      require(!workers.contains(workerId), s"Worker with ID $workerId already exists")
      workers += workerId -> worker
    }
  }

  @inline private[gbdt] def getWorker[T <: GBDTWorker](workerId: Int): T = {
    workers.get(workerId) match {
      case Some(worker) => worker.asInstanceOf[T]
      case None => throw new RuntimeException(
        s"Cannot fetch worker[$workerId] on executor[${SparkEnv.get.executorId}], " +
          "this may be due to executor lost or memory insufficient")
    }
  }

  @inline private[gbdt] def getArbitrary[T <: GBDTWorker]: T = {
    workers.headOption match {
      case Some(worker) => worker._2.asInstanceOf[T]
      case None => throw new RuntimeException(
        s"Cannot fetch worker on executor[${SparkEnv.get.executorId}], " +
          "this may be due to executor lost or memory insufficient")
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy