All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.WordEmbeddings.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.AnnotatorApproach
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, TOKEN, WORD_EMBEDDINGS}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.{Dataset, SparkSession}

class WordEmbeddings(override val uid: String) extends AnnotatorApproach[WordEmbeddingsModel] with HasWordEmbeddings {

  def this() = this(Identifiable.randomUID("WORD_EMBEDDINGS"))

  override val outputAnnotatorType: AnnotatorType = WORD_EMBEDDINGS
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator type */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN)

  val sourceEmbeddingsPath = new Param[String](this, "sourceEmbeddingsPath", "Word embeddings file")

  val embeddingsFormat = new IntParam(this, "embeddingsFormat", "Word vectors file format")

  override val description: String = "Word Embeddings lookup annotator that maps tokens to vectors"

  def setEmbeddingsSource(path: String, nDims: Int, format: WordEmbeddingsFormat.Format): this.type = {
    set(this.sourceEmbeddingsPath, path)
    set(this.embeddingsFormat, format.id)
    set(this.dimension, nDims)
  }

  def setEmbeddingsSource(path: String, nDims: Int, format: String): this.type = {
    import WordEmbeddingsFormat._
    set(this.sourceEmbeddingsPath, path)
    set(this.embeddingsFormat, format.id)
    set(this.dimension, nDims)
  }

  def setSourcePath(path: String): this.type = set(sourceEmbeddingsPath, path)
  def getSourcePath: String = $(sourceEmbeddingsPath)

  def setEmbeddingsFormat(format: String): this.type = {
    import WordEmbeddingsFormat._
    set(embeddingsFormat, format.id)
  }

  def getEmbeddingsFormat: String = {
    import WordEmbeddingsFormat._
    int2frm($(embeddingsFormat)).toString
  }


  override def beforeTraining(sparkSession: SparkSession): Unit = {
    if (isDefined(sourceEmbeddingsPath)) {
      if (!isLoaded()) {
        EmbeddingsHelper.load(
          $(sourceEmbeddingsPath),
          sparkSession,
          WordEmbeddingsFormat($(embeddingsFormat)).toString,
          $(dimension),
          $(caseSensitive),
          $(embeddingsRef)
        )
        setAsLoaded()
      }
    } else if (isSet(embeddingsRef)) {
      getClusterEmbeddings
    } else
      throw new IllegalArgumentException(
        s"Word embeddings not found. Either sourceEmbeddingsPath not set," +
          s" or not in cache by ref: ${get(embeddingsRef).getOrElse("-embeddingsRef not set-")}. " +
          s"Load using EmbeddingsHelper .loadEmbeddings() and .setEmbeddingsRef() to make them available."
      )
  }

  override def train(dataset: Dataset[_], recursivePipeline: Option[PipelineModel]): WordEmbeddingsModel = {
    val model = new WordEmbeddingsModel()
      .setInputCols($(inputCols))
      .setEmbeddingsRef($(embeddingsRef))
      .setDimension($(dimension))
      .setCaseSensitive($(caseSensitive))
      .setIncludeEmbeddings($(includeEmbeddings))

    getClusterEmbeddings.getLocalRetriever.close()

    model
  }

}

object WordEmbeddings extends DefaultParamsReadable[WordEmbeddings]




© 2015 - 2025 Weber Informatics LLC | Privacy Policy