All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel.scala Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, TOKEN, WORD_EMBEDDINGS}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common.{
  TokenPieceEmbeddings,
  TokenizedWithSentence,
  WordpieceEmbeddingsSentence
}
import com.johnsnowlabs.nlp.util.io.ResourceHelper.spark.implicits._
import com.johnsnowlabs.storage.Database.Name
import com.johnsnowlabs.storage._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.IntParam
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/** Word Embeddings lookup annotator that maps tokens to vectors
  *
  * This is the instantiated model of [[WordEmbeddings]].
  *
  * Pretrained models can be loaded with `pretrained` of the companion object:
  * {{{
  * val embeddings = WordEmbeddingsModel.pretrained()
  *     .setInputCols("document", "token")
  *     .setOutputCol("embeddings")
  * }}}
  * The default model is `"glove_100d"`, if no name is provided. For available pretrained models
  * please see the [[https://sparknlp.org/models?task=Embeddings Models Hub]].
  *
  * There are also two convenient functions to retrieve the embeddings coverage with respect to
  * the transformed dataset:
  *   - `withCoverageColumn(dataset, embeddingsCol, outputCol)`: Adds a custom column with word
  *     coverage stats for the embedded field: (`coveredWords`, `totalWords`,
  *     `coveragePercentage`). This creates a new column with statistics for each row.
  *     {{{
  *     val wordsCoverage = WordEmbeddingsModel.withCoverageColumn(resultDF, "embeddings", "cov_embeddings")
  *     wordsCoverage.select("text","cov_embeddings").show(false)
  *     +-------------------+--------------+
  *     |text               |cov_embeddings|
  *     +-------------------+--------------+
  *     |This is a sentence.|[5, 5, 1.0]   |
  *     +-------------------+--------------+
  *     }}}
  *   - `overallCoverage(dataset, embeddingsCol)`: Calculates overall word coverage for the whole
  *     data in the embedded field. This returns a single coverage object considering all rows in
  *     the field.
  *     {{{
  *     val wordsOverallCoverage = WordEmbeddingsModel.overallCoverage(wordsCoverage,"embeddings").percentage
  *     1.0
  *     }}}
  *
  * For extended examples of usage, see the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/quick_start_offline.ipynb Examples]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/WordEmbeddingsTestSpec.scala WordEmbeddingsTestSpec]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
  * import com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel
  * import com.johnsnowlabs.nlp.EmbeddingsFinisher
  * import org.apache.spark.ml.Pipeline
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols(Array("document"))
  *   .setOutputCol("token")
  *
  * val embeddings = WordEmbeddingsModel.pretrained()
  *   .setInputCols("document", "token")
  *   .setOutputCol("embeddings")
  *
  * val embeddingsFinisher = new EmbeddingsFinisher()
  *   .setInputCols("embeddings")
  *   .setOutputCols("finished_embeddings")
  *   .setOutputAsVector(true)
  *   .setCleanAnnotations(false)
  *
  * val pipeline = new Pipeline()
  *   .setStages(Array(
  *     documentAssembler,
  *     tokenizer,
  *     embeddings,
  *     embeddingsFinisher
  *   ))
  *
  * val data = Seq("This is a sentence.").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
  * +--------------------------------------------------------------------------------+
  * |                                                                          result|
  * +--------------------------------------------------------------------------------+
  * |[-0.570580005645752,0.44183000922203064,0.7010200023651123,-0.417129993438720...|
  * |[-0.542639970779419,0.4147599935531616,1.0321999788284302,-0.4024400115013122...|
  * |[-0.2708599865436554,0.04400600120425224,-0.020260000601410866,-0.17395000159...|
  * |[0.6191999912261963,0.14650000631809235,-0.08592499792575836,-0.2629800140857...|
  * |[-0.3397899866104126,0.20940999686717987,0.46347999572753906,-0.6479200124740...|
  * +--------------------------------------------------------------------------------+
  * }}}
  *
  * @see
  *   [[SentenceEmbeddings]] to combine embeddings into a sentence-level representation
  * @see
  *   [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer
  *   based embeddings
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class WordEmbeddingsModel(override val uid: String)
    extends AnnotatorModel[WordEmbeddingsModel]
    with HasSimpleAnnotate[WordEmbeddingsModel]
    with HasEmbeddingsProperties
    with HasStorageModel
    with ParamsAndFeaturesWritable
    with ReadsFromBytes {

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  def this() = this(Identifiable.randomUID("WORD_EMBEDDINGS_MODEL"))

  /** Output annotator type : WORD_EMBEDDINGS
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = WORD_EMBEDDINGS

  /** Input annotator type : DOCUMENT, TOKEN
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN)

  /** Cache size for items retrieved from storage. Increase for performance but higher memory
    * consumption
    *
    * @group param
    */
  val readCacheSize = new IntParam(
    this,
    "readCacheSize",
    "cache size for items retrieved from storage. Increase for performance but higher memory consumption")

  /** Set cache size for items retrieved from storage. Increase for performance but higher memory
    * consumption
    *
    * @group setParam
    */
  def setReadCacheSize(value: Int): this.type = set(readCacheSize, value)

  private var memoryStorage: Option[Broadcast[Map[BytesKey, Array[Byte]]]] = None

  private def getInMemoryStorage: Map[BytesKey, Array[Byte]] = {
    memoryStorage.map(_.value).getOrElse {
      if ($(enableInMemoryStorage)) {
        getReader(Database.EMBEDDINGS).exportStorageToMap()
      } else {
        Map.empty
      }
    }
  }

  override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
    if (this.memoryStorage.isEmpty && $(enableInMemoryStorage)) {
      val storageReader = getReader(Database.EMBEDDINGS)
      val memoryStorage = storageReader.exportStorageToMap()
      this.memoryStorage = Some(dataset.sparkSession.sparkContext.broadcast(memoryStorage))
    }
    dataset
  }

  /** Takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    val sentences = TokenizedWithSentence.unpack(annotations)
    val withEmbeddings = sentences.map { s =>
      val tokens = s.indexedTokens.map { indexedToken =>
        val (embeddings, zeroArray) = retrieveEmbeddings(indexedToken.token)
        TokenPieceEmbeddings(
          indexedToken.token,
          indexedToken.token,
          -1,
          isWordStart = true,
          embeddings,
          zeroArray,
          indexedToken.begin,
          indexedToken.end)
      }
      WordpieceEmbeddingsSentence(tokens, s.sentenceIndex)
    }

    WordpieceEmbeddingsSentence.pack(withEmbeddings)
  }

  private def retrieveEmbeddings(token: String): (Option[Array[Float]], Array[Float]) = {
    if ($(enableInMemoryStorage)) {
      val zeroArray = Array.fill[Float]($(dimension))(0f)
      var embeddings: Option[Array[Float]] = None

      lazy val resultLower =
        getInMemoryStorage.getOrElse(new BytesKey(token.trim.toLowerCase.getBytes()), Array())
      lazy val resultUpper =
        getInMemoryStorage.getOrElse(new BytesKey(token.trim.toUpperCase.getBytes()), Array())
      lazy val resultExact =
        getInMemoryStorage.getOrElse(new BytesKey(token.trim.getBytes()), Array())

      if (resultExact.nonEmpty) {
        embeddings = Some(fromBytes(resultExact))
      } else if (! $(caseSensitive) && resultLower.nonEmpty) {
        embeddings = Some(fromBytes(resultLower))
      } else if (! $(caseSensitive) && resultUpper.nonEmpty) {
        embeddings = Some(fromBytes(resultUpper))
      }

      (embeddings, zeroArray)

    } else {
      val storageReader = getReader(Database.EMBEDDINGS)
      val embeddings = storageReader.lookup(token)
      val zeroArray: Array[Float] = storageReader.emptyValue
      (embeddings, zeroArray)
    }
  }

  override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
    dataset.withColumn(
      getOutputCol,
      wrapEmbeddingsMetadata(dataset.col(getOutputCol), $(dimension), Some($(storageRef))))
  }

  private def bufferSizeFormula: Int = {
    scala.math
      .min( // LRU Cache Size, pick the smallest value up to 50k to reduce memory blue print as dimension grows
        (100.0 / $(dimension)) * 200000,
        50000)
      .toInt
  }

  override protected def createReader(
      database: Database.Name,
      connection: RocksDBConnection): WordEmbeddingsReader = {
    new WordEmbeddingsReader(
      connection,
      $(caseSensitive),
      $(dimension),
      get(readCacheSize).getOrElse(bufferSizeFormula))
  }

  override val databases: Array[Database.Name] = WordEmbeddingsModel.databases
}

trait ReadablePretrainedWordEmbeddings
    extends StorageReadable[WordEmbeddingsModel]
    with HasPretrained[WordEmbeddingsModel] {

  override val databases: Array[Name] = Array(Database.EMBEDDINGS)
  override val defaultModelName: Option[String] = Some("glove_100d")

  /** Java compliant-overrides */
  override def pretrained(): WordEmbeddingsModel = super.pretrained()
  override def pretrained(name: String): WordEmbeddingsModel = super.pretrained(name)
  override def pretrained(name: String, lang: String): WordEmbeddingsModel =
    super.pretrained(name, lang)
  override def pretrained(name: String, lang: String, remoteLoc: String): WordEmbeddingsModel =
    super.pretrained(name, lang, remoteLoc)
}

trait EmbeddingsCoverage {

  case class CoverageResult(covered: Long, total: Long, percentage: Float)

  def withCoverageColumn(
      dataset: DataFrame,
      embeddingsCol: String,
      outputCol: String = "coverage"): DataFrame = {
    val coverageFn = udf((annotatorProperties: Seq[Row]) => {
      val annotations = annotatorProperties.map(Annotation(_))
      val oov =
        annotations.map(x => if (x.metadata.getOrElse("isOOV", "false") == "false") 1 else 0)
      val covered = oov.sum
      val total = annotations.count(_ => true)
      val percentage = 1f * covered / total
      CoverageResult(covered, total, percentage)
    })
    dataset.withColumn(outputCol, coverageFn(col(embeddingsCol)))
  }

  def overallCoverage(dataset: DataFrame, embeddingsCol: String): CoverageResult = {
    val words = dataset
      .select(embeddingsCol)
      .flatMap(row => {
        val annotations = row.getAs[Seq[Row]](embeddingsCol)
        annotations.map(annotation =>
          Tuple2(
            annotation.getAs[Map[String, String]]("metadata")("token"),
            if (annotation
                .getAs[Map[String, String]]("metadata")
                .getOrElse("isOOV", "false") == "false") 1
            else 0))
      })
    val oov = words.reduce((a, b) => Tuple2("Total", a._2 + b._2))
    val covered = oov._2
    val total = words.count()
    val percentage = 1f * covered / total
    CoverageResult(covered, total, percentage)
  }
}

/** This is the companion object of [[WordEmbeddingsModel]]. Please refer to that class for the
  * documentation.
  */
object WordEmbeddingsModel extends ReadablePretrainedWordEmbeddings with EmbeddingsCoverage




© 2015 - 2024 Weber Informatics LLC | Privacy Policy