All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.crf.ForwardBackward.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.crf

import com.johnsnowlabs.ml.crf.VectorMath._

// Class helps with Forward-Backward algorithm values precalculations
class FbCalculator(val maxLength: Int, val metadata: DatasetMetadata) {

  val labels = metadata.label2Id.size
  val logPhi = Array.fill(maxLength)(Matrix(labels, labels))
  val phi = Array.fill(maxLength)(Matrix(labels, labels))
  val alpha = Array.fill(maxLength)(Vector(labels))
  val beta = Array.fill(maxLength)(Vector(labels))
  val c = Array.fill(maxLength)(1f)

  def calculate(sentence: Instance, weights: Array[Float], scale: Float): Unit = {
    require(sentence.items.length <= maxLength)

    calcPhi(sentence, weights, scale)
    calcAlpha(sentence)
    calcBeta(sentence)
  }

  private def calcPhi(sentence: Instance, weights: Array[Float], scale: Float): Unit = {
    val length = sentence.items.length

    for (i <- 0 until length) {
      // 1. Calculate log Phi for each edge
      EdgeCalculator.fillLogEdges(sentence.items(i).values, weights, scale, metadata, logPhi(i))

      // 2. Calc exp for each matrix value
      copy(logPhi(i), phi(i))
      exp(phi(i))
    }
  }

  // ToDo Try Linear Algebra operations on top of Matrises and Vectors
  private def calcAlpha(sentence: Instance): Unit = {
    val length = sentence.items.length
    require(length <= phi.length)

    fillMatrix(alpha, 0f)
    fillVector(c, 1f)

    copy(phi(0)(0), alpha(0))
    c(0) = alpha(0).sum
    multiply(alpha(0), 1 / c(0))

    var prev = alpha(0)

    for (i <- 1 until length) {
      for (from <- 0 until labels) {
        for (to <- 0 until labels) {
          alpha(i)(to) += prev(from) * phi(i)(from)(to)
        }
      }

      c(i) = alpha(i).sum
      require(c(i) != 0f)

      multiply(alpha(i), 1 / c(i))

      prev = alpha(i)
    }
  }

  private def calcBeta(sentence: Instance): Unit = {
    val length = sentence.items.length
    require(length <= phi.length)

    fillMatrix(beta, 0f)
    fillVector(beta(length - 1), 1f / c(length - 1))
    var next = beta(length - 1)

    for (i <- Range.inclusive(length - 2, 0, -1)) {
      for (from <- 0 until labels) {
        for (to <- 0 until labels) {
          beta(i)(from) += phi(i + 1)(from)(to) * next(to)
        }
      }

      multiply(beta(i), 1 / c(i))
      next = beta(i)
    }
  }

  def addObservedExpectations(
      weights: Vector,
      instance: Instance,
      instanceLabels: InstanceLabels,
      c: Float): Unit = {

    val length = instance.items.length

    for (i <- 0 until length) {
      val label = instanceLabels.labels(i)

      // Observed Features
      for ((attrId, value) <- instance.items(i).values) {
        metadata.attrFeatures2Id
          .get((attrId, label))
          .foreach(fId => weights(fId) += c * value)
      }

      // Transition Features
      val fromLabel = if (i > 0) instanceLabels.labels(i - 1) else 0
      val meta = Transition(fromLabel, label)
      metadata.transFeature2Id.get(meta).foreach { fid =>
        weights(fid) += c
      }

    }
  }

  def addModelExpectations(weights: Vector, sentence: Instance, const: Float): Unit = {

    val length = sentence.items.length

    // Update Observed
    for (i <- 0 until length) {
      for ((attrId, value) <- sentence.items(i).values) {
        for (feature <- metadata.attr2Features(attrId)) {
          weights(feature.id) += const * c(i) * alpha(i)(feature.label) * beta(i)(
            feature.label) * value
        }
      }
    }

    // Update Transitions
    for (i <- 1 until length) {
      for ((feature, fid) <- metadata.transFeature2Id) {
        val from = feature.stateFrom
        val to = feature.stateTo

        weights(fid) += const * alpha(i - 1)(from) * phi(i)(from)(to) * beta(i)(to)
      }
    }

    // Update Transition from Start
    for ((feature, fid) <- metadata.transFeature2Id; if (feature.stateFrom == 0)) {
      val to = feature.stateTo
      weights(fid) += const * phi(0)(0)(to) * beta(0)(to)
    }
  }
}

object EdgeCalculator {
  def fillLogEdges(
      values: Seq[(Int, Float)],
      weights: Array[Float],
      scale: Float,
      metadata: DatasetMetadata,
      matrix: Matrix): Unit = {

    val labels = metadata.labels.size

    fillMatrix(matrix, 0f)

    for ((attrId, value) <- values) {
      for (from <- 0 until labels)
        for (feature <- metadata.attr2Features(attrId))
          matrix(from)(feature.label) += weights(feature.id) * value * scale
    }

    for ((feature, fid) <- metadata.transFeature2Id) {
      matrix(feature.stateFrom)(feature.stateTo) += weights(fid) * scale
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy