All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.trees.FeaturePowerSetTerms.scala Maven / Gradle / Ivy

package io.citrine.lolo.trees

/**
  * Description of a feature's effect on the weight ascribed to the node
  *
  * Decision trees can be thought of as weighted sums over their leaves.  When there is knowledge of every feature,
  * all of the weight goes into a single leaf.  However, when feature are excluded from the prediction, then both
  * paths of decisions that depend on those excluded features are assigned non-zero weights that are proportional to
  * the share of the training data that followed that path.
  *
  * This class captures those weights for a single feature and a single node. The `weightWhenIncluded` gives the factor
  * of the weight of the node due to this feature when the feature is included (i.e. known).  It is always 0 or 1.
  * The `weightWhenExcluded` gives the factor of the weight of the node due to this feature when the feature is
  * excluded (i.e. unknown).  It is always strictly greater than 0 and strictly less than 1.
  *
  * @param weightWhenExcluded factor of the weight of the leaf when the feature is excluded
  * @param weightWhenIncluded factor of the weight of the leaf when the feature is included
  */
case class FeatureWeightFactor(
    weightWhenExcluded: Double,
    weightWhenIncluded: Double
) {
  require(
    weightWhenIncluded == 0.0 || weightWhenIncluded == 1.0,
    s"Got weightWhenIncluded=$weightWhenIncluded, but should only ever be 0.0 or 1.0"
  )
  require(
    weightWhenExcluded > 0 && weightWhenExcluded < 1,
    s"Got weightWhenExcluded=$weightWhenExcluded, but should be > 0 and < 1"
  )
}

/**
  * Weighted combinatorial factors corresponding to the power-set of features that are included
  *
  * The set is extended as new features are encountered.  Each element in `weightBySubsetSize` corresponds to all of the
  * subsets of the features that are included vs excluded of a given size:
  *  - the 0th element corresponds to the term where all of the encountered features are excluded,
  *  - the 1st element sums the terms that have exactly one of the encountered feature included,
  *  - the 2nd element sums the terms that have exactly two of the encountered features included,
  * and so on and so forth.  The final element represents the single term where all of the features are "turned on".
  *
  * The calculation is performed efficiently using a dynamic programming technique that runs in quadratic time with
  * the number of features.  It is described in the `extend` and `unwind` methods.
  *
  * @param maxFeatures number of features to allocate space for
  */
class FeaturePowerSetTerms(maxFeatures: Int) {
  // pre-allocation of this whole array is an attempted performance optimization.
  val weightBySubsetSize: Array[Double] = Array.fill[Double](maxFeatures + 1)(0.0)
  weightBySubsetSize(0) = 1.0
  var size: Int = 0

  /**
    * Extend the power set by adding a new feature (in-place)
    *
    * Each subset weight is updated with a contribution from the subset of the same size multiplied by the weight when
    * the feature is excluded and the subset of size one smaller multiplied by the weight when the feature is included.
    * The additional terms in the expression track the combinatorial factors |S|! and  (M - |S| - 1)! from eq. 2 of
    * https://arxiv.org/pdf/1802.03888.pdf
    *
    * @param weightWhenExcluded weight factor for this node when the feature is excluded (unknown)
    * @param weightWhenIncluded weight factor for this node when the feature is included (known)
    * @return this (in-place)
    */
  def extend(
      weightWhenExcluded: Double,
      weightWhenIncluded: Double
  ): FeaturePowerSetTerms = {
    size += 1

    // use a single scratch variable so we can do a forward in-place update
    var register: Double = weightBySubsetSize(0)

    // there is no subset with size -1, so special case i = 0
    weightBySubsetSize(0) = weightWhenExcluded * weightBySubsetSize(0) * ((size - 0).toDouble / (size + 1))

    (1 until size + 1).foreach { i =>
      val contrib = weightWhenIncluded * register * (i.toDouble / (size + 1))
      // stash the weight for use in the next iteration
      register = weightBySubsetSize(i)
      weightBySubsetSize(i) = weightWhenExcluded * weightBySubsetSize(i) * ((size - i).toDouble / (size + 1)) + contrib
    }

    this
  }

  /**
    * Remove a feature from the power set, reversing the action of the `extend` method
    *
    * This method is probably better called "remove", but it is called unwind in the paper.  It's not pretty, but it
    * simply inverts `extend`.  If you think about extend as multiplying by a matrix with weightWhenExcluded... on the
    * diagonal and weightWhenIncluded... on the lower diagonal.  This method is performing a fast linear solve on that
    * banded system, with the two logical branches corresponding to the banded vs purely diagonal case.
    * The ... above denotes some combinatorial terms also present.
    *
    * @return feature power set with a feature removed (out-of-place)
    */
  def unwind(weightWhenExcluded: Double, weightWhenIncluded: Double): FeaturePowerSetTerms = {
    // make a copy so this is out of place
    val out = this.copy()

    // reverse the procedure in extend
    var n = out.weightBySubsetSize(size)
    (size - 1 to 0 by -1).foreach { j =>
      if (weightWhenIncluded != 0.0) {
        val t = out.weightBySubsetSize(j)
        out.weightBySubsetSize(j) = n * (size + 1) / ((j + 1) * weightWhenIncluded)
        n = t - out.weightBySubsetSize(j) * weightWhenExcluded * ((size - j).toDouble / (size + 1))
      } else {
        out.weightBySubsetSize(j) = out.weightBySubsetSize(j) * (size + 1).toDouble / (weightWhenExcluded * (size - j))
      }
    }

    // bookkeeping
    out.size -= 1
    out
  }

  /**
    * Get the total weight of the feature power set, as a sum over the weights by size of the sets in the power set
    */
  def totalWeight: Double = weightBySubsetSize.take(size + 1).sum

  /**
    * This is an optimized version of unwind().totalWeight.  Because the unwound feature power set isn't used,
    * it is not even allocated.  That avoided allocation is a significant performance improvement.
    */
  def unwoundTotalWeight(weightWhenExcluded: Double, weightWhenIncluded: Double): Double = {
    var res = 0.0
    // reverse the procedure in extend

    if (weightWhenIncluded != 0.0) {
      var n = weightBySubsetSize(size)
      (size - 1 to 0 by -1).foreach { j =>
        val x = n * (size + 1) / ((j + 1) * weightWhenIncluded)
        res += x
        n = weightBySubsetSize(j) - x * weightWhenExcluded * (size - j) / (size + 1)
      }
    } else {
      (0 until size).foreach { j =>
        res += weightBySubsetSize(j) * (size + 1) / (weightWhenExcluded * (size - j))
      }
    }
    res
  }

  def copy(): FeaturePowerSetTerms = {
    val newSet = new FeaturePowerSetTerms(this.maxFeatures)
    newSet.size = this.size
    this.weightBySubsetSize.zipWithIndex.foreach { case (x, i) => newSet.weightBySubsetSize(i) = x }
    newSet
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy