All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.spell.context.WeightedLevenshtein.scala Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.spell.context

import scala.collection.mutable
import scala.io.Codec
import scala.math.min

trait WeightedLevenshtein {

  def levenshteinDist(s11: String, s22: String)(cost: (String, String) => Float): Float = {

    // cope with start of string
    val s1 = s"^${s11}_"
    val s2 = s"^${s22}_"

    val s1_ = s"_^${s11}_"
    val s2_ = s"_^${s22}_"

    val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) =>
      if (j == 0) i * 1.0f else if (i == 0) j * 1.0f else 0.0f
    }

    for (j <- 1 to s2.length; i <- 1 to s1.length)
      dist(j)(i) =
        if (s2(j - 1) == s1(i - 1)) dist(j - 1)(i - 1)
        else {
          minimum(
            dist(j - 1)(i) + cost(s2_.substring(j - 1, j + 1), s1(i - 1) + "Ɛ"), // insert in s1
            dist(j)(i - 1) + cost(s2(j - 1) + "Ɛ", s1_.substring(i - 1, i + 1)), // insert in s2
            dist(j - 1)(i - 1) + cost(s2(j - 1).toString, s1(i - 1).toString))
        }

    dist(s2.length)(s1.length)
  }

  /* weighted levenshtein distance */
  def wLevenshteinDist(s1: String, s2: String, weights: Map[String, Map[String, Float]]) =
    levenshteinDist(s1, s2)(genCost(weights))

  def loadWeights(filename: String): Map[String, Map[String, Float]] = {
    // store word ids
    val vocabIdxs = mutable.HashMap[String, mutable.Map[String, Float]]()

    implicit val codec: Codec = Codec.UTF8

    scala.io.Source.fromFile(filename).getLines.foreach { case line =>
      val lineFields = line.split("\\|")
      val dist = vocabIdxs
        .getOrElse(lineFields(0), mutable.Map[String, Float]())
        .updated(lineFields(1), lineFields(2).toFloat)
      vocabIdxs.update(lineFields(0), dist)
    }
    vocabIdxs.toMap.mapValues(_.toMap)
  }

  private def genCost(weights: Map[String, Map[String, Float]])(a: String, b: String): Float = {
    if (weights.contains(a) && weights(a).contains(b))
      weights(a)(b)
    else if (a == b) {
      0.0f
    } else
      1.0f
  }

  private def minimum(i1: Float, i2: Float, i3: Float) = min(min(i1, i2), i3)

  def learnDist(s1: String, s2: String): Seq[(String, String)] = {
    val acc: Seq[(String, String)] = Seq.empty
    val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) =>
      if (j == 0) i * 1.0f else if (i == 0) j * 1.0f else 0.0f
    }

    for (j <- 1 to s2.length; i <- 1 to s1.length)
      dist(j)(i) =
        if (s2(j - 1) == s1(i - 1)) dist(j - 1)(i - 1)
        else minimum(dist(j - 1)(i) + 1.0f, dist(j)(i - 1) + 1.0f, dist(j - 1)(i - 1) + 1.0f)

    backTrack(dist, s2, s1, s2.length, s1.length, acc)
  }

  def backTrack(
      dist: Array[Array[Float]],
      s2: String,
      s1: String,
      j: Int,
      i: Int,
      acc: Seq[(String, String)]): Seq[(String, String)] = {

    if (s2(j - 1) == s1(i - 1)) {
      if (j == 1 && i == 1)
        acc
      else
        backTrack(dist, s2, s1, j - 1, i - 1, acc)
    } else {
      val pSteps = Map(
        dist(j - 1)(i) -> ("", s2(j - 1).toString, j - 1, i),
        dist(j)(i - 1) -> (s1(i - 1).toString, "", j, i - 1),
        dist(j - 1)(i - 1) -> (s1(i - 1).toString, s2(j - 1).toString, j - 1, i - 1))

      val best = pSteps.minBy(_._1)._2
      backTrack(dist, s2, s1, best._3, best._4, acc :+ (best._1, best._2))
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy