All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.transformers.standardizer.StandardizerTrainingResult.scala Maven / Gradle / Ivy

package io.citrine.lolo.transformers.standardizer

import io.citrine.lolo.api.TrainingResult

/** Training result wrapping the base learner's training result next to the transformations. */
trait StandardizerTrainingResult[+T] extends TrainingResult[T] {

  def baseTrainingResult: TrainingResult[T]

  override def model: StandardizerModel[T]

  override def featureImportance: Option[Vector[Double]] = baseTrainingResult.featureImportance
}

/**
  * Training result wrapping the input & output transformations for a regression learner.
  *
  * @param baseTrainingResult result of training on the standardized inputs and output
  * @param outputTrans transformation (rescale, offset) of output label
  * @param inputTrans sequence of optional transformations (rescale, offset) of inputs
  */
case class RegressionStandardizerTrainingResult(
    baseTrainingResult: TrainingResult[Double],
    outputTrans: Standardization,
    inputTrans: Seq[Option[Standardization]]
) extends StandardizerTrainingResult[Double] {

  override def model: RegressionStandardizerModel =
    RegressionStandardizerModel(baseTrainingResult.model, outputTrans, inputTrans)

  override def predictedVsActual: Option[Seq[(Vector[Any], Double, Double)]] = {
    baseTrainingResult.predictedVsActual.map { pva =>
      pva.map {
        case (inputs, pred, actual) =>
          (Standardization.invertMulti(inputs, inputTrans), outputTrans.invert(pred), outputTrans.invert(actual))
      }
    }
  }
}

/**
  * Training result wrapping the input transformations for a classification learner.
  *
  * @param baseTrainingResult result of training on the standardized inputs and output
  * @param inputTrans sequence of optional transformations (rescale, offset) of inputs
  */
case class ClassificationStandardizerTrainingResult[T](
    baseTrainingResult: TrainingResult[T],
    inputTrans: Seq[Option[Standardization]]
) extends StandardizerTrainingResult[T] {

  override def model: ClassificationStandardizerModel[T] =
    ClassificationStandardizerModel(baseTrainingResult.model, inputTrans)

  override def predictedVsActual: Option[Seq[(Vector[Any], T, T)]] = {
    baseTrainingResult.predictedVsActual.map { pva =>
      pva.map {
        case (inputs, pred, actual) => (Standardization.invertMulti(inputs, inputTrans), pred, actual)
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy