All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.linear.MultiTaskLinearRegression.scala Maven / Gradle / Ivy

package io.citrine.lolo.linear

import io.citrine.lolo.api.{MultiTaskLearner, MultiTaskModel, MultiTaskTrainingResult, ParallelModels, TrainingRow}
import io.citrine.random.Random

case class MultiTaskLinearRegressionLearner(regParam: Option[Double] = None, fitIntercept: Boolean = true)
    extends MultiTaskLearner {

  override def train(
      trainingData: Seq[TrainingRow[Vector[Any]]],
      rng: Random
  ): MultiTaskLinearRegressionTrainingResult = {
    val rep = trainingData.head
    val repLabels = rep.label
    val numLabels = repLabels.length
    assert(
      repLabels.forall(_.isInstanceOf[Double]),
      "Multi-task linear regression can only be performed for real-valued labels."
    )

    // Train each task independently so that label sparsity does not diminish the training corpus
    val singleTaskLearner = LinearRegressionLearner(regParam, fitIntercept)
    val singleTaskResults = Vector.tabulate(numLabels) { idx =>
      val reducedData = trainingData
        .map(row => row.mapLabel(labelVec => labelVec(idx).asInstanceOf[Double]))
        .filterNot(_.label.isNaN)
      singleTaskLearner.train(reducedData, rng)
    }
    MultiTaskLinearRegressionTrainingResult(singleTaskResults)
  }
}

case class MultiTaskLinearRegressionTrainingResult(linearResults: Vector[LinearRegressionTrainingResult])
    extends MultiTaskTrainingResult {

  override val models: Vector[LinearRegressionModel] = linearResults.map(_.model)

  override val model: MultiTaskModel = ParallelModels(models, Vector.fill(models.length)(true))

  override lazy val featureImportance: Option[Vector[Double]] = {
    // Average the importance across labels
    val labelImportance = linearResults.flatMap(_.featureImportance) // (# labels) x (# features) array
    Option.when(labelImportance.nonEmpty) {
      val importanceByFeature = labelImportance.transpose.map(_.sum)
      val totalImportance = importanceByFeature.sum
      importanceByFeature.map(_ / totalImportance)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy