
io.citrine.lolo.api.Learner.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of lolo_2.13 Show documentation
Show all versions of lolo_2.13 Show documentation
A random forest-centered machine learning library in Scala.
package io.citrine.lolo.api
import io.citrine.random.Random
/**
* Base trait for a supervised learner that produces a model.
*
* @tparam T the label type of the data the learner is trained on
*/
trait Learner[T] extends Serializable {
/**
* Train a model on the provided training data.
*
* @param trainingData to train on
* @param rng random number generator for reproducibility
* @return training result containing a model
*/
def train(trainingData: Seq[TrainingRow[T]], rng: Random = Random()): TrainingResult[T]
}
/** A learner that trains on multiple labels, outputting a single model that makes predictions for all labels. */
trait MultiTaskLearner extends Learner[Vector[Any]] {
def train(trainingData: Seq[TrainingRow[Vector[Any]]], rng: Random = Random()): MultiTaskTrainingResult
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy