All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.trees.multitask.MultiTaskTrainingNode.scala Maven / Gradle / Ivy

package io.citrine.lolo.trees.multitask

import io.citrine.lolo.PredictionResult
import io.citrine.lolo.linear.GuessTheMeanLearner
import io.citrine.lolo.trees.splits.{MultiTaskSplitter, NoSplit}
import io.citrine.lolo.trees.{InternalModelNode, ModelNode, TrainingLeaf}

/**
  * Node in a multi-task training tree, which can produce nodes for its model trees
  *
  * @param inputs data on which to select splits and form models
  */
class MultiTaskTrainingNode(inputs: Seq[(Vector[AnyVal], Array[AnyVal], Double)]) {

  // Compute a split
  val (split, _) = MultiTaskSplitter.getBestSplit(inputs, inputs.head._1.size, 1)

  // Try to construct left and right children
  val (leftChild: Option[MultiTaskTrainingNode], rightChild: Option[MultiTaskTrainingNode]) = split match {
    case _: NoSplit => (None, None)
    case _: Any =>
      val (leftData, rightData) = inputs.partition(row => split.turnLeft(row._1))
      (Some(new MultiTaskTrainingNode(leftData)), Some(new MultiTaskTrainingNode(rightData)))
  }

  // Construct the model node for the `index`th label
  def getNode(index: Int): ModelNode[PredictionResult[Any]] = {

    // Filter out "missing" values, which are NaN for regression and 0 for encoded categoricals
    val label = inputs.head._2(index)
    val reducedData = if (label.isInstanceOf[Double]) {
      inputs.map(x => (x._1, x._2(index).asInstanceOf[Double], x._3)).filterNot(_._2.isNaN)
    } else {
      inputs.map(x => (x._1, x._2(index).asInstanceOf[Char], x._3)).filter(_._2 > 0)
    }
    // Compute the valid data for each child
    val (left, right) = reducedData.partition(r => split.turnLeft(r._1))

    // Construct an internal node if the children are defined and actually have valid data
    val node = if (leftChild.isDefined && rightChild.isDefined && left.nonEmpty && right.nonEmpty) {
      // Not a case because of erasure :-(
      if (label.isInstanceOf[Double]) {
        new InternalModelNode[PredictionResult[Double]](
          split,
          leftChild.get.getNode(index).asInstanceOf[ModelNode[PredictionResult[Double]]],
          rightChild.get.getNode(index).asInstanceOf[ModelNode[PredictionResult[Double]]]
        )
      } else {
        if (!label.isInstanceOf[Char]) throw new IllegalArgumentException("Training data wasn't double or char")
        new InternalModelNode[PredictionResult[Char]](
          split,
          leftChild.get.getNode(index).asInstanceOf[ModelNode[PredictionResult[Char]]],
          rightChild.get.getNode(index).asInstanceOf[ModelNode[PredictionResult[Char]]]
        )
      }
    } else if (leftChild.isDefined && left.nonEmpty) {
      leftChild.get.getNode(index)
    } else if (rightChild.isDefined && right.nonEmpty) {
      rightChild.get.getNode(index)
    } else{
      // If there are no children or the children don't have valid data for this label, emit a leaf (with GTM)
      if (label.isInstanceOf[Double]) {
        new TrainingLeaf[Double](reducedData.asInstanceOf[Seq[(Vector[AnyVal], Double, Double)]], new GuessTheMeanLearner(), 1).getNode()
      } else {
        new TrainingLeaf[Char](reducedData.asInstanceOf[Seq[(Vector[AnyVal], Char, Double)]], new GuessTheMeanLearner(), 1).getNode()
      }
    }
    node.asInstanceOf[ModelNode[PredictionResult[Any]]]
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy