All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.algebird.SGDMonoid.scala Maven / Gradle / Ivy

/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package com.twitter.algebird

object SGD {
  /**
   * constructs the gradient for linear regression.
   * the Pos type is (Double, IndexedSeq[Double])
   * note the LAST element in the weights is the constant term.
   * and note that the length of the IndexedSeq in the tuple is
   * one less than the weights (we don't carry the constant term)
   */
  val linearGradient: (IndexedSeq[Double], (Double, IndexedSeq[Double])) => IndexedSeq[Double] = { (w, pos) =>
    val (y, xs) = pos
    val xsPlusConst = xs :+ 1.0
    val err = dot(w, xsPlusConst) - y
    // Here is the gradient
    xsPlusConst.map { _ * err }
  }

  def dot(x: IndexedSeq[Double], y: IndexedSeq[Double]): Double =
    x.view.zip(y).map { case (a: Double, b: Double) => a * b }.sum

  // Here are some step algorithms:
  def constantStep(s: Double): (Long, IndexedSeq[Double]) => Double = { (_, _) => s }
  // A standard: a/(steps + b)^c
  def countAdaptiveStep(a: Double, b: Double, c: Double = 1.0): (Long, IndexedSeq[Double]) => Double = { (cnt, _) =>
    a / scala.math.pow((cnt + b), c)
  }

  def weightsOf[T](s: SGD[T]): Option[IndexedSeq[Double]] = {
    s match {
      case SGDWeights(_, w) => Some(w)
      case _ => None
    }
  }
}

sealed abstract class SGD[+Pos]
case object SGDZero extends SGD[Nothing]
object SGDWeights {
  def apply(w: IndexedSeq[Double]) = new SGDWeights(1L, w)
  def average(left: SGDWeights, right: SGDWeights): SGDWeights = {
    val lc = left.count
    val rc = right.count
    if (rc == 0L) left
    else if (lc == 0L) right
    else {
      val newW = left.weights.view
        .zip(right.weights)
        .map {
          case (l: Double, r: Double) =>
            (lc * l + rc * r) / ((lc + rc).toDouble)
        }
        .toIndexedSeq
      SGDWeights(lc + rc, newW)
    }
  }
}
case class SGDWeights(val count: Long, val weights: IndexedSeq[Double]) extends SGD[Nothing]

object SGDPos {
  def apply[Pos](p: Pos) = new SGDPos(List(p))
}
case class SGDPos[+Pos](val pos: List[Pos]) extends SGD[Pos]

/**
 * Basically a specific implementation of the RightFoldedMonoid
 * gradient is the gradient of the function to be minimized
 * To use this, you need to insert an initial weight SGDWeights
 * before you start adding SGDPos objects. Otherwise you will
 * just be doing list concatenation.
 */
class SGDMonoid[Pos](stepfn: (Long, IndexedSeq[Double]) => Double,
  gradient: (IndexedSeq[Double], Pos) => IndexedSeq[Double])

  extends Monoid[SGD[Pos]] {

  val zero = SGDZero

  def plus(left: SGD[Pos], right: SGD[Pos]): SGD[Pos] = {
    (left, right) match {
      case (_, SGDZero) => left
      case (SGDPos(llps), SGDPos(rlps)) => SGDPos(llps ::: rlps)
      case (rsw @ SGDWeights(c, w), SGDPos(p)) => p.foldLeft(rsw) { (cntWeight, pos) =>
        newWeights(cntWeight, pos)
      }
      // TODO make a RightFolded2 which folds A,B => (B,C), and a group on C.
      case _ => right
    }
  }

  def newWeights(sgdW: SGDWeights, p: Pos): SGDWeights = {
    val grad = gradient(sgdW.weights, p)
    val step = stepfn(sgdW.count, grad)
    SGDWeights(sgdW.count + 1L,
      sgdW.weights.view
        .zip(grad)
        .map { case (l: Double, r: Double) => l - step * r }
        .toIndexedSeq)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy