All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.citrine.lolo.hypers.HyperOptimizer.scala Maven / Gradle / Ivy

package io.citrine.lolo.hypers

import io.citrine.lolo.api.{Learner, TrainingRow}
import io.citrine.random.Random

/**
  * Base class for hyperparameter optimizers
  *
  * They take a range of hypers as a Map[(String, Seq[Any])] and output the best map and loss
  */
abstract class HyperOptimizer() {

  /**
    * Add a 1D hyper range to the space searched by this optimizer
    *
    * @param name   of the hyper
    * @param values it takes, enumerated as a seq
    * @return calling instance
    */
  def addHyperGrid(name: String, values: Seq[Any]): this.type = {
    hyperGrids = hyperGrids + ((name, values))
    this
  }

  /** The search space */
  var hyperGrids: Map[String, Seq[Any]] = Map()

  /**
    * Optimize the hypers over the defined search space
    *
    * @param trainingData  the data to train/test on
    * @param numIterations to take before terminating
    * @return the best hyper map found in give iterations and the corresponding loss
    */
  def optimize[T](
      trainingData: Seq[TrainingRow[T]],
      numIterations: Int = 8,
      builder: Map[String, Any] => Learner[T],
      rng: Random = Random()
  ): (Map[String, Any], Double)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy