All Downloads are FREE. Search and download functionalities are using the official Maven repository.

au.csiro.variantspark.algo.Split.scala Maven / Gradle / Ivy

The newest version!
package au.csiro.variantspark.algo

/** An immutable container for the information that was recently split
  *
  * Specify 'splitPoint', 'gini', 'leftGini', and 'rightGini'
  *
  * @constructor create an object containing the information about the split
  * @param splitPoint: specifies the exact point in the dataset that it was split at
  * @param gini: general gini value of the dataset
  * @param leftGini: the gini impurity of the left split of the dataset
  * @param rightGini: the gini impurity of the right split of the dataset
  */
case class SplitInfo(splitPoint: Double, gini: Double, leftGini: Double, rightGini: Double)

/**
  * An aggregator for calculating split impurity for two sets of labels or values
  * indireclty referenced by theid indexes.
  */
trait IndexedSplitAggregator {
  def left: ImpurityAggregator
  def right: ImpurityAggregator
  def reset() {
    left.reset()
    right.reset()
  }
  def update(agg: ImpurityAggregator) {
    left.add(agg)
    right.sub(agg)
  }

  /**
    *  Is this a valid split that is one that does not put
    *  all elements to one side
    */
  def hasProperSplit: Boolean = !left.isEmpty && !right.isEmpty

  /**
    * Get split impurity value
    */
  def getValue(outSplitImp: SplitImpurity): Double = {
    left.splitValue(right, outSplitImp)
  }
  def init(indexes: Array[Int]) {
    reset()
    indexes.foreach(i => init(i))
  }
  def init(index: Int)
  def update(index: Int)
}

/**
  * Split aggregator for classification. The indexes refer to nominal labels.
  */
class ClassificationSplitAggregator private (val labels: Array[Int],
    val left: ClassificationImpurityAggregator, val right: ClassificationImpurityAggregator)
    extends IndexedSplitAggregator {

  def initLabel(label: Int) {
    right.addLabel(label)
  }

  def updateLabel(label: Int) {
    left.addLabel(label)
    right.subLabel(label)
  }

  override def init(index: Int): Unit = initLabel(labels(index))

  override def update(index: Int): Unit = updateLabel(labels(index))
}

object ClassificationSplitAggregator {
  def apply(impurity: ClassficationImpurity, labels: Array[Int],
      nCategories: Int): ClassificationSplitAggregator =
    new ClassificationSplitAggregator(labels, impurity.createAggregator(nCategories),
      impurity.createAggregator(nCategories))
}

/**
  * Fast but memory intensive split aggregator keeping partial impurity statistics for
  * all the unique values of the feature (only makes senses with indexed features)
  */
class ConfusionAggregator private (val matrix: Array[ClassificationImpurityAggregator],
    val labels: Array[Int]) {

  def this(impurity: ClassficationImpurity, size: Int, nCategories: Int, labels: Array[Int]) {
    this(Array.fill(size)(impurity.createAggregator(nCategories)), labels)
  }

  /**
    * Reset the first nLevels of the matrix
    */
  def reset(nLevels: Int) {
    assert(nLevels <= matrix.length)
    matrix.iterator.take(nLevels).foreach(_.reset())
  }

  /**
    * Add a response at index yIndex for ordinal level
    */
  def updateAt(level: Int, yIndex: Int): Unit = matrix(level).addLabel(labels(yIndex))

  def apply(level: Int): ClassificationImpurityAggregator = matrix(level)
}

/**
  * The base interface for finding the best split in a set of indexed values.
  */
trait IndexedSplitter {
  def findSplit(splitIndices: Array[Int]): SplitInfo
}

/**
  * A helper trait for IndexedSplitter that select the actual implementaiton
  * base on the set of indexes themselves.
  */
trait SwitchingIndexedSplitter extends IndexedSplitter {
  def select(splitIndices: Array[Int]): IndexedSplitter
  override def findSplit(splitIndices: Array[Int]): SplitInfo =
    select(splitIndices).findSplit(splitIndices)
}

/**
  * Base interface for entities capable of producing indexes splitters
  */
trait SplitterProvider {
  def createSplitter(impCalc: IndexedSplitAggregator): IndexedSplitter
}

/**
  * Base interface for entities capable of producing fast but memory intensive confusion splitters
  */
trait FastSplitterProvider extends SplitterProvider {

  /**
    * The size of the required confusino aggregator
    */
  def confusionSize: Int
  def createSplitter(impCalc: IndexedSplitAggregator,
      confusionAgg: ConfusionAggregator): IndexedSplitter
}

/**
  * Base interface for the strategy for creating indexed splitters for a provider.
  */
trait IndexedSplitterFactory {
  def create(sf: SplitterProvider): IndexedSplitter
}

/**
  * Depending on weather the fast memory consuming splitter can be created
  * and the size of the current subset select either the fast memory consuming option
  * slower but memory efficien one
  * The way ranger does it is
  *  if (sampleSize/numOfUniqueValues < Q_THRESHOLD {
  *    useSlowAlgorirm()
  *  else {
  *   useFastAltorithm() if (available I assume)
  *
  *  The value of Q_THRESHOLD is 0.02
  */
case class ThresholdIndexedSplitter(fastSplitter: IndexedSplitter, confusionSize: Int,
    defaultSplitter: IndexedSplitter,
    qThreshold: Double = ThresholdIndexesSplitter.DefaultQThredhold)
    extends SwitchingIndexedSplitter {

  override def select(splitIndices: Array[Int]): IndexedSplitter = {
    if (splitIndices.length.toDouble / confusionSize >= qThreshold) fastSplitter
    else defaultSplitter
  }
}

object ThresholdIndexesSplitter {
  val DefaultQThredhold: Double = 0.02
}

/**
  * The default implementation of the {{IndexedSplitterFactory}} for classification
  *
  */
class DefStatefullIndexedSpliterFactory(val impurity: ClassficationImpurity,
    val labels: Array[Int], val nCategories: Int, val maxConfusionSize: Int = 10,
    val qThreshold: Double = ThresholdIndexesSplitter.DefaultQThredhold)
    extends IndexedSplitterFactory {

  lazy val splitAggregator: ClassificationSplitAggregator =
    ClassificationSplitAggregator(impurity, labels, nCategories)
  lazy val confusionAgg: ConfusionAggregator =
    new ConfusionAggregator(impurity, maxConfusionSize, nCategories, labels)

  def create(sf: SplitterProvider): IndexedSplitter = {
    sf match {
      case fsf: FastSplitterProvider if fsf.confusionSize <= maxConfusionSize =>
        ThresholdIndexedSplitter(fsf.createSplitter(splitAggregator, confusionAgg),
          fsf.confusionSize, sf.createSplitter(splitAggregator), qThreshold)
      case _ => sf.createSplitter(splitAggregator)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy