All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.validation.CrossValidation.scala Maven / Gradle / Ivy

package io.citrine.lolo.validation

import io.citrine.lolo.{Learner, PredictionResult}

import scala.util.Random

/**
  * Methods tha use cross-validation to calculate predicted-vs-actual data and metric estimates
  */
object CrossValidation {

  /**
    * Driver to apply named metrics to k-fold cross-validated predicted-vs-actual
    *
    * @param trainingData to cross-validate with
    * @param learner      to cross-validate
    * @param metrics      apply to the predicted-vs-actual data
    * @param k            number of folds
    * @param nTrial       number of times to refold the data to improve statistics
    * @tparam T type of the response, e.g. Double for scalar regression
    * @return a Map from the name of the metric to its mean value and the error in that mean
    */
  def kFoldCrossvalidation[T](
                               trainingData: Seq[(Vector[Any], T)],
                               learner: Learner,
                               metrics: Map[String, Merit[T]],
                               k: Int = 8,
                               nTrial: Int = 1
                             ): Map[String, (Double, Double)] = {
    Merit.estimateMerits(
      kFoldPvA(trainingData, learner, k, nTrial).iterator,
      metrics
    )
  }

  /**
    * Use k-fold cross-validation to create predicted vs actual results
    *
    * @param trainingData to cross-validate with
    * @param learner      to cross-validate
    * @param k            number of folds
    * @param nTrial       number of times to re-fold the data to improve statistics
    * @tparam T type of the response, e.g. Double for scalar regression
    * @return an iterable over predicted-vs-actual for each fold
    */
  def kFoldPvA[T](
                   trainingData: Seq[(Vector[Any], T)],
                   learner: Learner,
                   k: Int = 8,
                   nTrial: Int = 1
                 ): Iterable[(PredictionResult[T], Seq[T])] = {
    val nTest: Int = Math.ceil(trainingData.size.toDouble / k).toInt
    (0 until nTrial).flatMap { _ =>
      val folds: Seq[Seq[(Vector[Any], T)]] = Random.shuffle(trainingData).grouped(nTest).toSeq

      folds.indices.map { idx =>
        val (testFolds, trainFolds) = folds.zipWithIndex.partition(_._2 == idx)
        val testData = testFolds.flatMap(_._1)
        val trainData = trainFolds.flatMap(_._1)
        val model = learner.train(trainData).getModel()
        val predictions: PredictionResult[T] = model.transform(testData.map(_._1)).asInstanceOf[PredictionResult[T]]
        (predictions, testData.map(_._2))
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy