All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.trees.splits.BoltzmannSplitter.scala Maven / Gradle / Ivy

package io.citrine.lolo.trees.splits

import io.citrine.lolo.api.TrainingRow
import io.citrine.random.Random
import io.citrine.lolo.trees.impurity.VarianceCalculator
import io.citrine.lolo.trees.splits.BoltzmannSplitter.SplitterResult

import scala.collection.mutable

/**
  * Find a split for a regression problem
  *
  * The splits are picked with a probability that is related to the reduction in variance:
  * P(split) ~ exp[ - {remaining variance} / ({temperature} * {total variance}) ]
  * recalling that the "variance" here is weighted by the sample size (so its really the sum of the square difference
  * from the mean of that side of the split).  This is analogous to simulated annealing and Metropolis-Hastings.
  *
  * The motivation here is to reduce the correlation of the trees by making random choices between splits that are
  * almost just as good as the strictly optimal one.  Reducing the correlation between trees will reduce the variance
  * in an ensemble method (e.g. random forests): the variance will both decrease more quickly with the tree count and
  * will reach a lower floor.  In this paragraph, we're using "variance" as in "bias-variance trade-off".
  *
  * Division by the local total variance make the splitting behavior invariant to data size and the scale of the labels.
  * That means, however, that you can't set the temperature based on a known absolute noise scale.  For that, you'd want
  * to divide by the total weight rather than the total variance.
  *
  * TODO: allow the rescaling to happen based on the total weight instead of the total variance, as an option
  *
  * Created by maxhutch on 11/29/16.
  *
  * @param temperature used to control how sensitive the probability of a split is to its change in variance.
  *                    The temperature can be thought of as a hyperparameter.
  */
case class BoltzmannSplitter(temperature: Double) extends Splitter[Double] {
  require(
    temperature >= Float.MinPositiveValue,
    s"Temperature must be >= ${Float.MinPositiveValue} to avoid numerical underflows"
  )

  /**
    * Get the a split probabalisticly, considering numFeature random features (w/o replacement), ensuring that the
    * resulting partitions have at least minInstances in them
    *
    * @param data         to split
    * @param numFeatures  to consider, randomly
    * @param minInstances minimum instances permitted in a post-split partition
    * @param rng          random number generator for reproducibility
    * @return a split object that optimally divides data
    */
  def getBestSplit(
      data: Seq[TrainingRow[Double]],
      numFeatures: Int,
      minInstances: Int,
      rng: Random
  ): (Split, Double) = {
    /* Pre-compute these for the variance calculation */
    val calculator = VarianceCalculator.build(data.map(_.label), data.map(_.weight))
    val initialVariance = calculator.getImpurity

    // Don't split if there is no impurity to reduce
    if (initialVariance == 0) {
      return (NoSplit(), 0.0)
    }
    val beta = 1.0 / (temperature * initialVariance)

    val rep = data.head

    /* Try every feature index */
    val featureIndices: Seq[Int] = rep.inputs.indices

    val possibleSplits: Seq[SplitterResult] = rng.shuffle(featureIndices).take(numFeatures).flatMap { index =>
      /* Use different splitters for each type */
      rep.inputs(index) match {
        case _: Double => BoltzmannSplitter.getBestRealSplit(data, calculator, index, minInstances, beta, rng)
        case _: Char   => BoltzmannSplitter.getBestCategoricalSplit(data, calculator, index, minInstances, beta, rng)
        case _: Any    => throw new IllegalArgumentException("Trying to split unknown feature type")
      }
    }

    // If we couldn't find a split, then return NoSplit with no variance reduction
    if (possibleSplits.isEmpty) {
      return (NoSplit(), 0.0)
    }

    // Re-based the probabilities, such that the largest probability is order-1.0
    // This is meant to avoid every probability underflowing
    val rebase = possibleSplits.map(_.base).max
    val totalProbability = possibleSplits.map { x => x.rebasedScore * Math.exp(x.base - rebase) }.sum

    // select from a discrete probability distribution by drawing a random number and then computing the CDF
    // where the "draw" is the bin for which the CDF crosses the drawn number
    val draw = rng.nextDouble() * totalProbability
    // could be a scanLeft + find, but this is more readable
    var cumSum: Double = 0.0
    possibleSplits.foreach {
      case SplitterResult(split, variance, score, base) =>
        // Here's the probability rebasing again
        cumSum = cumSum + score * Math.exp(base - rebase)
        if (draw < cumSum) {
          val deltaImpurity = initialVariance - variance
          return (split, deltaImpurity)
        }
    }
    // This shouldn't ever be hit
    throw new RuntimeException(s"Draw was beyond all the probabilities ${draw} ${totalProbability}")
  }
}

object BoltzmannSplitter {

  /**
    * Container for function returns, like a decorated tuple
    *
    * The true score (proportional to draw probability) is rebasedScore * Math.exp(base).  This decomposition
    * is such that rebasedScore should always be >= 1.0.
    */
  protected case class SplitterResult(split: Split, variance: Double, rebasedScore: Double, base: Double) {
    // The rebasing procedure should result in rebasedScores that are >= 1.0 with finite bases
    // Otherwise, None should have been returned
    require(rebasedScore >= 1.0)
    require(!base.isNegInfinity)
  }

  /**
    * Find the best split on a continuous variable
    *
    * @param data  to split
    * @param calculator that will efficiently compute the impurity (variance in this case)
    * @param index of the feature to split on
    * @param minCount minimum number of training instances to leave in each of the children nodes
    * @param beta the inverse temperature (1.0 / (temperature * initial variance)) to scale the variances by
    * @return the best split of this feature, along with its score, base, and result variance
    */
  def getBestRealSplit(
      data: Seq[TrainingRow[Double]],
      calculator: VarianceCalculator,
      index: Int,
      minCount: Int,
      beta: Double,
      rng: Random
  ): Option[SplitterResult] = {
    /* Pull out the feature that's considered here and sort by it */
    val thinData = data.map(dat => (dat.inputs(index).asInstanceOf[Double], dat.label, dat.weight)).sortBy(_._1)

    /* Move the data from the right to the left partition one value at a time */
    calculator.reset()
    val possibleSplits: Seq[(Double, Double, Double)] = (0 until data.size - minCount).flatMap { j =>
      val totalVariance = calculator.add(thinData(j)._2, thinData(j)._3)

      /* Keep track of the best split, avoiding splits in the middle of constant sets of feature values
         It is really important for performance to keep these checks together so
         1) there is only one branch and
         2) it is usually false
       */
      val left = thinData(j + 1)._1
      val right = thinData(j)._1
      if (j + 1 >= minCount && Splitter.isDifferent(left, right)) {
        val score = -totalVariance * beta
        val pivot = (left - right) * rng.nextDouble() + right
        Some(score, pivot, totalVariance)
      } else {
        None
      }
    }

    if (possibleSplits.isEmpty) {
      return None
    }

    val base: Double = possibleSplits.map(_._1).max
    val totalScore = possibleSplits.map { case (s, _, _) => Math.exp(s - base) }.sum
    val draw = rng.nextDouble() * totalScore
    var cumSum: Double = 0.0
    possibleSplits.foreach {
      case (score, pivot, variance) =>
        cumSum = cumSum + Math.exp(score - base)
        if (draw < cumSum) {
          return Some(SplitterResult(RealSplit(index, pivot), variance, totalScore, base))
        }
    }
    // This should never be hit; it would mean there's a bug in the logic above ^^
    throw new RuntimeException(s"Draw was beyond all the probabilities: ${draw} > $cumSum")
  }

  /**
    * Find the best categorical splitter.
    *
    * @param data  to split
    * @param calculator that will efficiently compute the impurity (variance in this case)
    * @param index of the feature to split on
    * @param minCount minimum number of training instances to leave in each of the children nodes
    * @param beta the inverse temperature (1.0 / (temperature * initial variance)) to scale the variances by
    * @return the best split of this feature, along with its score, base, and result variance
    */
  def getBestCategoricalSplit(
      data: Seq[TrainingRow[Double]],
      calculator: VarianceCalculator,
      index: Int,
      minCount: Int,
      beta: Double,
      rng: Random
  ): Option[SplitterResult] = {
    /* Extract the features at the index */
    val thinData = data.map(dat => (dat.inputs(index).asInstanceOf[Char], dat.label, dat.weight))
    val totalWeight = thinData.map(_._3).sum

    /* Group the data by categorical feature and compute the weighted sum and sum of the weights for each */
    val groupedData =
      thinData.groupBy(_._1).view.mapValues(g => (g.map(v => v._2 * v._3).sum, g.map(_._3).sum, g.size)).toMap

    /* Make sure there is more than one member for most of the classes */
    val nonTrivial: Double = groupedData.filter(_._2._3 > 1).map(_._2._2).sum
    if (nonTrivial / totalWeight < 0.5) {
      return None
    }

    /* Compute the average label for each categorical value */
    val categoryAverages: Map[Char, Double] = groupedData.view.mapValues(p => p._1 / p._2).toMap

    /* Create an orderd list of the categories by average label */
    val orderedNames = categoryAverages.toSeq.sortBy(_._2).map(_._1)

    /* Base cases for the iteration */
    var leftNum: Int = 0

    /* Add the categories one at a time in order of their average label */
    calculator.reset()
    val possibleSplits: Seq[(Double, mutable.BitSet, Double)] = (0 until orderedNames.size - 1).flatMap { j =>
      val dat = groupedData(orderedNames(j))
      val totalVariance = calculator.add(dat._1 / dat._2, dat._2)
      leftNum += dat._3

      if (leftNum >= minCount && (thinData.size - leftNum) >= minCount) {
        val score = -totalVariance * beta
        val includeSet: mutable.BitSet = new mutable.BitSet() ++ orderedNames.slice(0, j + 1).map(_.toInt)
        Some((score, includeSet, totalVariance))
      } else {
        None
      }
    }

    if (possibleSplits.isEmpty) {
      return None
    }

    val base: Double = possibleSplits.map(_._1).max
    val totalScore = possibleSplits.map { case (s, _, _) => Math.exp(s - base) }.sum
    val draw = rng.nextDouble() * totalScore
    var cumSum: Double = 0.0
    possibleSplits.foreach {
      case (score, includeSet, variance) =>
        cumSum = cumSum + Math.exp(score - base)
        if (draw < cumSum) {
          return Some(SplitterResult(CategoricalSplit(index, includeSet), variance, totalScore, base))
        }
    }
    // This should never be hit; it would mean there's a bug in the logic above ^^
    throw new RuntimeException(s"Draw was beyond all the probabilities: $draw > $cumSum")
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy