All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.helper.NodeIndexer.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.helper

import com.tencent.angel.sona.tree.basic.split.SplitEntry
import com.tencent.angel.sona.tree.gbdt.dataset.Dataset
import com.tencent.angel.sona.tree.util.{ConcurrentUtil, MathUtil, RangeBitSet}
import com.tencent.angel.sona.tree.gbdt.histogram.GradPair
import com.tencent.angel.sona.tree.gbdt.tree.GBDTParam

import java.{util => ju}

object NodeIndexer {

  private[gbdt] def apply(maxDepth: Int, numIns: Int): NodeIndexer = {
    val maxNodeNum = MathUtil.maxNodeNum(maxDepth)
    val nodePosStart = Array.ofDim[Int](maxNodeNum)
    val nodePosEnd = Array.ofDim[Int](maxNodeNum)
    val nodeActualPosEnd = Array.ofDim[Int](maxNodeNum)
    val nodeToIns = Array.ofDim[Int](numIns)
    val nodeGradPairs = Array.ofDim[GradPair](maxNodeNum)
    val nodeGains = Array.ofDim[Float](maxNodeNum)
    NodeIndexer(nodePosStart, nodePosEnd, nodeActualPosEnd, nodeToIns, nodeGradPairs, nodeGains)
  }
}

case class NodeIndexer(nodePosStart: Array[Int], nodePosEnd: Array[Int], nodeActualPosEnd: Array[Int],
                       nodeToIns: Array[Int], nodeGradPairs: Array[GradPair], nodeGains: Array[Float]) {

  private[gbdt] def reset(): Unit = {
    val numIns = numInstance
    for (nid <- nodeGradPairs.indices) {
      nodeGradPairs(nid) = null
      nodeGains(nid) = Float.NaN
    }
    nodePosStart(0) = 0
    nodePosEnd(0) = numIns
    nodeActualPosEnd(0) = numIns
    for (i <- 0 until numIns) {
      nodeToIns(i) = i
    }
  }

  private[gbdt] def sample(ratio: Float, seed: Option[Long] = None): Boolean = {
    val numIns = numInstance
    val numSample = Math.ceil(numIns * ratio).toInt
    if (numSample < numIns) {
      nodePosEnd(0) = numSample
      MathUtil.shuffle(nodeToIns, seed.getOrElse(
        java.lang.Double.doubleToLongBits(Math.random())
      ))
      // sort to be cache-friendly
      ju.Arrays.sort(nodeToIns, 0, numSample)
      ju.Arrays.sort(nodeToIns, numSample, numIns)
      true
    } else {
      false
    }
  }

  private[gbdt] def getSplitResult(nid: Int, fid: Int, splitEntry: SplitEntry, splits: Array[Float],
                                   dataset: Dataset[Int, Int], param: GBDTParam): RangeBitSet = {
    val res = new RangeBitSet(getNodePosStart(nid), getNodeActualPosEnd(nid))

    def getFlowTo(posId: Int): Int = {
      val insId = nodeToIns(posId)
      val binId = dataset.get(insId, fid)
      if (binId >= 0) splitEntry.flowTo(splits(binId))
      else splitEntry.defaultTo()
    }

    def split(start: Int, end: Int): Unit = {
      require(start <= end)
      if (start == end)
        return

      for (posId <- start until end) {
        val flowTo = getFlowTo(posId)
        if (flowTo == 1) {
          res.set(posId)
        }
      }
    }

    if (ConcurrentUtil.numThread == 1) {
      split(getNodePosStart(nid), getNodeActualPosEnd(nid))
    } else {
      ConcurrentUtil.rangeParallel(split, getNodePosStart(nid),
        getNodeActualPosEnd(nid)).foreach(_.get())
    }
    res
  }

  private[gbdt] def updatePos(nid: Int, splitResult: RangeBitSet): Unit = {
    // in-place update position
    def updateRange(start: Int, end: Int): Int = {
      require(start <= end)
      if (start == end)
        return start

      var left = start
      var right = end - 1
      while (left < right) {
        while (left < right && !splitResult.get(left)) left += 1
        while (left < right && splitResult.get(right)) right -= 1
        if (left < right) {
          val leftInsId = nodeToIns(left)
          val rightInsId = nodeToIns(right)
          nodeToIns(left) = rightInsId
          nodeToIns(right) = leftInsId
          left += 1
          right -= 1
        }
      }
      // find cutting position
      val cut = if (left == right) {
        if (splitResult.get(left)) left
        else left + 1
      } else {
        right
      }
      cut
    }

    // update sampled part
    val cut1 = updateRange(getNodePosStart(nid), getNodePosEnd(nid))
    val numLeftSampled = cut1 - getNodePosStart(nid)
    val numRightSampled = getNodeSize(nid) - numLeftSampled
    // update unsampled part
    val cut2 = updateRange(getNodePosEnd(nid), getNodeActualPosEnd(nid))
    val numLeftUnsampled = cut2 - getNodePosEnd(nid)
    val numRightUnsampled = getNodeActualSize(nid) - getNodeSize(nid) - numLeftUnsampled
    // swap right child's sampled part and left child's unsampled part
    MathUtil.swapRange(nodeToIns, cut1, cut2, getNodePosEnd(nid))
    // set edges
    nodePosStart(2 * nid + 1) = getNodePosStart(nid)
    nodePosEnd(2 * nid + 1) = nodePosStart(2 * nid + 1) + numLeftSampled
    nodeActualPosEnd(2 * nid + 1) = nodePosEnd(2 * nid + 1) + numLeftUnsampled
    nodePosStart(2 * nid + 2) = nodeActualPosEnd(2 * nid + 1)
    nodePosEnd(2 * nid + 2) = nodePosStart(2 * nid + 2) + numRightSampled
    nodeActualPosEnd(2 * nid + 2) = nodePosEnd(2 * nid + 2) + numRightUnsampled
  }

  private[gbdt] def updatePos(nid: Int, dataset: Dataset[Int, Int],
                              splitEntry: SplitEntry, splits: Array[Float]): Unit = {
    val fid = splitEntry.getFid

    def getFlowTo(posId: Int): Int = {
      val insId = nodeToIns(posId)
      val binId = dataset.get(insId, fid)
      if (binId >= 0) {
        splitEntry.flowTo(splits(binId))
      } else {
        splitEntry.defaultTo()
      }
    }

    // in-place update position
    def updateRange(start: Int, end: Int): Int = {
      require(start <= end)
      if (start == end)
        return start

      var left = start
      var right = end - 1
      while (left < right) {
        while (left < right && getFlowTo(left) == 0) left += 1
        while (left < right && getFlowTo(right) == 1) right -= 1
        if (left < right) {
          val leftInsId = nodeToIns(left)
          val rightInsId = nodeToIns(right)
          nodeToIns(left) = rightInsId
          nodeToIns(right) = leftInsId
          left += 1
          right -= 1
        }
      }
      // find cutting position
      val cut = if (left == right) {
        if (getFlowTo(left) == 1) left
        else left + 1
      } else {
        right
      }
      cut
    }

    // update sampled part
    val cut1 = updateRange(getNodePosStart(nid), getNodePosEnd(nid))
    val numLeftSampled = cut1 - getNodePosStart(nid)
    val numRightSampled = getNodeSize(nid) - numLeftSampled
    // update unsampled part
    val cut2 = updateRange(getNodePosEnd(nid), getNodeActualPosEnd(nid))
    val numLeftUnsampled = cut2 - getNodePosEnd(nid)
    val numRightUnsampled = getNodeActualSize(nid) - getNodeSize(nid) - numLeftUnsampled
    // swap right child's sampled part and left child's unsampled part
    MathUtil.swapRange(nodeToIns, cut1, cut2, getNodePosEnd(nid))
    // set edges
    nodePosStart(2 * nid + 1) = getNodePosStart(nid)
    nodePosEnd(2 * nid + 1) = nodePosStart(2 * nid + 1) + numLeftSampled
    nodeActualPosEnd(2 * nid + 1) = nodePosEnd(2 * nid + 1) + numLeftUnsampled
    nodePosStart(2 * nid + 2) = nodeActualPosEnd(2 * nid + 1)
    nodePosEnd(2 * nid + 2) = nodePosStart(2 * nid + 2) + numRightSampled
    nodeActualPosEnd(2 * nid + 2) = nodePosEnd(2 * nid + 2) + numRightUnsampled
  }

  lazy val numInstance: Int = nodeToIns.length

  @inline def nodeCanSplit(nid: Int, param: GBDTParam): Boolean =
    getNodeSize(nid) > param.regTParam.minNodeInstance && getNodeGradPair(nid).satisfyWeight(param)

  @inline def getNodePosStart(nid: Int): Int = nodePosStart(nid)

  @inline def getNodePosEnd(nid: Int) = nodePosEnd(nid)

  @inline def getNodeSize(nid: Int): Int = getNodePosEnd(nid) - getNodePosStart(nid)

  @inline def getNodeActualPosEnd(nid: Int): Int = nodeActualPosEnd(nid)

  @inline def getNodeActualSize(nid: Int): Int = getNodeActualPosEnd(nid) - getNodePosStart(nid)

  @inline def getNodeGradPair(nid: Int): GradPair = nodeGradPairs(nid)

  @inline def setNodeGradPair(nid: Int, nodeGradPair: GradPair): Unit = nodeGradPairs(nid) = nodeGradPair

  @inline def getNodeGain(nid: Int): Float = nodeGains(nid)

  @inline def setNodeGain(nid: Int, nodeGain: Float): Unit = nodeGains(nid) = nodeGain
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy