All Downloads are FREE. Search and download functionalities are using the official Maven repository.

breeze.optimize.BatchDiffFunction.scala Maven / Gradle / Ivy

package breeze.optimize

import breeze.stats.distributions.Rand

/**
* A diff function that supports subsets of the data. By default it evaluates on all the data
*/
trait BatchDiffFunction[T] extends DiffFunction[T] with ((T,IndexedSeq[Int])=>Double) { outer =>
  /**
  * Calculates the gradient of the function on a subset of the data
  */
  def gradientAt(x:T, batch: IndexedSeq[Int]) : T = calculate(x,batch)._2;
  /**
  * Calculates the value of the function on a subset of the data
  */
  def valueAt(x:T, batch: IndexedSeq[Int]) : Double = calculate(x,batch)._1
  /**
  * Calculates the value and gradient of the function on a subset of the data;
  */
  def calculate(x:T, batch: IndexedSeq[Int]): (Double,T)

  override def calculate(x:T):(Double,T) = calculate(x,fullRange);
  override def valueAt(x:T):Double = valueAt(x,fullRange)
  override def gradientAt(x:T):T = gradientAt(x,fullRange)

  def apply(x:T, batch:IndexedSeq[Int]) = valueAt(x,batch);

  /**
  * The full size of the data
  */
  def fullRange: IndexedSeq[Int];

  def withRandomBatches(size: Int):StochasticDiffFunction[T] = new StochasticDiffFunction[T] {
    val rand = Rand.subsetsOfSize(fullRange,size)
    def calculate(x: T) = outer.calculate(x, rand.get)
  }

  def withScanningBatches(size: Int):StochasticDiffFunction[T] = new StochasticDiffFunction[T] {
    var lastStop = 0
    def nextBatch = synchronized {
      val start = lastStop
      lastStop += size
      lastStop %= fullRange.size
      Array.tabulate(size)(i => fullRange((i+start)%fullRange.size))
    }

    def calculate(x: T) = outer.calculate(x, nextBatch)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy