All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.ChunkEmbeddings.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, HasSimpleAnnotate}
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.DataFrame

import scala.collection.Map

object PoolingStrategy {

  object AnnotatorType {
    val AVERAGE = "AVERAGE"
    val SUM = "SUM"
  }

}

/** This annotator utilizes [[WordEmbeddings]], [[BertEmbeddings]] etc. to generate chunk
  * embeddings from either [[com.johnsnowlabs.nlp.annotators.Chunker Chunker]],
  * [[com.johnsnowlabs.nlp.annotators.NGramGenerator NGramGenerator]], or
  * [[com.johnsnowlabs.nlp.annotators.ner.NerConverter NerConverter]] outputs.
  *
  * For extended examples of usage, see the
  * [[https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/3.SparkNLP_Pretrained_Models.ipynb Spark NLP Workshop]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddingsTestSpec.scala ChunkEmbeddingsTestSpec]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
  * import com.johnsnowlabs.nlp.annotators.{NGramGenerator, Tokenizer}
  * import com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel
  * import com.johnsnowlabs.nlp.embeddings.ChunkEmbeddings
  * import org.apache.spark.ml.Pipeline
  *
  * // Extract the Embeddings from the NGrams
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val sentence = new SentenceDetector()
  *   .setInputCols("document")
  *   .setOutputCol("sentence")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols(Array("sentence"))
  *   .setOutputCol("token")
  *
  * val nGrams = new NGramGenerator()
  *   .setInputCols("token")
  *   .setOutputCol("chunk")
  *   .setN(2)
  *
  * val embeddings = WordEmbeddingsModel.pretrained()
  *   .setInputCols("sentence", "token")
  *   .setOutputCol("embeddings")
  *   .setCaseSensitive(false)
  *
  * // Convert the NGram chunks into Word Embeddings
  * val chunkEmbeddings = new ChunkEmbeddings()
  *   .setInputCols("chunk", "embeddings")
  *   .setOutputCol("chunk_embeddings")
  *   .setPoolingStrategy("AVERAGE")
  *
  * val pipeline = new Pipeline()
  *   .setStages(Array(
  *     documentAssembler,
  *     sentence,
  *     tokenizer,
  *     nGrams,
  *     embeddings,
  *     chunkEmbeddings
  *   ))
  *
  * val data = Seq("This is a sentence.").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.selectExpr("explode(chunk_embeddings) as result")
  *   .select("result.annotatorType", "result.result", "result.embeddings")
  *   .show(5, 80)
  * +---------------+----------+--------------------------------------------------------------------------------+
  * |  annotatorType|    result|                                                                      embeddings|
  * +---------------+----------+--------------------------------------------------------------------------------+
  * |word_embeddings|   This is|[-0.55661, 0.42829502, 0.86661, -0.409785, 0.06316501, 0.120775, -0.0732005, ...|
  * |word_embeddings|      is a|[-0.40674996, 0.22938299, 0.50597, -0.288195, 0.555655, 0.465145, 0.140118, 0...|
  * |word_embeddings|a sentence|[0.17417, 0.095253006, -0.0530925, -0.218465, 0.714395, 0.79860497, 0.0129999...|
  * |word_embeddings|sentence .|[0.139705, 0.177955, 0.1887775, -0.45545, 0.20030999, 0.461557, -0.07891501, ...|
  * +---------------+----------+--------------------------------------------------------------------------------+
  * }}}
  *
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class ChunkEmbeddings(override val uid: String)
    extends AnnotatorModel[ChunkEmbeddings]
    with HasSimpleAnnotate[ChunkEmbeddings] {

  import com.johnsnowlabs.nlp.AnnotatorType._

  /** Output annotator type : WORD_EMBEDDINGS
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = WORD_EMBEDDINGS

  /** Input annotator type : CHUNK, WORD_EMBEDDINGS
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(CHUNK, WORD_EMBEDDINGS)

  /** Choose how you would like to aggregate Word Embeddings to Chunk Embeddings: `"AVERAGE"` or
    * `"SUM"` (Default: `"AVERAGE"`)
    *
    * @group param
    */
  val poolingStrategy = new Param[String](
    this,
    "poolingStrategy",
    "Choose how you would like to aggregate Word Embeddings to Chunk Embeddings: AVERAGE or SUM")

  /** Whether to discard default vectors for OOV words from the aggregation / pooling (Default:
    * `true`)
    *
    * @group param
    */
  val skipOOV = new BooleanParam(
    this,
    "skipOOV",
    "Whether to discard default vectors for OOV words from the aggregation / pooling")

  /** PoolingStrategy must be either AVERAGE or SUM
    *
    * @group setParam
    */
  def setPoolingStrategy(strategy: String): this.type = {
    strategy.toLowerCase() match {
      case "average" => set(poolingStrategy, "AVERAGE")
      case "sum" => set(poolingStrategy, "SUM")
      case _ => throw new MatchError("poolingStrategy must be either AVERAGE or SUM")
    }
  }

  /** Whether to discard default vectors for OOV words from the aggregation / pooling
    *
    * @group setParam
    */
  def setSkipOOV(value: Boolean): this.type = set(skipOOV, value)

  /** Choose how you would like to aggregate Word Embeddings to Chunk Embeddings: AVERAGE or SUM
    *
    * @group getParam
    */
  def getPoolingStrategy: String = $(poolingStrategy)

  /** Whether to discard default vectors for OOV words from the aggregation / pooling
    *
    * @group getParam
    */
  def getSkipOOV: Boolean = $(skipOOV)

  setDefault(
    inputCols -> Array(CHUNK, WORD_EMBEDDINGS),
    outputCol -> "chunk_embeddings",
    poolingStrategy -> "AVERAGE",
    skipOOV -> true)

  /** Internal constructor to submit a random UID */
  def this() = this(Identifiable.randomUID("CHUNK_EMBEDDINGS"))

  private def calculateChunkEmbeddings(matrix: Array[Array[Float]]): Array[Float] = {
    val res = Array.ofDim[Float](matrix(0).length)
    matrix(0).indices.foreach { j =>
      matrix.indices.foreach { i =>
        res(j) += matrix(i)(j)
      }
      if ($(poolingStrategy) == "AVERAGE")
        res(j) /= matrix.length
    }
    res
  }

  /** takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {

    val documentsWithChunks = annotations
      .filter(token => token.annotatorType == CHUNK)

    val embeddingsSentences = WordpieceEmbeddingsSentence.unpack(annotations)

    documentsWithChunks.flatMap { chunk =>
      val sentenceIdx = chunk.metadata.getOrElse("sentence", "0").toInt
      val chunkIdx = chunk.metadata.getOrElse("chunk", "0").toInt

      if (sentenceIdx < embeddingsSentences.length) {

        val tokensWithEmbeddings = embeddingsSentences(sentenceIdx).tokens.filter(token =>
          token.begin >= chunk.begin && token.end <= chunk.end)

        val allEmbeddings = tokensWithEmbeddings.flatMap(tokenEmbedding =>
          if (!tokenEmbedding.isOOV || ! $(skipOOV))
            Some(tokenEmbedding.embeddings)
          else
            None)

        val finalEmbeddings =
          if (allEmbeddings.length > 0) allEmbeddings else tokensWithEmbeddings.map(_.embeddings)

        /** When we have more chunks than word embeddings this happens when the embeddings has max
          * sequence restriction like BERT, ALBERT, etc.
          */
        if (finalEmbeddings.isEmpty)
          None
        else
          Some(
            Annotation(
              annotatorType = outputAnnotatorType,
              begin = chunk.begin,
              end = chunk.end,
              result = chunk.result,
              metadata = Map(
                "sentence" -> sentenceIdx.toString,
                "chunk" -> chunkIdx.toString,
                "token" -> chunk.result,
                "pieceId" -> "-1",
                "isWordStart" -> "true"),
              embeddings = calculateChunkEmbeddings(finalEmbeddings)))
      } else {
        None
      }
    }
  }

  override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
    val embeddingsCol = Annotation.getColumnByType(dataset, $(inputCols), WORD_EMBEDDINGS)
    dataset.withColumn(
      getOutputCol,
      dataset.col(getOutputCol).as(getOutputCol, embeddingsCol.metadata))
  }

}

/** This is the companion object of [[ChunkEmbeddings]]. Please refer to that class for the
  * documentation.
  */
object ChunkEmbeddings extends DefaultParamsReadable[ChunkEmbeddings]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy