All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.trees.classification.ClassificationTree.scala Maven / Gradle / Ivy

package io.citrine.lolo.trees.classification

import breeze.linalg.DenseMatrix
import io.citrine.lolo.api.{Learner, Model, PredictionResult, TrainingResult, TrainingRow}
import io.citrine.random.Random
import io.citrine.lolo.encoders.CategoricalEncoder
import io.citrine.lolo.linear.GuessTheModeLearner
import io.citrine.lolo.trees.splits.{ClassificationSplitter, Splitter}
import io.citrine.lolo.trees.{ModelNode, TrainingNode, TreeMeta}

import scala.collection.mutable

/**
  * @param numFeatures subset of features to select splits from
  * @param maxDepth maximum depth of tree
  * @param minLeafInstances minimum training instances per node
  * @param leafLearner to train on leaves
  * @param splitter used to select splits
  */
case class ClassificationTreeLearner(
    numFeatures: Int = -1,
    maxDepth: Int = 30,
    minLeafInstances: Int = 1,
    leafLearner: Option[Learner[Char]] = None,
    splitter: Splitter[Char] = ClassificationSplitter()
) extends Learner[Any] {

  @transient private lazy val myLeafLearner: Learner[Char] = leafLearner.getOrElse(GuessTheModeLearner())

  /**
    * Train a classification tree.
    *
    * @param trainingData to train on
    * @param rng          random number generator for reproducibility
    * @return a classification tree
    */
  override def train(trainingData: Seq[TrainingRow[Any]], rng: Random): ClassificationTreeTrainingResult = {
    assert(trainingData.size > 4, s"We need to have at least 4 rows, only ${trainingData.size} given")
    val repInput = trainingData.head.inputs

    // Create encoders for any categorical features
    val inputEncoders: Seq[Option[CategoricalEncoder[Any]]] = repInput.zipWithIndex.map {
      case (v, i) =>
        if (v.isInstanceOf[Double]) {
          None
        } else {
          Some(CategoricalEncoder.buildEncoder(trainingData.map(_.inputs(i))))
        }
    }

    val outputEncoder = CategoricalEncoder.buildEncoder(trainingData.map(_.label))

    // Encode the training data
    val encodedTraining = trainingData.map { row =>
      val encodedInputs = CategoricalEncoder.encodeInput(row.inputs, inputEncoders)
      val encodedLabels = outputEncoder.encode(row.label)
      TrainingRow(encodedInputs, encodedLabels, row.weight)
    }

    /* Add the weights to the (features, label) tuples and remove any with zero weight */
    val finalTraining = encodedTraining.filter(_.weight > 0.0)

    /* If the number of features isn't specified, use all of them */
    val numFeaturesActual = if (numFeatures > 0) {
      numFeatures
    } else {
      finalTraining.head.inputs.size
    }

    // Recursively build the tree via its nodes and wrap the top node in a ClassificationTreeTrainingResult
    val rootTrainingNode = ClassificationTrainingNode.build(
      trainingData = finalTraining,
      leafLearner = myLeafLearner,
      splitter = splitter,
      numFeatures = numFeaturesActual,
      minLeafInstances = minLeafInstances,
      remainingDepth = maxDepth,
      maxDepth = maxDepth,
      numClasses = trainingData.map(_.label).distinct.length,
      rng = rng
    )
    ClassificationTreeTrainingResult(rootTrainingNode, inputEncoders, outputEncoder)
  }
}

case class ClassificationTreeTrainingResult(
    rootTrainingNode: TrainingNode[Char],
    inputEncoders: Seq[Option[CategoricalEncoder[Any]]],
    outputEncoder: CategoricalEncoder[Any]
) extends TrainingResult[Any] {

  // Grab a prediction node. The partitioning happens here
  override lazy val model: ClassificationTree =
    ClassificationTree(rootTrainingNode.modelNode, inputEncoders, outputEncoder)

  // Grab the feature influences
  lazy val nodeImportance: mutable.ArraySeq[Double] = rootTrainingNode.featureImportance

  override lazy val featureImportance: Option[Vector[Double]] = Some(
    if (Math.abs(nodeImportance.sum) > 0) {
      nodeImportance.map(_ / nodeImportance.sum).toVector
    } else {
      nodeImportance.map(_ => 1.0 / nodeImportance.size).toVector
    }
  )
}

/**
  * Classification tree
  */
case class ClassificationTree(
    rootModelNode: ModelNode[Char],
    inputEncoders: Seq[Option[CategoricalEncoder[Any]]],
    outputEncoder: CategoricalEncoder[Any]
) extends Model[Any] {

  /**
    * Apply the model to a seq of inputs
    *
    * @param inputs to apply the model to
    * @return a predictionresult which includes, at least, the expected outputs
    */
  override def transform(inputs: Seq[Vector[Any]]): ClassificationTreePrediction = {
    ClassificationTreePrediction(
      inputs.map(inp => rootModelNode.transform(CategoricalEncoder.encodeInput(inp, inputEncoders))),
      outputEncoder
    )
  }

  /**
    * Compute Shapley feature attributions for a given input
    *
    * @param input for which to compute feature attributions.
    * @param omitFeatures feature indices to omit in computing Shapley values
    * @return array of Shapley feature attributions, one per input feature, each a vector of
    *         One Vector[Double] per feature, each of length equal to the output dimension.
    *         The output dimension is 1 for single-task regression, or equal to the number of classification categories.
    */
  override def shapley(input: Vector[Any], omitFeatures: Set[Int] = Set()): Option[DenseMatrix[Double]] = {
    rootModelNode.shapley(CategoricalEncoder.encodeInput(input, inputEncoders), omitFeatures)
  }
}

/**
  * Classification result
  */
case class ClassificationTreePrediction(
    predictions: Seq[(PredictionResult[Char], TreeMeta)],
    outputEncoder: CategoricalEncoder[Any]
) extends PredictionResult[Any] {

  /**
    * Get the expected values for this prediction
    *
    * @return expected value of each prediction
    */
  override def expected: Seq[Any] = predictions.map(p => outputEncoder.decode(p._1.expected.head))

  def depth: Seq[Int] = predictions.map(_._2.depth)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy