All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.helper.SplitFinder.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.helper

import com.tencent.angel.sona.tree.basic.split._
import com.tencent.angel.sona.tree.gbdt.helper.HistManager.NodeHist
import com.tencent.angel.sona.tree.gbdt.histogram._
import com.tencent.angel.sona.tree.gbdt.metadata.FeatureInfo
import com.tencent.angel.sona.tree.gbdt.tree.{GBDTParam, GBTSplit}

import scala.collection.mutable.ArrayBuffer

object SplitFinder {
  private[gbdt] def findBestSplitOfOneFeature(param: GBDTParam, featInfo: FeatureInfo, fid: Int, histogram: Histogram,
                                              sumGradPair: GradPair, nodeGain: Float): GBTSplit = {
    if (featInfo.isCategorical(fid)) {
      findBestSplitSet(param, featInfo, fid, histogram, sumGradPair, nodeGain)
    } else {
      findBestSplitPoint(param, featInfo, fid, histogram, sumGradPair, nodeGain)
    }
  }

  private def findBestSplitPoint(param: GBDTParam, featInfo: FeatureInfo, fid: Int, histogram: Histogram,
                                 sumGradPair: GradPair, nodeGain: Float): GBTSplit = {
    val splitPoint = new SplitPoint()
    val leftStat = sumGradPair.copy()
    leftStat.clear()
    val rightStat = sumGradPair.copy()
    var bestLeftStat = null.asInstanceOf[GradPair]
    var bestRightStat = null.asInstanceOf[GradPair]
    val numBin = histogram.getNumBin
    val splits = featInfo.getSplits(fid)
    for (binId <- 0 until numBin - 1) {
      histogram.scan(binId, leftStat, rightStat)
      if (leftStat.satisfyWeight(param) && rightStat.satisfyWeight(param)) {
        val leftGain = leftStat.calcGain(param)
        val rightGain = rightStat.calcGain(param)
        val lossChg = leftGain + rightGain - nodeGain - param.regLambda
        if (splitPoint.needReplace(lossChg)) {
          splitPoint.setFid(fid)
          splitPoint.setFvalue(splits(binId + 1))
          splitPoint.setGain(lossChg)
          bestLeftStat = leftStat.copy()
          bestRightStat = rightStat.copy()
        }
      }
    }
    new GBTSplit(splitPoint, bestLeftStat, bestRightStat)
  }

  private def findBestSplitSet(param: GBDTParam, featInfo: FeatureInfo, fid: Int, histogram: Histogram,
                               sumGradPair: GradPair, nodeGain: Float): GBTSplit = {
    val splits = featInfo.getSplits(fid)
    val defaultBin = featInfo.getDefaultBin(fid)

    def binFlowTo(left: GradPair, bin: GradPair): Int = {
      if (!param.isLeafVector) {
        val sumGrad = sumGradPair.asInstanceOf[BinaryGradPair].getGrad
        val leftGrad = left.asInstanceOf[BinaryGradPair].getGrad
        val binGrad = bin.asInstanceOf[BinaryGradPair].getGrad
        if (binGrad * (2 * leftGrad + binGrad - sumGrad) >= 0.0) 0 else 1
      } else {
        val sumGrad = sumGradPair.asInstanceOf[MultiGradPair].getGrad
        val leftGrad = left.asInstanceOf[MultiGradPair].getGrad
        val binGrad = bin.asInstanceOf[MultiGradPair].getGrad
        var dot = 0.0
        for (i <- 0 until param.numClass)
          dot += binGrad(i) * (2 * leftGrad(i) + binGrad(i) - sumGrad(i))
        if (dot >= 0.0) 0 else 1
      }
    }

    // 1. set default bin to left child
    val leftStat = histogram.get(defaultBin).copy()
    // 2. for other bins, find its location
    var firstFlow = -1
    var curFlow = -1
    var hasRight = false
    var curSplitId = 0
    val edges = ArrayBuffer[Float]()
    edges.sizeHint(FeatureInfo.ENUM_THRESHOLD)
    edges += splits.head
    val binGradPair = leftStat.copy()
    binGradPair.clear()
    val numBin = histogram.getNumBin
    for (binId <- 0 until numBin) {
      if (binId != defaultBin) { // skip default bin
        histogram.get(binId, binGradPair)  // re-use
        val flowTo = binFlowTo(leftStat, binGradPair)
        if (flowTo == 0) leftStat.plusBy(binGradPair)
        else hasRight = true
        if (firstFlow == -1) {
          firstFlow = flowTo
          curFlow = flowTo
        } else if (flowTo != curFlow) {
          edges += splits(curSplitId)
          curFlow = flowTo
        }
        curSplitId += 1
      }
    }
    // 3. create split set
    if (hasRight) {  // whether all bins go the left
      if (edges.length == 1 || edges.last != splits.length)
        edges += splits.last
      val rightStat = sumGradPair.subtract(leftStat)
      if (leftStat.satisfyWeight(param) && rightStat.satisfyWeight(param)) {
        val leftGain = leftStat.calcGain(param)
        val rightGain = rightStat.calcGain(param)
        val lossChg = leftGain + rightGain - nodeGain - param.regLambda
        if (lossChg > 0.0f) {
          val splitSet = new SplitSet(fid, lossChg, edges.toArray, firstFlow, 0)
          return new GBTSplit(splitSet, leftStat, rightStat)
        }
      }
    }
    new GBTSplit()
  }

  private[gbdt] def apply(param: GBDTParam, featInfo: FeatureInfo): SplitFinder =
    new SplitFinder(param, featInfo)

}

import SplitFinder._
class SplitFinder(param: GBDTParam, featInfo: FeatureInfo) {

  private[gbdt] def findBestSplits(nids: Seq[Int], nodeIndexer: NodeIndexer,
                                   histManager: HistManager): Seq[(Int, GBTSplit)] = {
    nids.flatMap(nid => {
      val nodeHist = histManager.getNodeHist(nid)
      if (nodeHist != null) {
        val split = findBestSplit(nodeHist,
          nodeIndexer.getNodeGradPair(nid), nodeIndexer.getNodeGain(nid))
        if (split.isValid(param.regTParam.minSplitGain)) Iterator((nid, split)) else Iterator.empty
      } else {
        Iterator.empty
      }
    })
  }

  private[gbdt] def findBestSplit(nodeHist: NodeHist, sumGradPair: GradPair,
                                  nodeGain: Float): GBTSplit = {
    val best = new GBTSplit()
    for (fid <- nodeHist.indices) {
      if (nodeHist(fid) != null) {
        best.update(findBestSplitOfOneFeature(param, featInfo, fid,
          nodeHist(fid), sumGradPair, nodeGain))
      }
    }
    best
  }

  private[gbdt] def findBestSplit(nodeHist: NodeHist, fids: Seq[Int],
                                  sumGradPair: GradPair, nodeGain: Float): GBTSplit = {
    require(nodeHist.length == fids.length)
    val best = new GBTSplit()
    for (i <- nodeHist.indices) {
      if (nodeHist(i) != null) {
        best.update(findBestSplitOfOneFeature(param, featInfo, fids(i),
          nodeHist(i), sumGradPair, nodeGain))
      }
    }
    best
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy