All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.trees.splits.Splitter.scala Maven / Gradle / Ivy

package io.citrine.lolo.trees.splits

import io.citrine.lolo.api.TrainingRow
import io.citrine.lolo.trees.impurity.ImpurityCalculator
import io.citrine.random.Random

trait Splitter[T] {
  def getBestSplit(
      data: Seq[TrainingRow[T]],
      numFeatures: Int,
      minInstances: Int,
      rng: Random = Random()
  ): (Split, Double)
}

object Splitter {
  def isDifferent(x1: Double, x2: Double): Boolean = {
    if (x1 == 0 || x2 == 0) {
      x1 != x2
    } else if (x1 * x2 < 0) {
      true
    } else {
      Math.abs((x1 - x2) / (x1 + x2)) > 1.0e-9
    }
  }

  /**
    * Find the best split on a continuous variable.
    * This is a typical implementation for decision trees: it finds the pivot that maximizes the decrease in impurity.
    * Some splitters might forego this method and implement different ways of choosing a pivot point.
    *
    * @param data  to split
    * @param index of the feature to split on
    * @param minCount minimum number of data points to allow in each of the resulting splits
    * @param randomizePivotLocation whether generate splits by drawing a random value uniformly between the two split points.
    *                               This can improve generalizability, particularly as part of an ensemble.
    * @param rng  random number generator for reproducibility
    * @return the best split of this feature
    */
  def getBestRealSplit[T](
      data: Seq[TrainingRow[T]],
      calculator: ImpurityCalculator[T],
      index: Int,
      minCount: Int,
      randomizePivotLocation: Boolean = false,
      rng: Random = Random()
  ): (Split, Double) = {
    /* Pull out the feature that's considered here and sort by it */
    val thinData = data.map(dat => (dat.inputs(index).asInstanceOf[Double], dat.label, dat.weight)).sortBy(_._1)
    val features = thinData.map(x => x._1)

    var bestImpurity = Double.MaxValue
    var bestPivot = Double.MinValue

    /* Move the data from the right to the left partition one value at a time */
    calculator.reset()
    (0 until data.size - minCount).foreach { j =>
      val totalImpurity = calculator.add(thinData(j)._2, thinData(j)._3)
      val left = features(j + 1)
      val right = features(j)
      if (totalImpurity < bestImpurity && j + 1 >= minCount && isDifferent(left, right)) {
        bestImpurity = totalImpurity
        /* Try pivots at the midpoints between consecutive member values */
        bestPivot = if (randomizePivotLocation) {
          (left - right) * rng.nextDouble() + right
        } else {
          (left + right) / 2.0
        }
      }
    }
    (RealSplit(index, bestPivot), bestImpurity)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy