All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.transformers.standardizer.MultiTaskStandardizer.scala Maven / Gradle / Ivy

package io.citrine.lolo.transformers.standardizer

import io.citrine.lolo.api.{
  Model,
  MultiTaskLearner,
  MultiTaskModel,
  MultiTaskTrainingResult,
  ParallelModels,
  TrainingRow
}
import io.citrine.random.Random

/**
  * Standardize the training data to zero mean and unit variance before feeding it into a multi-task learner.
  *
  * @param baseLearner base multi-task learner to train on standardized data.
  */
case class MultiTaskStandardizer(baseLearner: MultiTaskLearner) extends MultiTaskLearner {

  override def train(trainingData: Seq[TrainingRow[Vector[Any]]], rng: Random): MultiTaskStandardizerTrainingResult = {
    val (inputs, labels, weights) = trainingData.map(_.asTuple).unzip3

    val inputTrans = Standardization.buildMulti(inputs)
    val outputTrans = Standardization.buildMulti(labels)

    val standardInputs = inputs.map { input => Standardization.applyMulti(input, inputTrans) }
    val standardLabels = labels.map { label => Standardization.applyMulti(label, outputTrans) }
    val standardTrainingData = standardInputs.lazyZip(standardLabels).lazyZip(weights).map {
      case (i, l, w) => TrainingRow(i, l, w)
    }

    val baseTrainingResult = baseLearner.train(standardTrainingData, rng)
    MultiTaskStandardizerTrainingResult(baseTrainingResult, outputTrans, inputTrans)
  }
}

/**
  * Training result wrapping the base learner's training result next to the transformations
  *
  * @param baseTrainingResult result of training on the standardized inputs and output
  * @param outputTrans sequence of optional transformation (rescale, offset) of output labels
  * @param inputTrans sequence of optional transformations (rescale, offset) of inputs
  */
case class MultiTaskStandardizerTrainingResult(
    baseTrainingResult: MultiTaskTrainingResult,
    outputTrans: Seq[Option[Standardization]],
    inputTrans: Seq[Option[Standardization]]
) extends MultiTaskTrainingResult {

  override def model: MultiTaskModel = ParallelModels(models, baseTrainingResult.model.realLabels)

  override def models: Seq[StandardizerModel[Any]] = {
    val realLabels = baseTrainingResult.model.realLabels
    baseTrainingResult.models.zipWithIndex.map {
      case (model, idx) =>
        if (realLabels(idx)) {
          RegressionStandardizerModel(model.asInstanceOf[Model[Double]], outputTrans(idx).get, inputTrans)
        } else {
          ClassificationStandardizerModel(model, inputTrans)
        }
    }
  }

  override def featureImportance: Option[Vector[Double]] = baseTrainingResult.featureImportance

  override def predictedVsActual: Option[Seq[(Vector[Any], Vector[Option[Any]], Vector[Option[Any]])]] = {
    baseTrainingResult.predictedVsActual.map { pva =>
      pva.map {
        case (inputs, predOpt, actualOpt) =>
          val invertedInputs = Standardization.invertMulti(inputs, inputTrans)
          val invertedPred = Standardization.invertMultiOption(predOpt, outputTrans)
          val invertedActual = Standardization.invertMultiOption(actualOpt, outputTrans)
          (invertedInputs, invertedPred, invertedActual)
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy