All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.optimization.GradientDescent.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package org.apache.flink.ml.optimization

import org.apache.flink.api.scala._
import org.apache.flink.ml.common._
import org.apache.flink.ml.math._
import org.apache.flink.ml.optimization.IterativeSolver.{ConvergenceThreshold, Iterations, LearningRate}
import org.apache.flink.ml.optimization.Solver._
import org.apache.flink.ml._

/** Base class which performs Stochastic Gradient Descent optimization using mini batches.
  *
  * For each labeled vector in a mini batch the gradient is computed and added to a partial
  * gradient. The partial gradients are then summed and divided by the size of the batches. The
  * average gradient is then used to updated the weight values, including regularization.
  *
  * At the moment, the whole partition is used for SGD, making it effectively a batch gradient
  * descent. Once a sampling operator has been introduced, the algorithm can be optimized
  *
  *  The parameters to tune the algorithm are:
  *                      [[Solver.LossFunction]] for the loss function to be used,
  *                      [[Solver.RegularizationConstant]] for the regularization parameter,
  *                      [[IterativeSolver.Iterations]] for the maximum number of iteration,
  *                      [[IterativeSolver.LearningRate]] for the learning rate used.
  *                      [[IterativeSolver.ConvergenceThreshold]] when provided the algorithm will
  *                      stop the iterations if the relative change in the value of the objective
  *                      function between successive iterations is is smaller than this value.
  */
abstract class GradientDescent extends IterativeSolver {

  /** Provides a solution for the given optimization problem
    *
    * @param data A Dataset of LabeledVector (label, features) pairs
    * @param initialWeights The initial weights that will be optimized
    * @return The weights, optimized for the provided data.
    */
  override def optimize(
    data: DataSet[LabeledVector],
    initialWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector] = {

    val numberOfIterations: Int = parameters(Iterations)
    val convergenceThresholdOption: Option[Double] = parameters.get(ConvergenceThreshold)
    val lossFunction = parameters(LossFunction)
    val learningRate = parameters(LearningRate)
    val regularizationConstant = parameters(RegularizationConstant)

    // Initialize weights
    val initialWeightsDS: DataSet[WeightVector] = createInitialWeightsDS(initialWeights, data)

    // Perform the iterations
    convergenceThresholdOption match {
      // No convergence criterion
      case None =>
        optimizeWithoutConvergenceCriterion(
          data,
          initialWeightsDS,
          numberOfIterations,
          regularizationConstant,
          learningRate,
          lossFunction)
      case Some(convergence) =>
        optimizeWithConvergenceCriterion(
          data,
          initialWeightsDS,
          numberOfIterations,
          regularizationConstant,
          learningRate,
          convergence,
          lossFunction
        )
    }
  }

  def optimizeWithConvergenceCriterion(
      dataPoints: DataSet[LabeledVector],
      initialWeightsDS: DataSet[WeightVector],
      numberOfIterations: Int,
      regularizationConstant: Double,
      learningRate: Double,
      convergenceThreshold: Double,
      lossFunction: LossFunction)
    : DataSet[WeightVector] = {
    // We have to calculate for each weight vector the sum of squared residuals,
    // and then sum them and apply regularization
    val initialLossSumDS = calculateLoss(dataPoints, initialWeightsDS, lossFunction)

    // Combine weight vector with the current loss
    val initialWeightsWithLossSum = initialWeightsDS.mapWithBcVariable(initialLossSumDS){
      (weights, loss) => (weights, loss)
    }

    val resultWithLoss = initialWeightsWithLossSum.iterateWithTermination(numberOfIterations) {
      weightsWithPreviousLossSum =>

        // Extract weight vector and loss
        val previousWeightsDS = weightsWithPreviousLossSum.map{_._1}
        val previousLossSumDS = weightsWithPreviousLossSum.map{_._2}

        val currentWeightsDS = SGDStep(
          dataPoints,
          previousWeightsDS,
          lossFunction,
          regularizationConstant,
          learningRate)

        val currentLossSumDS = calculateLoss(dataPoints, currentWeightsDS, lossFunction)

        // Check if the relative change in the loss is smaller than the
        // convergence threshold. If yes, then terminate i.e. return empty termination data set
        val termination = previousLossSumDS.filterWithBcVariable(currentLossSumDS){
          (previousLoss, currentLoss) => {
            if (previousLoss <= 0) {
              false
            } else {
              scala.math.abs((previousLoss - currentLoss)/previousLoss) >= convergenceThreshold
            }
          }
        }

        // Result for new iteration
        (currentWeightsDS.mapWithBcVariable(currentLossSumDS)((w, l) => (w, l)), termination)
    }
    // Return just the weights
    resultWithLoss.map{_._1}
  }

  def optimizeWithoutConvergenceCriterion(
      data: DataSet[LabeledVector],
      initialWeightsDS: DataSet[WeightVector],
      numberOfIterations: Int,
      regularizationConstant: Double,
      learningRate: Double,
      lossFunction: LossFunction)
    : DataSet[WeightVector] = {
    initialWeightsDS.iterate(numberOfIterations) {
      weightVectorDS => {
        SGDStep(data, weightVectorDS, lossFunction, regularizationConstant, learningRate)
      }
    }
  }

  /** Performs one iteration of Stochastic Gradient Descent using mini batches
    *
    * @param data A Dataset of LabeledVector (label, features) pairs
    * @param currentWeights A Dataset with the current weights to be optimized as its only element
    * @return A Dataset containing the weights after one stochastic gradient descent step
    */
  private def SGDStep(
    data: DataSet[(LabeledVector)],
    currentWeights: DataSet[WeightVector],
    lossFunction: LossFunction,
    regularizationConstant: Double,
    learningRate: Double)
  : DataSet[WeightVector] = {

    data.mapWithBcVariable(currentWeights){
      (data, weightVector) => (lossFunction.gradient(data, weightVector), 1)
    }.reduce{
      (left, right) =>
        val (leftGradVector, leftCount) = left
        val (rightGradVector, rightCount) = right
        // Add the left gradient to the right one
        BLAS.axpy(1.0, leftGradVector.weights, rightGradVector.weights)
        val gradients = WeightVector(
          rightGradVector.weights, leftGradVector.intercept + rightGradVector.intercept)

        (gradients , leftCount + rightCount)
    }.mapWithBcVariableIteration(currentWeights){
      (gradientCount, weightVector, iteration) => {
        val (WeightVector(weights, intercept), count) = gradientCount

        BLAS.scal(1.0/count, weights)

        val gradient = WeightVector(weights, intercept/count)

        val effectiveLearningRate = learningRate/Math.sqrt(iteration)

        val newWeights = takeStep(
          weightVector.weights,
          gradient.weights,
          regularizationConstant,
          effectiveLearningRate)

        WeightVector(
          newWeights,
          weightVector.intercept - effectiveLearningRate * gradient.intercept)
      }
    }
  }

  /** Calculates the new weights based on the gradient
    *
    * @param weightVector
    * @param gradient
    * @param regularizationConstant
    * @param learningRate
    * @return
    */
  def takeStep(
    weightVector: Vector,
    gradient: Vector,
    regularizationConstant: Double,
    learningRate: Double
    ): Vector

  /** Calculates the regularized loss, from the data and given weights.
    *
    * @param data
    * @param weightDS
    * @param lossFunction
    * @return
    */
  private def calculateLoss(
      data: DataSet[LabeledVector],
      weightDS: DataSet[WeightVector],
      lossFunction: LossFunction)
    : DataSet[Double] = {
    data.mapWithBcVariable(weightDS){
      (data, weightVector) => (lossFunction.loss(data, weightVector), 1)
    }.reduce{
      (left, right) => (left._1 + right._1, left._2 + right._2)
    }.map {
      lossCount => lossCount._1 / lossCount._2
    }
  }
}

/** Implementation of a SGD solver with L2 regularization.
  *
  * The regularization function is `1/2 ||w||_2^2` with `w` being the weight vector.
  */
class GradientDescentL2 extends GradientDescent {

  /** Calculates the new weights based on the gradient
    *
    * @param weightVector
    * @param gradient
    * @param regularizationConstant
    * @param learningRate
    * @return
    */
  override def takeStep(
      weightVector: Vector,
      gradient: Vector,
      regularizationConstant: Double,
      learningRate: Double)
    : Vector = {
    // add the gradient of the L2 regularization
    BLAS.axpy(regularizationConstant, weightVector, gradient)

    // update the weights according to the learning rate
    BLAS.axpy(-learningRate, gradient, weightVector)

    weightVector
  }
}

object GradientDescentL2 {
  def apply() = new GradientDescentL2
}

/** Implementation of a SGD solver with L1 regularization.
  *
  * The regularization function is `||w||_1` with `w` being the weight vector.
  */
class GradientDescentL1 extends GradientDescent {

  /** Calculates the new weights based on the gradient.
    *
    * @param weightVector
    * @param gradient
    * @param regularizationConstant
    * @param learningRate
    * @return
    */
  override def takeStep(
      weightVector: Vector,
      gradient: Vector,
      regularizationConstant: Double,
      learningRate: Double)
    : Vector = {
    // Update weight vector with gradient. L1 regularization has no gradient, the proximal operator
    // does the job.
    BLAS.axpy(-learningRate, gradient, weightVector)

    // Apply proximal operator (soft thresholding)
    val shrinkageVal = regularizationConstant * learningRate
    var i = 0
    while (i < weightVector.size) {
      val wi = weightVector(i)
      weightVector(i) = scala.math.signum(wi) *
        scala.math.max(0.0, scala.math.abs(wi) - shrinkageVal)
      i += 1
    }

    weightVector
  }
}

object GradientDescentL1 {
  def apply() = new GradientDescentL1
}

/** Implementation of a SGD solver without regularization.
  *
  * No regularization is applied.
  */
class SimpleGradientDescent extends GradientDescent {

  /** Calculates the new weights based on the gradient.
    *
    * @param weightVector
    * @param gradient
    * @param regularizationConstant
    * @param learningRate
    * @return
    */
  override def takeStep(
      weightVector: Vector,
      gradient: Vector,
      regularizationConstant: Double,
      learningRate: Double)
    : Vector = {
    // Update the weight vector
    BLAS.axpy(-learningRate, gradient, weightVector)
    weightVector
  }
}

object SimpleGradientDescent{
  def apply() = new SimpleGradientDescent
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy