All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.train.FPGBDTTrainer.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.train

import org.apache.spark.SparkContext
import com.tencent.angel.sona.tree.basic.split.SplitEntry
import com.tencent.angel.sona.tree.gbdt.GBDTModel.GBTTree
import com.tencent.angel.sona.tree.gbdt.dataset.Dataset
import com.tencent.angel.sona.tree.gbdt.histogram.GradPair
import com.tencent.angel.sona.tree.gbdt.metadata.{FeatureInfo, InstanceInfo}
import com.tencent.angel.sona.tree.gbdt.tree.{GBDTParam, GBTSplit}
import com.tencent.angel.sona.tree.objective.ObjectiveFactory
import com.tencent.angel.sona.tree.objective.metric.EvalMetric
import com.tencent.angel.sona.tree.util.{ConcurrentUtil, DataLoader, MathUtil, RangeBitSet}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD

import scala.collection.mutable

object FPGBDTTrainer {

  @inline private def getWorker(workerId: Int): FPGBDTWorker = GBDTWorker.getWorker[FPGBDTWorker](workerId)

  private[train] def featureInfoOfGroup(featureInfo: FeatureInfo, groupId: Int,
                                        groupIdToFid: Array[Int]): FeatureInfo = {
    val groupSize = groupIdToFid.length
    val featTypes = Array.ofDim[Boolean](groupSize)
    val numBin = Array.ofDim[Int](groupSize)
    val splits = Array.ofDim[Array[Float]](groupSize)
    val defaultBins = Array.ofDim[Int](groupSize)
    groupIdToFid.view.zipWithIndex.foreach {
      case (fid, newFid) =>
        featTypes(newFid) = featureInfo.isCategorical(fid)
        numBin(newFid) = featureInfo.getNumBin(fid)
        splits(newFid) = featureInfo.getSplits(fid)
        defaultBins(newFid) = featureInfo.getDefaultBin(fid)
    }
    FeatureInfo(featTypes, numBin, splits, defaultBins)
  }
}

class FPGBDTTrainer(param: GBDTParam) extends GBDTTrainer(param) {
  import FPGBDTTrainer._

  @transient private[gbdt] var bcFidToGroupId: Broadcast[Array[Int]] = _
  @transient private[gbdt] var bcGroupIdToFid: Broadcast[Array[Array[Int]]] = _
  @transient private[gbdt] var bcFidToNewFid: Broadcast[Array[Int]] = _
  @transient private[gbdt] var bcGroupSizes: Broadcast[Array[Int]] = _

  @transient private[gbdt] var workers: RDD[Int] = _

  override protected def initialize(trainInput: String, validInput: String,
                                    numWorker: Int, numThread: Int)
                                   (implicit sc: SparkContext): Unit = {
    println("Start to initialize workers")
    val initStart = System.currentTimeMillis()
    val bcParam = sc.broadcast(param)
    val numFeature = param.regTParam.numFeature
    val numClass = param.numClass
    val numSplit = param.regTParam.numSplit

    // 1. Partition features into groups,
    // get feature id to group id mapping and inverted indexing
    val featGroupStart = System.currentTimeMillis()
    val (fidToGroupId, groupIdToFid) = GBDTTrainer.hashFeatureGrouping(numFeature, numWorker)
    val groupSizes = Array.ofDim[Int](numWorker)
    val fidToNewFid = Array.ofDim[Int](numFeature)
    for (fid <- 0 until numFeature) {
      val groupId = fidToGroupId(fid)
      val newFid = groupSizes(groupId)
      groupSizes(groupId) += 1
      fidToNewFid(fid) = newFid
    }
    val bcFidToGroupId = sc.broadcast(fidToGroupId)
    val bcGroupIdToFid = sc.broadcast(groupIdToFid)
    val bcFidToNewFid = sc.broadcast(fidToNewFid)
    val bcGroupSizes = sc.broadcast(groupSizes)
    println(s"Hash feature grouping cost ${System.currentTimeMillis() - featGroupStart} ms")

    // 2. load data from hdfs
    val loadStart = System.currentTimeMillis()
    val trainSet = Dataset.verticalPartition(
      trainInput, fidToGroupId, fidToNewFid, numWorker
    ).cache()
    val numTrain = trainSet.map(_._1.length).collect().sum
    require(trainSet.map(_._2.size).collect().forall(_ == numTrain))
    println(s"Load $numTrain training data cost " +
      s"${System.currentTimeMillis() - loadStart} ms")

    // 3. collect labels, ensure 0-based indexed and broadcast
    val labelStart = System.currentTimeMillis()
    val labels = new Array[Float](numTrain)
    val partLabels = trainSet.map {
      case (partLabel, dataset) => (dataset.id, partLabel)
    }.collect()
    require(partLabels.map(_._1).distinct.length == partLabels.length)
    var offset = 0
    partLabels.sortBy(_._1).map(_._2).foreach(partLabel => {
      Array.copy(partLabel, 0, labels, offset, partLabel.length)
      offset += partLabel.length
    })
    if (!bcParam.value.isRegression) InstanceInfo.ensureLabel(labels, numClass)
    val bcLabels = sc.broadcast(labels)
    println(s"Collect labels cost ${System.currentTimeMillis() - labelStart} ms")

    // 4. build quantile sketches, get candidate splits,
    // and create feature info, finally broadcast info to all workers
    val getSplitsStart = System.currentTimeMillis()
    val isCategorical = Array.ofDim[Boolean](numFeature)
    val splits = Array.ofDim[Array[Float]](numFeature)
    val groupNNZ = Array.ofDim[Int](numWorker)
    trainSet.map {
      case (_, dataset) =>
        val workerMaxDim = bcGroupSizes.value(dataset.id)
        val sketches = Dataset.createSketches(dataset, workerMaxDim)
        val groupSplits = sketches.map(sketch => {
          GBDTTrainer.getSplits(sketch, numSplit)
        })
        (dataset.id, groupSplits)
    }.collect().foreach {
      case (groupId, groupSplits) =>
        // restore feature id based on column grouping info
        // and set splits to corresponding feature
        val newFidToFid = groupIdToFid(groupId)
        groupSplits.view.zipWithIndex.foreach {
          case ((fIsCategorical, fSplits, nnz), newFid) =>
            val fid = newFidToFid(newFid)
            isCategorical(fid) = fIsCategorical
            splits(fid) = fSplits
            groupNNZ(groupId) += nnz
        }
    }
    val featureInfo = FeatureInfo(isCategorical, splits)
    val bcFeatureInfo = sc.broadcast(featureInfo)
    println(s"Create feature info cost ${System.currentTimeMillis() - getSplitsStart} ms")
    println("Feature info: " + (groupSizes, groupNNZ, 0 until numWorker).zipped.map {
      case (size, nnz, groupId) => s"(group[$groupId] #feature[$size] #nnz[$nnz])"
    }.mkString(" "))

    // 5. initialize workers
    val initWorkerStart = System.currentTimeMillis()
    val validSet = DataLoader.loadLibsvm(validInput, numFeature)
      .repartition(numWorker)
    val workers = trainSet.zipPartitions(validSet, preservesPartitioning = true)(
      (trainIter, validIter) => {
        // training dataset, turn feature values into histogram bin indexes
        val train = trainIter.toArray
        require(train.length == 1)
        val workerId = train.head._2.id
        val featureInfo = featureInfoOfGroup(bcFeatureInfo.value,
          workerId, bcGroupIdToFid.value(workerId))
        val trainData = Dataset.binning(train.head._2, featureInfo)
        val trainLabels = bcLabels.value
        // validation dataset, ensure labels are 0-based indexed
        val valid = validIter.toArray
        val validData = valid.map(_._2)
        val validLabels = valid.map(_._1.toFloat)
        if (!bcParam.value.isRegression) InstanceInfo.ensureLabel(validLabels, numClass)
        // meta data
        val insInfo = InstanceInfo(bcParam.value, trainData.size)
        val featInfo = featureInfo
        // init worker
        val worker = new FPGBDTWorker(workerId, bcParam.value)
        worker.initialize(trainLabels, trainData, validLabels, validData,
          insInfo, featInfo)
        GBDTWorker.putWorker(worker)
        ConcurrentUtil.reset(numThread)
        Iterator(workerId)
      }
    ).cache()
    workers.foreach(workerId =>
      println(s"Worker[$workerId] initialization done"))
    val numValid = workers.map(workerId => {
      val worker = getWorker(workerId)
      require(worker.numTrain == numTrain)
      worker.numValid
    }).collect().sum
    println(s"Load valid data and initialize worker cost " +
      s"${System.currentTimeMillis() - initWorkerStart} ms, " +
      s"$numTrain train data, $numValid valid data")
    trainSet.unpersist()

    this.bcFidToGroupId = bcFidToGroupId
    this.bcGroupIdToFid = bcGroupIdToFid
    this.bcFidToNewFid = bcFidToNewFid
    this.bcGroupSizes = bcGroupSizes
    this.workers = workers
    println(s"Initialization done, cost ${System.currentTimeMillis() - initStart} ms in total")
  }

  override protected def shutdown(): Unit = {
    workers.foreach(_ => ConcurrentUtil.shutdown())
    workers.unpersist()
    bcFidToGroupId.destroy()
    bcGroupIdToFid.destroy()
    bcFidToNewFid.destroy()
    bcGroupSizes.destroy()
  }

  override protected def doCalcGradPairs(): Unit = {
    workers.foreach(workerId => getWorker(workerId).calcGradPairs())
  }

  override protected def doNewTree(classIdOpt: Option[Int] = None): (Int, GradPair) = {
    val seed = Option(java.lang.Double.doubleToLongBits(Math.random()))
    workers.map(workerId => {
      val worker = getWorker(workerId)
      worker.reset(insSampleSeed = seed, classIdOpt = classIdOpt)
      (worker.nodeIndexer.getNodeSize(0),
        worker.nodeIndexer.getNodeGradPair(0))
    }).collect().reduceLeft((r1, r2) => {
      require(r1._1 == r2._1, "Number of instances are different on workers " +
        s"(${r1._1} vs. ${r2._2}), which is incorrect in feature parallel")
      r1
    })
  }

  override protected def doFindRootSplit(rootGradPair: GradPair, classIdOpt: Option[Int] = None): GBTSplit = {
    var res = null.asInstanceOf[GBTSplit]
    workers.map(workerId =>
      (workerId, getWorker(workerId).findRootSplit(classIdOpt = classIdOpt))
    ).collect().foreach {
      case (workerId, split) =>
        if (split != null) {
          if (res == null || res.needReplace(split)) {
            val fidInWorker = split.getSplitEntry.getFid
            val trueFid = bcGroupIdToFid.value(workerId)(fidInWorker)
            split.getSplitEntry.setFid(trueFid)
            res = split
          }
        }
    }
    res
  }

  override protected def doFindSplits(nids: Seq[Int], nodeGradPairs: Seq[GradPair],
                                      canSplits: Seq[Boolean], toBuild: Seq[Int],
                                      toSubtract: Seq[Boolean], toRemove: Seq[Int],
                                      classIdOpt: Option[Int] = None): Map[Int, GBTSplit] = {
    val res = mutable.Map[Int, GBTSplit]()
    workers.map(workerId => {
      (workerId, getWorker(workerId).findSplits(nids, canSplits, toBuild, toSubtract,
        toRemove, classIdOpt = classIdOpt))
    }).collect().foreach {
      case (workerId, splits) =>
        splits.foreach {
          case (nid, split) =>
            if (!res.contains(nid) || res(nid).needReplace(split)) {
              val fidInWorker = split.getSplitEntry.getFid
              val trueFid = bcGroupIdToFid.value(workerId)(fidInWorker)
              split.getSplitEntry.setFid(trueFid)
              res(nid) = split
            }
        }
    }
    res.toMap
  }

  override protected def doSplitNodes(splits: Seq[(Int, GBTSplit)],
                                      classIdOpt: Option[Int] = None): Seq[(Int, Int)] = {
    val fids = splits.map {
      case (_, split) =>
        val fid = split.getSplitEntry.getFid
        val ownerId = bcFidToGroupId.value(fid)
        val fidInWorker = bcFidToNewFid.value(fid)
        (ownerId, fidInWorker)
    }
    val bcSplits = workers.sparkContext.broadcast(splits)
    val bcFids = workers.sparkContext.broadcast(fids)
    val splitResults = workers.flatMap(workerId => {
      getWorker(workerId).getSplitResults(bcSplits.value, bcFids.value).iterator
    }).collect().toMap
    val bcSplitResults = workers.sparkContext.broadcast(splitResults)
    val childrenSizes = workers.map(workerId => {
      getWorker(workerId).splitNodes(bcSplits.value, bcSplitResults.value)
    }).collect().head
    childrenSizes
  }

  override protected def doSetAsLeaves(nids: Seq[Int]): Unit = {
    workers.foreach(workerId => {
      val worker = getWorker(workerId)
      nids.foreach(nid => {
        worker.histManager.removeNodeHist(nid)
        if (nid != 0)
          worker.histManager.removeNodeHist(MathUtil.parent(nid))
      })
    })
  }

  override protected def doFinishTree(tree: GBTTree, learningRate: Float,
                                      classIdOpt: Option[Int] = None): Unit = {
    val bcTree = workers.sparkContext.broadcast(tree)
    if (!param.isMultiClassMultiTree) {
      workers.foreach(workerId => {
        val worker = getWorker(workerId)
        worker.histManager.removeAll()
        worker.updatePreds(bcTree.value, learningRate)
      })
    } else {
      val treeId = classIdOpt.get
      workers.foreach(workerId => {
        val worker = getWorker(workerId)
        worker.histManager.removeAll()
        worker.updatePredsMultiTree(bcTree.value, treeId, learningRate)
      })
    }
  }

  override protected def doEvaluate(): Seq[(EvalMetric.Kind, Double, Double)] = {
    val evalMetrics = param.evalMetrics.map(ObjectiveFactory.getEvalMetricKind)
    val trainMetrics = Array.ofDim[Double](evalMetrics.length)
    val validMetrics = Array.ofDim[Double](evalMetrics.length)
    workers.map(workerId => getWorker(workerId).evaluate())
      .collect().foreach(_.zipWithIndex.foreach {
      case ((kind, train, valid), index) =>
        require(kind == evalMetrics(index))
        trainMetrics(index) += train
        validMetrics(index) += valid
    })
    val numWorker = workers.getNumPartitions
    evalMetrics.indices.foreach(index => {
      trainMetrics(index) /= numWorker
      validMetrics(index) /= numWorker
    })
    (evalMetrics, trainMetrics, validMetrics).zipped.toSeq
  }
}


private[gbdt] class FPGBDTWorker(workerId: Int, param: GBDTParam) extends GBDTWorker(workerId, param) {

  private[gbdt] def findRootSplit(classIdOpt: Option[Int] = None): GBTSplit = {
    histBuilder.buildRootHist(trainData, histManager, nodeIndexer, classIdOpt = classIdOpt)
    splitFinder.findBestSplits(Array(0), nodeIndexer, histManager).headOption match {
      case Some(split) => split._2
      case None => null
    }
  }

  private[gbdt] def findSplits(nids: Seq[Int], canSplits: Seq[Boolean], toBuild: Seq[Int],
                               toSubtract: Seq[Boolean], toRemove: Seq[Int],
                               classIdOpt: Option[Int] = None): Seq[(Int, GBTSplit)] = {
    // 1. remove hist of nodes whose children cannot be split
    toRemove.foreach(nid => histManager.removeNodeHist(nid))
    // 2. build hist for active nodes
    histBuilder.buildHist(toBuild, toSubtract, trainData, histManager,
      nodeIndexer, classIdOpt = classIdOpt)
    // 3. there are nodes which cannot be split, but we build hist for them to perform
    // histogram subtraction, we should remove these hist here
    (nids, canSplits).zipped.foreach {
      case (nid, canSplit) => if (!canSplit) histManager.removeNodeHist(nid)
    }
    // 4. find splits for active nodes
    splitFinder.findBestSplits(nids, nodeIndexer, histManager)
  }

  private[train] def getSplitResults(splits: Seq[(Int, GBTSplit)],
                                     fids: Seq[(Int, Int)]): Seq[(Int, RangeBitSet)] = {
    require(splits.length == fids.length)
    splits.indices.flatMap(i => {
      val (nid, split) = splits(i)
      val (ownerId, fidInWorker) = fids(i)
      if (ownerId == this.workerId)
        Iterator((nid, getSplitResult(nid, fidInWorker, split.getSplitEntry)))
      else
        Iterator.empty
    })
  }

  private[train] def getSplitResult(nid: Int, fidInWorker: Int, splitEntry: SplitEntry): RangeBitSet = {
    require(!splitEntry.isEmpty && splitEntry.getGain > param.regTParam.minSplitGain)
    val splits = featInfo.getSplits(fidInWorker)
    nodeIndexer.getSplitResult(nid, fidInWorker, splitEntry, splits, trainData, param)
  }

  private[train] def splitNodes(splits: Seq[(Int, GBTSplit)],
                                splitResults: Map[Int, RangeBitSet]): Seq[(Int, Int)] = {
    val maxInnerNum = MathUtil.maxInnerNodeNum(param.regTParam.maxDepth)
    splits.flatMap {
      case (nid, split) =>
        val splitResult = splitResults(nid)
        splitNode(nid, splitResult, split.getLeftGradPair, split.getRightGradPair)
        if (2 * nid + 1 >= maxInnerNum)
          histManager.removeNodeHist(nid)
        Iterator((2 * nid + 1, nodeIndexer.getNodeSize(2 * nid + 1)),
          (2 * nid + 2, nodeIndexer.getNodeSize(2 * nid + 2)))
    }
  }

  private[train] def splitNode(nid: Int, splitResult: RangeBitSet, leftGradPair: GradPair,
                               rightGradPair: GradPair): Unit = {
    nodeIndexer.updatePos(nid, splitResult)
    nodeIndexer.setNodeGradPair(2 * nid + 1, leftGradPair)
    nodeIndexer.setNodeGradPair(2 * nid + 2, rightGradPair)
    nodeIndexer.setNodeGain(2 * nid + 1, leftGradPair.calcGain(param))
    nodeIndexer.setNodeGain(2 * nid + 2, rightGradPair.calcGain(param))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy