All Downloads are FREE. Search and download functionalities are using the official Maven repository.

breeze.optimize.CachedDiffFunction.scala Maven / Gradle / Ivy

package breeze.optimize

import breeze.linalg.support.CanCopy
import breeze.linalg.copy
import breeze.concurrent.ThreadLocal


/**
 *
 * @author dlwh
 */
class CachedDiffFunction[T:CanCopy](obj: DiffFunction[T]) extends DiffFunction[T] {
  /** calculates the gradient at a point */
  override def gradientAt(x: T): T = calculate(x)._2
  /** calculates the value at a point */
  override def valueAt(x:T): Double = calculate(x)._1

  private val lastData = new ThreadLocal[(T, Double, T)](null)

  /** Calculates both the value and the gradient at a point */
  def calculate(x:T):(Double,T) = {
    val last = lastData()
    if(last == null || x != last._1) {
      val newData = obj.calculate(x)
      lastData.set ( (copy(x), newData._1, newData._2))
    }

    val (_, v, g) = lastData.get()
    v -> g
  }
}

/**
 * @author dlwh
 */
class CachedBatchDiffFunction[T:CanCopy](obj: BatchDiffFunction[T]) extends BatchDiffFunction[T] {
  /** calculates the gradient at a point */
  override def gradientAt(x: T, range: IndexedSeq[Int]): T = calculate(x,range)._2
  /** calculates the value at a point */
  override def valueAt(x:T, range: IndexedSeq[Int]): Double = calculate(x,range)._1

  private val lastData = new ThreadLocal[(T, Double, T, IndexedSeq[Int])](null)

  def fullRange = obj.fullRange

  /** Calculates both the value and the gradient at a point */
  override def calculate(x:T, range: IndexedSeq[Int]):(Double,T) = {
    val last = lastData()
    if(last == null || range != last._4 || x != last._1) {
      val newData = obj.calculate(x, range)
      lastData.set ( (copy(x), newData._1, newData._2, range))
    }

    val (_, v, g, _) = lastData.get()
    v -> g
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy