All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.crf.LinearChainCrf.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.crf

import VectorMath._
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import org.slf4j.LoggerFactory
import scala.util.Random


// ToDo Make c0 estimation before training
class LinearChainCrf(val params: CrfParams) {

  private val logger = LoggerFactory.getLogger("CRF")

  def log(value: => String, minLevel: Verbose.Level): Unit = {
    if (minLevel >= params.verbose) {
      logger.info(value)
    }
  }

  def trainSGD(dataset: CrfDataset): LinearChainCrfModel = {
    val metadata = dataset.metadata
    val weights = Vector(dataset.metadata.attrFeatures.length + dataset.metadata.transitions.length)
    val labels = dataset.metadata.labels.length

    if (params.randomSeed.isDefined)
      Random.setSeed(params.randomSeed.get)

    // 1. Calc max sentence Length
    val maxLength = dataset.instances.map(w => w._2.items.size).max

    log(s"labels: $labels", Verbose.TrainingStat)
    log(s"instances: ${dataset.instances.size}", Verbose.TrainingStat)
    log(s"features: ${weights.length}", Verbose.TrainingStat)
    log(s"maxLength: $maxLength", Verbose.TrainingStat)


    // 2. Allocate reusable space
    val context = new FbCalculator(maxLength, metadata)

    val bestW = Vector(weights.length)
    var bestLoss = Float.MaxValue
    var lastLoss = Float.MaxValue

    var notImprovedEpochs = 0

    val decayStrategy = new L2DecayStrategy(dataset.instances.size, params.l2, params.c0)

    for (epoch <- 0 until params.maxEpochs
         if notImprovedEpochs < 10 || epoch < params.minEpochs) {

      var loss = 0f

      log(s"\nEpoch: $epoch, eta: ${decayStrategy.eta}", Verbose.Epochs)
      val started = System.nanoTime()

      val shuffled = Random.shuffle(dataset.instances)

      var instancesCount = 0
      for ((labels, sentence) <- shuffled) {
        decayStrategy.nextStep()

        // 1. Calculate values for further usage
        context.calculate(sentence, weights, decayStrategy.getScale)

        // 2. Make one gradient step
        doSgdStep(sentence, labels, decayStrategy.alpha, weights, context)

        // 3. Calculate loss
        loss += getLoss(sentence, labels, context)

        // 4. Track Weights
        instancesCount += 1
        if (instancesCount % 1000 == 0)
          decayStrategy.reset(weights)
      }

      // Return weights to normal values
      decayStrategy.reset(weights)

      val l2Loss = params.l2 * weights.map(w => w*w).sum

      val totalLoss = loss + l2Loss

      log(s"finished, time: ${(System.nanoTime() - started)/1e9}", Verbose.Epochs)
      log(s"Loss = $totalLoss, logLoss = $loss, l2Loss = $l2Loss", Verbose.Epochs)

      // Update best solution if loss is lower
      if (totalLoss < bestLoss) {
        bestLoss = totalLoss
        copy(weights, bestW)

        if ((bestLoss - totalLoss)/totalLoss < params.lossEps)
          notImprovedEpochs = 0
        else
          notImprovedEpochs += 1
      }
      else
        notImprovedEpochs += 1

      lastLoss = totalLoss
    }

    new LinearChainCrfModel(bestW, metadata)
  }

  private def getLoss(sentence: Instance, labels: InstanceLabels, context: FbCalculator): Float = {
    val length = sentence.items.length

    var prevLabel = 0
    var result = 0f
    for (i <- 0 until length) {
      result -= context.logPhi(i)(prevLabel)(labels.labels(i))
      prevLabel = labels.labels(i)

      result += Math.log(context.c(i)).toFloat
    }

    if (result >= 0) {
      assert(result >= 0)
    }

    result
  }

  // Step for minimizing model Log Likelihood
  def doSgdStep(sentence: Instance,
                labels: InstanceLabels,
                a: Float,
                weights: Array[Float],
                context: FbCalculator): Unit = {

    // Make Gradient Step
    // Minimizing -log likelihood
    // Gradient = [Observed Expectation] - [Model Expectations]
    // Weights = Weights + a*Gradient

    // 1. Plus Observed Expectation
    context.addObservedExpectations(weights, sentence, labels, a)

    // 2. Minus Model Expectations
    context.addModelExpectations(weights, sentence, -a)
  }
}

class L2DecayStrategy(val instances: Int,
                      val l2: Float,
                      val c0: Float = 1000
                     ) {

  // Correct weights is equal weights * scale
  private var scale: Float = 1f

  // Number of step SGD
  private var step = 0

  // Regularization for one instance
  private val lambda = 2f*l2 / instances

  def getScale: Float = scale

  // Scaled coefficient for Gradient step
  def alpha: Float = eta / scale

  // Real coefficient for Gradient step
  def eta: Float = 1f / (lambda * (step + c0))

  def nextStep(): Unit = {
    step += 1
    scale = scale * (1f - eta * lambda)
  }

  def reset(weights: Vector): Unit = {
    VectorMath.multiply(weights, scale)
    scale = 1f
  }
}


/**
  * Hyper Parameters and Setting for LinearChainCrf training
  * @param minEpochs - Minimum number of epochs to train
  * @param maxEpochs - Maximum number of epochs to train
  * @param l2 - l2 regularization coefficient
  * @param c0 - Initial number of steps in decay strategy
  * @param lossEps - If loss after a SGD epochs haven't improved (absolutely) more than lossEps, then training is stopped
  *
  * @param randomSeed - Seed for random
  * @param verbose - Level of verbosity during training procedure
 */
case class CrfParams
(
  minEpochs: Int = 10,
  maxEpochs: Int = 1000,
  l2: Float = 1f,
  c0: Int = 1500000,
  lossEps: Float = 1e-4f,

  randomSeed: Option[Int] = None,
  verbose: Verbose.Value = Verbose.Silent
)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy