All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.spell.context

import com.github.liblevenshtein.transducer.{Candidate, ITransducer}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import com.johnsnowlabs.nlp.annotators.spell.context.parser._
import com.johnsnowlabs.nlp.serialization._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{BooleanParam, FloatParam, IntArrayParam, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Dataset, SparkSession}
import org.slf4j.LoggerFactory

import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable

/** Implements a deep-learning based Noisy Channel Model Spell Algorithm. Correction candidates
  * are extracted combining context information and word information.
  *
  * Spell Checking is a sequence to sequence mapping problem. Given an input sequence, potentially
  * containing a certain number of errors, `ContextSpellChecker` will rank correction sequences
  * according to three things:
  *   1. Different correction candidates for each word — '''word level'''.
  *   1. The surrounding text of each word, i.e. it’s context — '''sentence level'''.
  *   1. The relative cost of different correction candidates according to the edit operations at
  *      the character level it requires — '''subword level'''.
  *
  * For an in-depth explanation of the module see the article
  * [[https://medium.com/spark-nlp/applying-context-aware-spell-checking-in-spark-nlp-3c29c46963bc Applying Context Aware Spell Checking in Spark NLP]].
  *
  * This is the instantiated model of the [[ContextSpellCheckerApproach]]. For training your own
  * model, please see the documentation of that class.
  *
  * Pretrained models can be loaded with `pretrained` of the companion object:
  * {{{
  * val spellChecker = ContextSpellCheckerModel.pretrained()
  *   .setInputCols("token")
  *   .setOutputCol("checked")
  * }}}
  * The default model is `"spellcheck_dl"`, if no name is provided. For available pretrained
  * models please see the [[https://sparknlp.org/models?task=Spell+Check Models Hub]].
  *
  * For extended examples of usage, see the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/italian/Training_Context_Spell_Checker_Italian.ipynb Examples]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/spell/context/ContextSpellCheckerTestSpec.scala ContextSpellCheckerTestSpec]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
  * import com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel
  * import org.apache.spark.ml.Pipeline
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("doc")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols(Array("doc"))
  *   .setOutputCol("token")
  *
  * val spellChecker = ContextSpellCheckerModel
  *   .pretrained()
  *   .setTradeOff(12.0f)
  *   .setInputCols("token")
  *   .setOutputCol("checked")
  *
  * val pipeline = new Pipeline().setStages(Array(
  *   documentAssembler,
  *   tokenizer,
  *   spellChecker
  * ))
  *
  * val data = Seq("It was a cold , dreary day and the country was white with smow .").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.select("checked.result").show(false)
  * +--------------------------------------------------------------------------------+
  * |result                                                                          |
  * +--------------------------------------------------------------------------------+
  * |[It, was, a, cold, ,, dreary, day, and, the, country, was, white, with, snow, .]|
  * +--------------------------------------------------------------------------------+
  * }}}
  *
  * @see
  *   [[com.johnsnowlabs.nlp.annotators.spell.norvig.NorvigSweetingModel NorvigSweetingModel]] and
  *   [[com.johnsnowlabs.nlp.annotators.spell.symmetric.SymmetricDeleteModel SymmetricDeleteModel]]
  *   for alternative approaches to spell checking
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class ContextSpellCheckerModel(override val uid: String)
    extends AnnotatorModel[ContextSpellCheckerModel]
    with HasSimpleAnnotate[ContextSpellCheckerModel]
    with WeightedLevenshtein
    with WriteTensorflowModel
    with ParamsAndFeaturesWritable
    with HasTransducerFeatures
    with HasEngine {

  private val logger = LoggerFactory.getLogger("ContextSpellCheckerModel")

  val transducer = new TransducerFeature(this, "mainVocabularyTransducer")

  /** @group setParam */
  def setVocabTransducer(trans: ITransducer[Candidate]): this.type = {
    val main = new MainVocab()
    main.transducer = trans

    set(transducer, main)
  }

  val specialTransducers = new TransducerSeqFeature(this, "specialClassesTransducers")

  /** @group setParam */
  def setSpecialClassesTransducers(transducers: Seq[SpecialClassParser]): this.type = {
    set(specialTransducers, transducers.toArray)
  }

  /** Frequency words from the vocabulary
    *
    * @group param
    */
  val vocabFreq = new MapFeature[String, Double](this, "vocabFreq")

  /** @group setParam */
  def setVocabFreq(v: Map[String, Double]): this.type = set(vocabFreq, v)

  /** Mapping of ids to vocabulary
    *
    * @group param
    */
  val idsVocab = new MapFeature[Int, String](this, "idsVocab")

  /** Mapping of vocabulary to ids
    *
    * @group param
    */
  val vocabIds = new MapFeature[String, Int](this, "vocabIds")

  /** @group setParam */
  def setVocabIds(v: Map[String, Int]): this.type = {
    set(idsVocab, v.map(_.swap))
    set(vocabIds, v)
  }

  /** Classes the spell checker recognizes
    *
    * @group param
    */
  val classes: MapFeature[Int, (Int, Int)] = new MapFeature(this, "classes")

  /** @group setParam */
  def setClasses(c: Map[Int, (Int, Int)]): this.type = set(classes, c)

  /** Maximum distance for the generated candidates for every word, minimum 1.
    *
    * @group param
    */
  val wordMaxDistance = new IntParam(
    this,
    "wordMaxDistance",
    "Maximum distance for the generated candidates for every word, minimum 1.")

  /** @group setParam */
  def setWordMaxDistance(k: Int): this.type = set(wordMaxDistance, k)

  /** Maximum number of candidates for every word (Default: `6`).
    *
    * @group param
    */
  val maxCandidates =
    new IntParam(this, "maxCandidates", "Maximum number of candidates for every word.")

  /** @group setParam */
  def setMaxCandidates(k: Int): this.type = set(maxCandidates, k)

  /** What case combinations to try when generating candidates (Default: `CandidateStrategy.ALL`).
    *
    * @group param
    */
  val caseStrategy = new IntParam(
    this,
    "caseStrategy",
    "What case combinations to try when generating candidates.")

  /** @group setParam */
  def setCaseStrategy(k: Int): this.type = set(caseStrategy, k)

  /** Threshold perplexity for a word to be considered as an error.
    *
    * @group param
    */
  val errorThreshold = new FloatParam(
    this,
    "errorThreshold",
    "Threshold perplexity for a word to be considered as an error.")

  /** @group setParam */
  def setErrorThreshold(t: Float): this.type = set(errorThreshold, t)

  /** Tradeoff between the cost of a word and a transition in the language model (Default:
    * `18.0f`).
    *
    * @group param
    */
  val tradeoff = new FloatParam(
    this,
    "tradeoff",
    "Tradeoff between the cost of a word and a transition in the language model.")

  /** @group setParam */
  def setTradeOff(lambda: Float): this.type = set(tradeoff, lambda)

  /** Controls the influence of individual word frequency in the decision (Default: `120.0f`).
    *
    * @group param
    */
  val gamma = new FloatParam(
    this,
    "gamma",
    "Controls the influence of individual word frequency in the decision.")

  /** @group setParam */
  def setGamma(g: Float): this.type = set(gamma, g)

  val weights: MapFeature[String, Map[String, Float]] =
    new MapFeature[String, Map[String, Float]](this, "levenshteinWeights")

  /** @group setParam */
  def setWeights(w: Map[String, Map[String, Float]]): this.type = set(weights, w)

  // for Python access

  /** @group setParam */
  def setWeights(w: util.HashMap[String, util.HashMap[String, Double]]): this.type = {

    val ws = w.asScala.mapValues(_.asScala.mapValues(_.toFloat).toMap).toMap
    set(weights, ws)
  }

  /** When set to true new lines will be treated as any other character (Default: `false`). When
    * set to false correction is applied on paragraphs as defined by newline characters.
    *
    * @group param
    */
  val useNewLines = new BooleanParam(
    this,
    "trim",
    "When set to true new lines will be treated as any other character, when set to false" +
      " correction is applied on paragraphs as defined by newline characters.")

  /** @group setParam */
  def setUseNewLines(useIt: Boolean): this.type = set(useNewLines, useIt)

  /** Maximum size for the window used to remember history prior to every correction (Default:
    * `5`).
    *
    * @group param
    */
  val maxWindowLen = new IntParam(
    this,
    "maxWindowLen",
    "Maximum size for the window used to remember history prior to every correction.")

  /** @group setParam */
  def setMaxWindowLen(w: Int): this.type = set(maxWindowLen, w)

  /** ConfigProto from tensorflow, serialized into byte array. Get with
    * config_proto.SerializeToString()
    *
    * @group param
    */
  val configProtoBytes = new IntArrayParam(
    this,
    "configProtoBytes",
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")

  /** @group setParam */
  def setConfigProtoBytes(bytes: Array[Int]): ContextSpellCheckerModel.this.type =
    set(this.configProtoBytes, bytes)

  /** @group getParam */
  def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))

  /** Whether to correct special symbols or skip spell checking for them
    *
    * @group param
    */
  val correctSymbols: BooleanParam = new BooleanParam(
    this,
    "correctSymbols",
    "Whether to correct special symbols or skip spell checking for them")

  /** @group setParam */
  def setCorrectSymbols(value: Boolean): this.type = set(correctSymbols, value)

  setDefault(correctSymbols -> false)

  /** If true will compare tokens in low case with vocabulary (Default: `false`)
    *
    * @group param
    */
  val compareLowcase: BooleanParam = new BooleanParam(
    this,
    "compareLowcase",
    "If true will compare tokens in low case with vocabulary")

  /** @group setParam */
  def setCompareLowcase(value: Boolean): this.type = set(compareLowcase, value)

  setDefault(compareLowcase -> false)

  /** @group getParam */
  def getWordClasses: Seq[(String, AnnotatorType)] = $$(specialTransducers).map {
    case transducer: RegexParser =>
      (transducer.label, "RegexParser")
    case transducer: VocabParser =>
      (transducer.label, "VocabParser")
  }

  /* update a regex class */
  def updateRegexClass(label: String, regex: String): ContextSpellCheckerModel = {
    val classes = $$(specialTransducers)
    require(
      classes.count(_.label == label) == 1,
      s"Not found regex class $label. You can only update existing classes.")

    classes.filter(_.label.equals(label)).head match {
      case r: RegexParser =>
        r.regex = regex
        r.transducer = r.generateTransducer
      case _ => require(requirement = false, s"Class $label is not a regex class.")
    }
    this
  }

  /* update a vocabulary class */
  def updateVocabClass(
      label: String,
      vocabList: util.ArrayList[String],
      append: Boolean = true): ContextSpellCheckerModel = {
    val vocab = scala.collection.mutable.Set(vocabList.toArray.map(_.toString): _*)
    val classes = $$(specialTransducers)
    require(
      classes.count(_.label == label) == 1,
      s"Not found vocab class $label. You can only update existing classes.")

    classes.filter(_.label.equals(label)).head match {
      case v: VocabParser =>
        if (v.vocab.eq(null)) v.vocab = mutable.Set.empty[String]

        val newSet = if (append) v.vocab ++ vocab else vocab
        v.vocab = newSet
        v.transducer = v.generateTransducer
      case _ => require(requirement = false, s"Class $label is not a vocabulary class.")
    }
    this
  }

  setDefault(
    tradeoff -> 18.0f,
    gamma -> 120.0f,
    useNewLines -> false,
    maxCandidates -> 6,
    maxWindowLen -> 5,
    caseStrategy -> CandidateStrategy.ALL)

  // the scores for the EOS (end of sentence), and BOS (beginning of sentence)
  private val eosScore = .01
  private val bosScore = 1.0

  private var _model: Option[Broadcast[TensorflowSpell]] = None

  def getModelIfNotSet: TensorflowSpell = _model.get.value

  def setModelIfNotSet(spark: SparkSession, tensorflow: TensorflowWrapper): this.type = {
    if (_model.isEmpty) {
      _model = Some(spark.sparkContext.broadcast(new TensorflowSpell(tensorflow, Verbose.Silent)))
    }
    this
  }

  /* trellis goes like (label, weight, candidate)*/
  def decodeViterbi(trellis: Array[Array[(String, Double, String)]]): (Array[String], Double) = {

    // encode words with ids
    val encTrellis = Array(Array(($$(vocabIds)("_BOS_"), bosScore, "_BOS_"))) ++
      trellis.map(_.map { case (label, weight, cand) =>
        // at this point we keep only those candidates that are in the vocabulary
        ($$(vocabIds).get(label), weight, cand)
      }.filter(_._1.isDefined).map { case (x, y, z) => (x.get, y, z) }) ++
      Array(Array(($$(vocabIds)("_EOS_"), eosScore, "_EOS_")))

    // init
    var pathsIds = Array(Array($$(vocabIds)("_BOS_")))
    var pathWords = Array(Array("_BOS_"))
    var costs = Array(bosScore) // cost for each of the paths

    for (i <- 1 until encTrellis.length if pathsIds.forall(_.nonEmpty)) {

      var newPaths: Array[Array[Int]] = Array()
      var newWords: Array[Array[String]] = Array()
      var newCosts = Array[Double]()

      /* compute all the costs for all transitions in current step */
      val expPaths = pathsIds
        .map { p =>
          p :+ p.head
        }
        . // we need a placeholder, put the head.
        map(_.takeRight($(maxWindowLen)))
      val cids = expPaths.map(_.map { id =>
        $$(classes).apply(id)._1
      })
      val cwids = expPaths.map(_.map { id =>
        $$(classes).apply(id)._2
      })

      val candCids = encTrellis(i).map(_._1).map { id =>
        $$(classes).apply(id)._1
      }
      val candWids = encTrellis(i).map(_._1).map { id =>
        $$(classes).apply(id)._2
      }
      val expPathsCosts_ = getModelIfNotSet
        .predict_(
          pathsIds.map(_.takeRight($(maxWindowLen))),
          cids,
          cwids,
          candCids,
          candWids,
          configProtoBytes = getConfigProtoBytes)
        .toArray

      for { ((state, wcost, cand), idx) <- encTrellis(i).zipWithIndex } {
        var minCost = Double.MaxValue
        var minPath = Array[Int]()
        var minWords = Array[String]()

        val z = (pathsIds, costs, pathWords).zipped.toList

        for (((path, pathCost, cands), pi) <- z.zipWithIndex) {
          // compute cost to arrive to this 'state' coming from that 'path'
          val mult = if (i > 1) costs.length else 0
          val ppl_ = expPathsCosts_(encTrellis(i).size * pi + idx)

          val cost = pathCost + ppl_
          logger.debug(s"${$$(idsVocab).apply(path.last)} -> $cand, $ppl_, $cost")

          if (cost < minCost) {
            minCost = cost
            minPath = path :+ state
            minWords = cands :+ cand
          }
        }
        newPaths = newPaths :+ minPath
        newWords = newWords :+ minWords
        newCosts = newCosts :+ minCost + wcost * getOrDefault(tradeoff)
      }
      pathsIds = newPaths
      pathWords = newWords
      costs = newCosts

      // log paths and costs
      pathWords.zip(costs).foreach { case (path, cost) =>
        logger.debug(s"${path.toList}, $cost")
      }

    }
    // return the path with the lowest cost, and the cost
    val (minPath, minCost) = pathWords.zip(costs).minBy(_._2)

    if (minPath.nonEmpty)
      (minPath.tail.dropRight(1), minCost)
    else
      (minPath, minCost)
  }

  def getClassCandidates(
      transducer: ITransducer[Candidate],
      token: String,
      label: String,
      maxDist: Int,
      limit: Int = 2) = {
    transducer
      .transduce(token, maxDist)
      .asScala
      .map { cand =>
        // if weights are available, we use them
        val weight = weights.get
          .map(ws => wLevenshteinDist(cand.term, token, ws))
          .getOrElse(cand.distance.toFloat)

        (cand.term, label, weight)
      }
      .toSeq
      .sortBy(_._3)
      .take(limit)
  }

  def getVocabCandidates(token: String, maxDist: Int) = {
    val trans = $$(transducer).transducer
    // we use all case information as it comes
    val plainCandidates =
      trans
        .transduce(token, maxDist)
        .asScala
        .toList
        .map(c => (c.term, c.term, c.distance.toFloat))

    // We evaluate some case variations
    val tryUpperCase = getOrDefault(caseStrategy) == CandidateStrategy.ALL_UPPER_CASE ||
      getOrDefault(caseStrategy) == CandidateStrategy.ALL

    val tryFirstCapitalized =
      getOrDefault(caseStrategy) == CandidateStrategy.FIRST_LETTER_CAPITALIZED ||
        getOrDefault(caseStrategy) == CandidateStrategy.ALL

    val caseCandidates = if (token.isUpperCase && tryUpperCase) {
      trans
        .transduce(token.toLowerCase)
        .asScala
        .toList
        .map(c => (c.term.toUpperCase, c.term, c.distance.toFloat))
    } else if (token.isFirstLetterCapitalized && tryFirstCapitalized) {
      trans
        .transduce(token.toLowerCase)
        .asScala
        .toList
        .map(c => (c.term.capitalizeFirstLetter, c.term, c.distance.toFloat))
    } else Seq.empty

    plainCandidates ++ caseCandidates
  }

  implicit class StringTools(s: String) {
    def isUpperCase() = s.toUpperCase.equals(s)

    def isLowerCase() = s.toLowerCase.equals(s)

    def isFirstLetterCapitalized() =
      s.headOption
        .map { fl =>
          fl.isUpper && s.tail.isLowerCase
        }
        .getOrElse(false)

    def capitalizeFirstLetter() = s.head.toUpper + s.tail
  }

  override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
    require(_model.isDefined, "Tensorflow model has not been initialized")
    dataset
  }

  /** takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    val decodedSentPaths =
      annotations.groupBy(_.metadata.getOrElse("sentence", "0").toInt).mapValues { sentTokens =>
        val (decodedPath, cost) = toOption(getOrDefault(useNewLines))
          .map { _ =>
            val idxs = Seq(-1) ++ sentTokens.zipWithIndex
              .filter { case (a, _) =>
                a.result.equals(System.lineSeparator) || a.result.equals(System.lineSeparator * 2)
              }
              .map(_._2) ++ Seq(annotations.length)
            idxs
              .zip(idxs.tail)
              .map { case (s, e) =>
                decodeViterbi(
                  computeTrellis(
                    sentTokens.slice(s + 1, e),
                    computeMask(sentTokens.slice(s + 1, e))))
              }
              .reduceLeft[(Array[String], Double)]({ case ((dPathA, pCostA), (dPathB, pCostB)) =>
                (dPathA ++ Seq(System.lineSeparator) ++ dPathB, pCostA + pCostB)
              })
          }
          .getOrElse(decodeViterbi(computeTrellis(sentTokens, computeMask(sentTokens))))
        // ToDo: This is a backup plan for empty DecodedPath -- fix me!!
        if (decodedPath.nonEmpty)
          sentTokens.zip(decodedPath).map { case (orig, correct) =>
            orig.copy(result = correct, metadata = orig.metadata.updated("cost", cost.toString))
          }
        else
          sentTokens.map(orig => orig.copy(metadata = orig.metadata.updated("cost", "0")))
      }

    decodedSentPaths.values.toList.reverse.flatten
  }

  def toOption(boolean: Boolean): Option[Boolean] = {
    if (boolean)
      Some(boolean)
    else
      None
  }

  /* detects which tokens need correction
   *
   * returns a mask with boolean flag for each word indicating whether it needs correction or not
   *
   * two causes for a word to need correction, 1. high perplexity or 2. out of vocabulary
   * */
  def computeMask(annotations: Seq[Annotation]): Array[Boolean] = {
    val threshold = getOrDefault(errorThreshold)
    val unkCode = $$(vocabIds).get("_UNK_").get

    /* try to decide whether words need correction or not */
    // first pass - perplexities
    val encodedSent = Array($$(vocabIds)("_BOS_")) ++ annotations.map { ann =>
      if ($(compareLowcase))
        $$(vocabIds)
          .get(ann.result)
          .getOrElse($$(vocabIds).get(ann.result.toLowerCase).getOrElse(unkCode))
      else
        $$(vocabIds).get(ann.result).getOrElse(unkCode)
    } ++ Array($$(vocabIds)("_EOS_"))

    val cids = encodedSent.map { id =>
      $$(classes).apply(id)._1
    }
    val cwids = encodedSent.map { id =>
      $$(classes).apply(id)._2
    }

    val perplexities = getModelIfNotSet
      .pplEachWord(Array(encodedSent), Array(cids), Array(cwids))
      .map(_ > threshold)

    perplexities
      .zip(perplexities.tail)
      .zip(encodedSent.tail)
      .
      // if the word to the right needs correction, this word needs it too and is word in vocabulary ?
      map { case ((needCorrection, nextNeedCorrection), code) =>
        if (nextNeedCorrection) true else needCorrection || code == unkCode
      }
  }

  def computeTrellis(annotations: Seq[Annotation], mask: Seq[Boolean]) = {
    annotations
      .zip(mask)
      .map { case (annotation, needCorrection) =>
        val token = annotation.result
        var correctionCondition = needCorrection
        if (! $(correctSymbols))
          correctionCondition = needCorrection & token
            .replaceAll("[^A-Za-z0-9]+", "")
            .length > 0

        if (correctionCondition) {
          // ask each token class for candidates, keep the one with lower cost
          var candLabelWeight = $$(specialTransducers).flatMap { specialParser =>
            if (specialParser.transducer == null)
              throw new RuntimeException(s"${specialParser.label}")
            getClassCandidates(
              specialParser.transducer,
              token,
              specialParser.label,
              getOrDefault(wordMaxDistance) - 1)
          } ++ getVocabCandidates(token, getOrDefault(wordMaxDistance) - 1)

          // now try to relax distance requirements for candidates
          if (token.length > 4 && candLabelWeight.isEmpty)
            candLabelWeight = $$(specialTransducers).flatMap { specialParser =>
              getClassCandidates(
                specialParser.transducer,
                token,
                specialParser.label,
                getOrDefault(wordMaxDistance))
            } ++ getVocabCandidates(token, getOrDefault(wordMaxDistance))

          if (candLabelWeight.isEmpty)
            candLabelWeight = Array((token, "_UNK_", 3.0f))

          // label is a dictionary word for the main transducer, or a label such as _NUM_ for special classes
          val labelWeightCand = candLabelWeight
            .map { case (term, label, dist) =>
              // optional re-ranking of candidates according to special distance
              val d = get(weights)
                .map { w =>
                  wLevenshteinDist(term, token, w)
                }
                .getOrElse(dist)
              val weight = d - $$(vocabFreq).getOrElse(label, 0.0) / getOrDefault(gamma)
              (label, weight, term)
            }
            .sortBy(_._2)
            .take(getOrDefault(maxCandidates))
          logger.debug(
            s"""$token -> ${labelWeightCand.toList.take(getOrDefault(maxCandidates))}""")
          labelWeightCand.toArray // [(String, Double, String)]
        } else {
          Array(("_UNK_", .2, token))
        }
      }
      .toArray
  }

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  def this() = this(Identifiable.randomUID("SPELL"))

  /** Input Annotator Types: TOKEN
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[String] = Array(AnnotatorType.TOKEN)

  /** Output Annotator Types: TOKEN
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = AnnotatorType.TOKEN

  override def onWrite(path: String, spark: SparkSession): Unit = {
    super.onWrite(path, spark)
    writeTensorflowModel(
      path,
      spark,
      getModelIfNotSet.tensorflow,
      "_langmodeldl",
      ContextSpellCheckerModel.tfFile,
      configProtoBytes = getConfigProtoBytes)
  }
}

trait ReadsLanguageModelGraph
    extends ParamsAndFeaturesReadable[ContextSpellCheckerModel]
    with ReadTensorflowModel {

  override val tfFile = "tensorflow_lm"

  def readLanguageModelGraph(
      instance: ContextSpellCheckerModel,
      path: String,
      spark: SparkSession): Unit = {
    val tf = readTensorflowModel(path, spark, "_langmodeldl")
    instance.setModelIfNotSet(spark, tf)
  }

  addReader(readLanguageModelGraph)
}

trait ReadablePretrainedContextSpell
    extends ReadsLanguageModelGraph
    with HasPretrained[ContextSpellCheckerModel] {
  override val defaultModelName: Some[String] = Some("spellcheck_dl")

  /** Java compliant-overrides */
  override def pretrained(): ContextSpellCheckerModel = super.pretrained()

  override def pretrained(name: String): ContextSpellCheckerModel = super.pretrained(name)

  override def pretrained(name: String, lang: String): ContextSpellCheckerModel =
    super.pretrained(name, lang)

  override def pretrained(
      name: String,
      lang: String,
      remoteLoc: String): ContextSpellCheckerModel = super.pretrained(name, lang, remoteLoc)
}

/** This is the companion object of [[ContextSpellCheckerModel]]. Please refer to that class for
  * the documentation.
  */
object ContextSpellCheckerModel extends ReadablePretrainedContextSpell




© 2015 - 2024 Weber Informatics LLC | Privacy Policy