All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.crf.LinearChainCrfModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.crf

import com.johnsnowlabs.ml.crf.VectorMath._
import com.johnsnowlabs.nlp.annotators.param.{
  SerializedAnnotatorComponent,
  WritableAnnotatorComponent
}

class LinearChainCrfModel(val weights: Array[Float], val metadata: DatasetMetadata)
    extends WritableAnnotatorComponent {

  val labels = metadata.label2Id.size

  def predict(instance: Instance): InstanceLabels = {
    if (instance.items.isEmpty)
      return InstanceLabels(Seq.empty)

    var newBestPath = Vector(labels)
    var bestPath = Vector(labels)

    val matrix = Matrix(labels, labels)
    EdgeCalculator.fillLogEdges(instance.items.head.values, weights, 1f, metadata, matrix)
    copy(matrix(0), bestPath)

    val length = instance.items.length

    val prevIdx = Array.fill[Int](length, labels)(0)

    // Calculate best path
    for (i <- 1 until length) {
      val features = instance.items(i).values
      EdgeCalculator.fillLogEdges(features, weights, 1f, metadata, matrix)
      fillVector(newBestPath, Float.MinValue)

      for (from <- 0 until labels) {
        for (to <- 0 until labels) {
          val newPath = bestPath(from) + matrix(from)(to)

          if (newBestPath(to) < newPath) {
            newBestPath(to) = newPath
            prevIdx(i)(to) = from
          }
        }
      }

      val tmp = newBestPath
      newBestPath = bestPath
      bestPath = tmp
    }

    // Restore best path
    val result = Array.fill(length)(0)
    var best = 0f
    for (i <- 0 until labels) {
      if (bestPath(i) > best) {
        best = bestPath(i)
        result(length - 1) = i
      }
    }

    for (i <- Range.inclusive(length - 2, 0, -1)) {
      result(i) = prevIdx(i + 1)(result(i + 1))
    }

    InstanceLabels(result)
  }

  override def serialize: SerializedLinearChainCrfModel = {
    SerializedLinearChainCrfModel(
      weights.toList,
      metadata.serialize.asInstanceOf[SerializedDatasetMetadata])
  }

  /** Removes features with weights less then minW
    * @param minW
    *   Minimum weight to keep features
    * @return
    *   Shrinked model
    */
  def shrink(minW: Float): LinearChainCrfModel = {
    val (filteredWeights, featureIds) =
      weights.zipWithIndex.filter(p => Math.abs(p._1) >= minW).unzip
    val filteredMeta = metadata.filterFeatures(featureIds)
    new LinearChainCrfModel(filteredWeights, filteredMeta)
  }

}

case class SerializedLinearChainCrfModel(weights: Seq[Float], metadata: SerializedDatasetMetadata)
    extends SerializedAnnotatorComponent[LinearChainCrfModel] {
  override def deserialize: LinearChainCrfModel = {
    new LinearChainCrfModel(weights.toArray, metadata.deserialize)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy