All Downloads are FREE. Search and download functionalities are using the official Maven repository.

au.csiro.variantspark.algo.RandomForest.scala Maven / Gradle / Ivy

The newest version!
package au.csiro.variantspark.algo

import au.csiro.pbdava.ssparkle.common.utils.FastUtilConversions._
import au.csiro.pbdava.ssparkle.common.utils.Logging
import au.csiro.pbdava.ssparkle.common.utils.Timed._
import au.csiro.variantspark.data.Feature
import au.csiro.variantspark.metrics.Metrics
import au.csiro.variantspark.utils.IndexedRDDFunction._
import au.csiro.variantspark.utils.{Sample, defRng}
import it.unimi.dsi.fastutil.longs.{Long2DoubleOpenHashMap, Long2LongOpenHashMap}
import it.unimi.dsi.util.XorShift1024StarRandomGenerator
import org.apache.commons.lang3.builder.ToStringBuilder
import org.apache.spark.rdd.RDD

/** Allows for normalization(scaling)of the input map values
  */
trait VarImportanceNormalizer {
  def normalize(varImportance: Map[Long, Double]): Map[Long, Double]
}

/** Defines normalization variable conditionally
  */
case object RawVarImportanceNormalizer extends VarImportanceNormalizer {
  override def normalize(varImportance: Map[Long, Double]): Map[Long, Double] = varImportance
}

/** Implements normalization variable scaling
  */
class StandardImportanceNormalizer(val scale: Double) extends VarImportanceNormalizer {
  override def normalize(varImportance: Map[Long, Double]): Map[Long, Double] = {
    val total = varImportance.values.sum * scale
    varImportance.mapValues(_ / total)
  }
}

/** Defines two different scaling values conditionally - 100% and 1%
  */
case object To100ImportanceNormalizer extends StandardImportanceNormalizer(100.0)
case object ToOneImportanceNormalizer extends StandardImportanceNormalizer(1.0)

/** Implements voting aggregator conditionally
  *
  * @param nLabels the number of labels
  * @param nSamples the number of samples
  */
case class VotingAggregator(nLabels: Int, nSamples: Int) {
  lazy val votes: Array[Array[Int]] = Array.fill(nSamples)(Array.fill(nLabels)(0))

  /** Adds a vote with predictions and indexes
    * @param predictions the number of predictions
    * @param indexes the number of indexes
    */
  def addVote(predictions: Array[Int], indexes: Iterable[Int]) {
    require(predictions.length <= nSamples, "Valid number of samples")
    predictions.zip(indexes).foreach { case (v, i) => votes(i)(v) += 1 }
  }

  /** Adds a vote with predictions
    * @param predictions the number of predictions
    */
  def addVote(predictions: Array[Int]): VotingAggregator = {
    require(predictions.length == nSamples, "Full prediction range")
    predictions.zipWithIndex.foreach { case (v, i) => votes(i)(v) += 1 }
    this
  }

  /** Maps votes to predictions
    *
    */
  def predictions: Array[Int] = votes.map(v => v.indices.maxBy(v))

  /**
    * Computes class probabilities.
    * The result is an array with one item per sample, where
    * each item is a vector with class probabilities for this sample.
    * @return predicted class probabilities for each sample.
    */
  def classProbabilities: Array[Array[Double]] = {
    votes.map { row =>
      val sampleTotal = row.sum.toDouble
      row.map(classCount => classCount / sampleTotal)
    }
  }
}

/** Implements random forest members conditionally
  * @param predictor the predictor model
  * @param oobIndexes an array of out-of-bag index values
  */
@SerialVersionUID(2L)
case class RandomForestMember(predictor: PredictiveModelWithImportance,
    oobIndexes: Array[Int] = null, oobPred: Array[Int] = null) {}

/** Implements random forest models conditionally
  * @param members the RF members
  * @param labelCount the label count
  * @param oobErrors the out-of-bag errors
  */
@SerialVersionUID(2L)
case class RandomForestModel(members: List[RandomForestMember], labelCount: Int,
    oobErrors: List[Double] = List.empty, params: RandomForestParams = null) {

  def oobError: Double = oobErrors.last

  def printout() {
    trees.zipWithIndex.foreach {
      case (tree, index) =>
        println(s"Tree: ${index}")
        tree.printout()
    }
  }

  def trees: List[PredictiveModelWithImportance] = members.map(_.predictor)

  def normalizedVariableImportance(
      norm: VarImportanceNormalizer = To100ImportanceNormalizer): Map[Long, Double] =
    norm.normalize(variableImportance)

  /** Sets the variable importance by averaging the importance of each variable over all trees
    *  if a variable is not used in a tree it's importance for this tree is assumed to be 0
    */
  def variableImportance: Map[Long, Double] = {

    trees
      .map(_.variableImportanceAsFastMap)
      .foldLeft(new Long2DoubleOpenHashMap())(_.addAll(_))
      .asScala
      .mapValues(_ / size)
  }

  /**
    * Computes the number of time each of the variables appears as the splitting variable
    * in the forest.
    * @return map variableIndex -> variableSplitCount
    */
  def variableSplitCount: Map[Long, Long] = {
    trees
      .map(_.variableSplitCountAsFastMap)
      .foldLeft(new Long2LongOpenHashMap())(_.addAll(_))
      .asScala
  }

  def size: Int = members.size

  def predict(indexedData: RDD[(Feature, Long)]): Array[Int] =
    predict(indexedData, indexedData.size)

  def predict(indexedData: RDD[(Feature, Long)], nSamples: Int): Array[Int] = {
    trees
      .map(_.predict(indexedData))
      .foldLeft(VotingAggregator(labelCount, nSamples))(_.addVote(_))
      .predictions
  }

  def predictProb(indexedData: RDD[(Feature, Long)]): Array[Array[Double]] =
    predictProb(indexedData, indexedData.size)

  def predictProb(indexedData: RDD[(Feature, Long)], nSamples: Int): Array[Array[Double]] = {
    val treeVotes = trees
      .map(_.predict(indexedData))
      .foldLeft(VotingAggregator(labelCount, nSamples))(_.addVote(_))
    treeVotes.classProbabilities
  }
}

/** Implements random forest params conditionally
  * @param oob the out-of-bag value
  * @param nTryFraction the n-try fraction value
  * @param bootstrap the bootstrap value
  * @param subsample the subsample value
  * @param seed the seed value
  * @param maxDepth the maxDepth value
  * @param minNodeSize the minNodeSize value
  */
case class RandomForestParams(oob: Boolean = true, nTryFraction: Double = Double.NaN,
    bootstrap: Boolean = true, subsample: Double = Double.NaN, randomizeEquality: Boolean = true,
    seed: Long = defRng.nextLong, maxDepth: Int = Int.MaxValue, minNodeSize: Int = 1,
    correctImpurity: Boolean = false, airRandomSeed: Long = 0L) {
  def resolveDefaults(nSamples: Int, nVariables: Int): RandomForestParams = {
    RandomForestParams(oob = oob,
      nTryFraction =
        if (!nTryFraction.isNaN) nTryFraction else Math.sqrt(nVariables.toDouble) / nVariables,
      bootstrap = bootstrap,
      subsample = if (!subsample.isNaN) subsample else if (bootstrap) 1.0 else 0.666,
      randomizeEquality = randomizeEquality, seed = seed, maxDepth = maxDepth,
      minNodeSize = minNodeSize, correctImpurity = correctImpurity, airRandomSeed = airRandomSeed)
  }
  def toDecisionTreeParams(seed: Long): DecisionTreeParams = {
    DecisionTreeParams(seed = seed, randomizeEquality = randomizeEquality, maxDepth = maxDepth,
      minNodeSize = minNodeSize, correctImpurity = correctImpurity, airRandomSeed = airRandomSeed)
  }
  override def toString: String = ToStringBuilder.reflectionToString(this)
}

object RandomForestParams {
  def fromOptions(oob: Option[Boolean] = None, mTryFraction: Option[Double] = None,
      bootstrap: Option[Boolean] = None, subsample: Option[Double] = None,
      seed: Option[Long] = None, maxDepth: Option[Int] = None, minNodeSize: Option[Int] = None,
      correctImpurity: Option[Boolean] = None,
      airRandomSeed: Option[Long] = None): RandomForestParams =
    RandomForestParams(oob.getOrElse(true), mTryFraction.getOrElse(Double.NaN),
      bootstrap.getOrElse(true), subsample.getOrElse(Double.NaN), true,
      seed.getOrElse(defRng.nextLong), maxDepth.getOrElse(Int.MaxValue), minNodeSize.getOrElse(1),
      correctImpurity.getOrElse(false), airRandomSeed.getOrElse(0L))
}

trait RandomForestCallback {
  def onParamsResolved(actualParams: RandomForestParams) {}
  def onTreeComplete(nTrees: Int, oobError: Double, elapsedTimeMs: Long) {}
}

// TODO (Design): Avoid using type cast change design
trait BatchTreeModel {
  def batchTrain(indexedData: RDD[TreeFeature], labels: Array[Int], nTryFraction: Double,
      samples: Seq[Sample]): Seq[PredictiveModelWithImportance]
  def batchPredict(indexedData: RDD[TreeFeature], models: Seq[PredictiveModelWithImportance],
      indexes: Seq[Array[Int]]): Seq[Array[Int]]
}

object RandomForest {
  type ModelBuilderFactory = DecisionTreeParams => BatchTreeModel
  val defaultBatchSize: Int = 10

  def wideDecisionTreeBuilder(params: DecisionTreeParams): BatchTreeModel = {
    val decisionTree = new DecisionTree(params)
    new BatchTreeModel() {
      override def batchTrain(indexedData: RDD[TreeFeature], labels: Array[Int],
          nTryFraction: Double, samples: Seq[Sample]): Seq[PredictiveModelWithImportance] =
        decisionTree.batchTrainInt(indexedData, labels, nTryFraction, samples)
      override def batchPredict(indexedData: RDD[TreeFeature],
          models: Seq[PredictiveModelWithImportance], indexes: Seq[Array[Int]]): Seq[Array[Int]] =
        DecisionTreeModel.batchPredict(indexedData.map(tf => (tf, tf.index)),
          models.asInstanceOf[Seq[DecisionTreeModel]], indexes)
    }
  }
}

/** Implements random forest
  * @param params the RF params
  * @param modelBuilderFactory the type of model, i.e. 'wide decision tree builder'
  */
class RandomForest(params: RandomForestParams = RandomForestParams(),
    modelBuilderFactory: RandomForest.ModelBuilderFactory = RandomForest.wideDecisionTreeBuilder,
    trf: TreeRepresentationFactory = DefTreeRepresentationFactory)
    extends Logging {

  // TODO (Design):make this class keep random state (could be externalised to implicit random)
  implicit lazy val rng: XorShift1024StarRandomGenerator =
    new XorShift1024StarRandomGenerator(params.seed)
  def batchTrain(indexedData: RDD[(Feature, Long)], labels: Array[Int], nTrees: Int,
      nBatchSize: Int = RandomForest.defaultBatchSize): RandomForestModel = {
    val treeFeatures: RDD[TreeFeature] = trf.createRepresentation(indexedData)
    batchTrainTyped(treeFeatures, labels, nTrees, nBatchSize)
  }

  // TODO (Design): Make a param rather then an extra method
  // TODO (Func): Add OOB Calculation
  def batchTrainTyped(treeFeatures: RDD[TreeFeature], labels: Array[Int], nTrees: Int,
      nBatchSize: Int)(implicit callback: RandomForestCallback = null): RandomForestModel = {

    require(nBatchSize > 0)
    require(nTrees > 0)
    val nSamples = labels.length
    val nVariables = treeFeatures.count().toInt
    val nLabels = labels.max + 1

    logDebug(s"Data:  nSamples:${nSamples}, nVariables: ${nVariables}, nLabels:${nLabels}")

    val actualParams = params.resolveDefaults(nSamples, nVariables)

    Option(callback).foreach(_.onParamsResolved(actualParams))
    logDebug(s"Parameters: ${actualParams}")
    logDebug(s"Batch Training: ${nTrees} with batch size: ${nBatchSize}")

    val oobAggregator =
      if (actualParams.oob) Option(VotingAggregator(nLabels, nSamples)) else None

    val builder = modelBuilderFactory(actualParams.toDecisionTreeParams(rng.nextLong))
    val allSamples = Stream
      .fill(nTrees)(Sample.fraction(nSamples, actualParams.subsample, actualParams.bootstrap))

    val (allTrees, errors) = allSamples
      .sliding(nBatchSize, nBatchSize)
      .flatMap { samplesStream =>
        time {

          val samples = samplesStream.toList
          val predictors =
            builder.batchTrain(treeFeatures, labels, actualParams.nTryFraction, samples)
          val members = if (actualParams.oob) {

            val oobIndexes = samples.map(_.distinctIndexesOut.toArray)
            val oobPredictions = builder.batchPredict(treeFeatures, predictors, oobIndexes)
            predictors.zip(oobIndexes.zip(oobPredictions)).map {
              case (t, (i, p)) => RandomForestMember(t, i, p)
            }

          } else predictors.map(RandomForestMember(_))

          val oobError = oobAggregator
            .map { agg =>
              members.map { m =>
                agg.addVote(m.oobPred, m.oobIndexes)
                Metrics.classificationError(labels, agg.predictions)
              }
            }
            .getOrElse(List.fill(predictors.size)(Double.NaN))
          members.zip(oobError)
        }.withResultAndTime {
          case (treesAndErrors, elapsedTime) =>
            logDebug(s"Trees: ${treesAndErrors.size} >> oobError: ${treesAndErrors.last._2}, "
                + s"time: ${elapsedTime}")
            Option(callback).foreach(_.onTreeComplete(treesAndErrors.size, treesAndErrors.last._2,
                elapsedTime))
        }.result
      }
      .toList
      .unzip

    RandomForestModel(allTrees, nLabels, errors, actualParams)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy