All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.tree.gbdt.helper.FeatureImportance.scala Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */
package com.tencent.angel.sona.tree.gbdt.helper

import com.tencent.angel.sona.tree.basic.split.{SplitEntry, SplitType}
import com.tencent.angel.sona.tree.gbdt.GBDTModel

import scala.collection.JavaConversions._
import scala.collection.mutable

object FeatureImportance {
  val WEIGHT: String = "weight"
  val GAIN: String = "gain"
  val TOTAL_GAIN: String = "total_gain"
  val DEFAULT_ML_GBDT_IMPORTANCE_TYPE: String = GAIN

  def ensureImportanceType(importanceType: String): Unit = {
    val provided = Array(WEIGHT, GAIN, TOTAL_GAIN)
    require(provided.contains(importanceType),
      s"Unsupported feature importance type: $importanceType, " +
        s"choices: [${provided.mkString(", ")}]")
  }

  /**
    * Calculate the feature importance based on a GBDT model
    * using predefined importance type.
    *
    * @param model GBDT model
    * @param importanceType Predefined importance type, three types supported
    *                       - weight: the number of occurrences
    *                       - gain: the average split gain (average over occurrences)
    *                       - total_gain: the total split gain
    * @return Seq of (feature index, importance) pairs, sorted by importance.
    */
  def featImportance(model: GBDTModel, importanceType: String): Seq[(Int, Float)] = {
    importanceType match {
      case WEIGHT =>
        featImportance(model, (_: Int, _: Int, _: SplitEntry) => 1f, average = false)
      case GAIN =>
        featImportance(model, (_: Int, _: Int, s: SplitEntry) => s.getGain, average = true)
      case TOTAL_GAIN =>
        featImportance(model, (_: Int, _: Int, s: SplitEntry) => s.getGain, average = false)
      case _ => throw new RuntimeException("No such predefined " +
        "feature importance type: " + importanceType)
    }
  }

  /**
    * Calculate the feature importance based on a GBDT model
    * with specific importance function.
    *
    * @param model GBDT model
    * @param importanceFunc Function to calculate importance of one tree node
    * @param average If true, the importance will be divided by number of occurrences
    * @return Seq of (feature index, importance) pairs, sorted by importance.
    */
  def featImportance(model: GBDTModel, importanceFunc: (Int, Int, SplitEntry) => Float,
                     average: Boolean): Seq[(Int, Float)] = {
    val featStats = mutable.Map[Int, (Int, Float)]()

    for (treeId <- 0 until model.numTree) {
      val tree = model.get(treeId)
      tree.getNodes.foreach {
        case (nodeId, node) =>
          if (!node.isLeaf) {
            val split = node.getSplitEntry
            val fid = split.getFid
            val importance = importanceFunc(treeId, nodeId, split)
            if (featStats.contains(fid)) {
              val (cnt, sum) = featStats(fid)
              featStats(fid) = (cnt + 1, sum + importance)
            } else {
              featStats(fid) = (1, importance)
            }
          }
      }
    }

    if (average) {
      featStats.mapValues(pair => pair._2 / pair._1).toSeq.sortBy(-_._2)
    } else {
      featStats.mapValues(pair => pair._2).toSeq.sortBy(-_._2)
    }
  }

  private def examples(): Unit = {
    import org.apache.spark.SparkContext
    val sc = SparkContext.getOrCreate()

    val modelPath = "hdfs://path/to/model"
    val model = sc.objectFile[GBDTModel](modelPath).collect().head

    // Examples of predefined importance function:
    // Calculate by number of occurrences
    featImportance(model, importanceType = "weight")
    // Calculate by average split gain
    featImportance(model, importanceType = "gain")
    // Calculate by total split gain
    featImportance(model, importanceType = "total_gain")

    // Examples of self-customized importance function:
    // Calculate by total split gain, decayed by tree index
    featImportance(model, (treeId: Int, _: Int, s: SplitEntry) =>
      s.getGain * math.pow(0.9, treeId).toFloat, average = false)
    // Calculate by average split gain, decayed by node depth
    featImportance(model, (_: Int, nodeId: Int, s: SplitEntry) =>
      s.getGain * math.pow(0.9, math.floor(math.log10(nodeId + 1) / math.log10(2))).toFloat,
      average = true)
    // Only consider continuous features
    featImportance(model, (_: Int, _: Int, s: SplitEntry) =>
      if (s.splitType() == SplitType.SPLIT_POINT) s.getGain else 0f, average = true)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy