All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.SentenceEmbeddings.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, SENTENCE_EMBEDDINGS, WORD_EMBEDDINGS}
import com.johnsnowlabs.nlp.annotators.common.{SentenceSplit, WordpieceEmbeddingsSentence}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, AnnotatorType, HasSimpleAnnotate}
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}

/** Converts the results from [[WordEmbeddings]], [[BertEmbeddings]], or [[ElmoEmbeddings]] into
  * sentence or document embeddings by either summing up or averaging all the word embeddings in a
  * sentence or a document (depending on the inputCols).
  *
  * This can be configured with `setPoolingStrategy`, which either be `"AVERAGE"` or `"SUM"`.
  *
  * For more extended examples see the
  * [[https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/5.1_Text_classification_examples_in_SparkML_SparkNLP.ipynb Spark NLP Workshop]].
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/SentenceEmbeddingsTestSpec.scala SentenceEmbeddingsTestSpec]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
  * import com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel
  * import com.johnsnowlabs.nlp.embeddings.SentenceEmbeddings
  * import com.johnsnowlabs.nlp.EmbeddingsFinisher
  * import org.apache.spark.ml.Pipeline
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols(Array("document"))
  *   .setOutputCol("token")
  *
  * val embeddings = WordEmbeddingsModel.pretrained()
  *   .setInputCols("document", "token")
  *   .setOutputCol("embeddings")
  *
  * val embeddingsSentence = new SentenceEmbeddings()
  *   .setInputCols(Array("document", "embeddings"))
  *   .setOutputCol("sentence_embeddings")
  *   .setPoolingStrategy("AVERAGE")
  *
  * val embeddingsFinisher = new EmbeddingsFinisher()
  *   .setInputCols("sentence_embeddings")
  *   .setOutputCols("finished_embeddings")
  *   .setOutputAsVector(true)
  *   .setCleanAnnotations(false)
  *
  * val pipeline = new Pipeline()
  *   .setStages(Array(
  *     documentAssembler,
  *     tokenizer,
  *     embeddings,
  *     embeddingsSentence,
  *     embeddingsFinisher
  *   ))
  *
  * val data = Seq("This is a sentence.").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
  * +--------------------------------------------------------------------------------+
  * |                                                                          result|
  * +--------------------------------------------------------------------------------+
  * |[-0.22093398869037628,0.25130119919776917,0.41810303926467896,-0.380883991718...|
  * +--------------------------------------------------------------------------------+
  * }}}
  *
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class SentenceEmbeddings(override val uid: String)
    extends AnnotatorModel[SentenceEmbeddings]
    with HasSimpleAnnotate[SentenceEmbeddings]
    with HasEmbeddingsProperties
    with HasStorageRef {

  /** Output annotator type : SENTENCE_EMBEDDINGS
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = SENTENCE_EMBEDDINGS

  /** Input annotator type : DOCUMENT, WORD_EMBEDDINGS
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT, WORD_EMBEDDINGS)

  /** Number of embedding dimensions (Default: `100`)
    *
    * @group param
    */
  override val dimension = new IntParam(this, "dimension", "Number of embedding dimensions")

  /** Number of embedding dimensions (Default: `100`)
    *
    * @group getParam
    */
  override def getDimension: Int = $(dimension)

  /** Choose how you would like to aggregate Word Embeddings to Sentence Embeddings (Default:
    * `"AVERAGE"`). Can either be `"AVERAGE"` or `"SUM"`.
    *
    * @group param
    */
  val poolingStrategy = new Param[String](
    this,
    "poolingStrategy",
    "Choose how you would like to aggregate Word Embeddings to Sentence Embeddings: AVERAGE or SUM")

  /** Choose how you would like to aggregate Word Embeddings to Sentence Embeddings (Default:
    * `"AVERAGE"`). Can either be `"AVERAGE"` or `"SUM"`.
    *
    * @group setParam
    */
  def setPoolingStrategy(strategy: String): this.type = {
    strategy.toLowerCase() match {
      case "average" => set(poolingStrategy, "AVERAGE")
      case "sum" => set(poolingStrategy, "SUM")
      case _ => throw new MatchError("poolingStrategy must be either AVERAGE or SUM")
    }
  }

  setDefault(
    inputCols -> Array(DOCUMENT, WORD_EMBEDDINGS),
    outputCol -> "sentence_embeddings",
    poolingStrategy -> "AVERAGE",
    dimension -> 100)

  /** Internal constructor to submit a random UID */
  def this() = this(Identifiable.randomUID("SENTENCE_EMBEDDINGS"))

  private def calculateSentenceEmbeddings(matrix: Array[Array[Float]]): Array[Float] = {
    val res = Array.ofDim[Float](matrix(0).length)
    setDimension(matrix(0).length)

    matrix(0).indices.foreach { j =>
      matrix.indices.foreach { i =>
        res(j) += matrix(i)(j)
      }
      if ($(poolingStrategy) == "AVERAGE")
        res(j) /= matrix.length
    }
    res
  }

  /** takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    val sentences = SentenceSplit.unpack(annotations)

    val embeddingsSentences = WordpieceEmbeddingsSentence.unpack(annotations)

    sentences.map { sentence =>
      val embeddings =
        embeddingsSentences.filter(embeddings => embeddings.sentenceId == sentence.index)

      val sentenceEmbeddings = embeddings.flatMap { tokenEmbedding =>
        val allEmbeddings = tokenEmbedding.tokens.map { token =>
          token.embeddings
        }
        calculateSentenceEmbeddings(allEmbeddings)
      }.toArray

      Annotation(
        annotatorType = outputAnnotatorType,
        begin = sentence.start,
        end = sentence.end,
        result = sentence.content,
        metadata = Map(
          "sentence" -> sentence.index.toString,
          "token" -> sentence.content,
          "pieceId" -> "-1",
          "isWordStart" -> "true"),
        embeddings = sentenceEmbeddings)
    }
  }
  override protected def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
    val ref =
      HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS)
    if (get(storageRef).isEmpty)
      setStorageRef(ref)
    dataset
  }
  override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
    dataset.withColumn(
      getOutputCol,
      wrapSentenceEmbeddingsMetadata(
        dataset.col(getOutputCol),
        $(dimension),
        Some($(storageRef))))
  }
}

/** This is the companion object of [[SentenceEmbeddings]]. Please refer to that class for the
  * documentation.
  */
object SentenceEmbeddings extends DefaultParamsReadable[SentenceEmbeddings]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy