All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, TOKEN, WORD_EMBEDDINGS}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, ParamsAndFeaturesWritable}
import com.johnsnowlabs.nlp.annotators.common.{TokenPieceEmbeddings, TokenizedWithSentence, WordpieceEmbeddingsSentence}
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, SparkSession}


class WordEmbeddingsModel(override val uid: String)
  extends AnnotatorModel[WordEmbeddingsModel]
    with HasWordEmbeddings
    with AutoCloseable
    with ParamsAndFeaturesWritable {

  def this() = this(Identifiable.randomUID("WORD_EMBEDDINGS_MODEL"))

  override val outputAnnotatorType: AnnotatorType = WORD_EMBEDDINGS
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator type */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN)

  private def getEmbeddingsSerializedPath(path: String): Path =
    Path.mergePaths(new Path(path), new Path("/embeddings"))

  private[embeddings] def deserializeEmbeddings(path: String, spark: SparkSession): Unit = {
    val src = getEmbeddingsSerializedPath(path)

    EmbeddingsHelper.load(
      src.toUri.toString,
      spark,
      WordEmbeddingsFormat.SPARKNLP.toString,
      $(dimension),
      $(caseSensitive),
      $(embeddingsRef)
    )
  }

  private[embeddings] def serializeEmbeddings(path: String, spark: SparkSession): Unit = {
    val index = new Path(EmbeddingsHelper.getLocalEmbeddingsPath(getClusterEmbeddings.fileName))

    val uri = new java.net.URI(path)
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val dst = getEmbeddingsSerializedPath(path)

    EmbeddingsHelper.save(fs, index, dst)
  }

  override protected def onWrite(path: String, spark: SparkSession): Unit = {
    /** Param only useful for runtime execution */
    if ($(includeEmbeddings))
      serializeEmbeddings(path, spark)
  }

  override protected def close(): Unit = {
    get(embeddingsRef)
      .flatMap(_ => preloadedEmbeddings)
      .foreach(_.getLocalRetriever.close())
  }

  /**
    * takes a document and annotations and produces new annotations of this annotator's annotation type
    *
    * @param annotations Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return any number of annotations processed for every input annotation. Not necessary one to one relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    val sentences = TokenizedWithSentence.unpack(annotations)
    val withEmbeddings = sentences.zipWithIndex.map{case (s, idx) =>
      val tokens = s.indexedTokens.map {token =>
        val vector = this.getEmbeddings.getEmbeddingsVector(token.token)
        new TokenPieceEmbeddings(token.token, token.token, -1, true, vector, token.begin, token.end)
      }
      WordpieceEmbeddingsSentence(tokens, idx)
    }

    WordpieceEmbeddingsSentence.pack(withEmbeddings)
  }

  override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
    getClusterEmbeddings.getLocalRetriever.close()

    dataset.withColumn(getOutputCol, wrapEmbeddingsMetadata(dataset.col(getOutputCol), $(dimension), Some(getEmbeddingsRef)))
  }

}

object WordEmbeddingsModel extends EmbeddingsReadable[WordEmbeddingsModel] with PretrainedWordEmbeddings

trait PretrainedWordEmbeddings {
  def pretrained(name: String = "glove_100d", lang: String = "en", remoteLoc: String = ResourceDownloader.publicLoc): WordEmbeddingsModel =
    ResourceDownloader.downloadModel(WordEmbeddingsModel, name, Option(lang), remoteLoc)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy