All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.HasWordEmbeddings.scala Maven / Gradle / Ivy

There is a newer version: 1.6.2
Show newest version
package com.johnsnowlabs.nlp

import java.io.File
import java.nio.file.{Files, Paths}

import com.johnsnowlabs.nlp.embeddings.{SparkWordEmbeddings, WordEmbeddings, WordEmbeddingsFormat}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkContext, SparkFiles}


/**
  * Base class for models that uses Word Embeddings.
  * This implementation is based on RocksDB so it has a compact RAM usage
  *
  * Corresponding Approach have to implement AnnotatorWithWordEmbeddings
   */

trait HasWordEmbeddings extends AutoCloseable with ParamsAndFeaturesWritable {

  val nDims = new IntParam(this, "nDims", "Number of embedding dimensions")
  val indexPath = new Param[String](this, "indexPath", "File that stores Index")

  def setDims(nDims: Int): this.type = set(this.nDims, nDims)
  def setIndexPath(path: String): this.type = set(this.indexPath, path)

  @transient
  private var sparkEmbeddings: SparkWordEmbeddings = null

  def embeddings: Option[WordEmbeddings] = get(indexPath).map { path =>
    if (sparkEmbeddings == null)
      sparkEmbeddings = new SparkWordEmbeddings(path, $(nDims))

    sparkEmbeddings.wordEmbeddings
  }

  override def close(): Unit = {
    if (embeddings.nonEmpty)
      embeddings.get.close()
  }

  def moveFolderFiles(folderSrc: String, folderDst: String): Unit = {
    for (file <- new File(folderSrc).list()) {
      Files.move(Paths.get(folderSrc, file), Paths.get(folderDst, file))
    }

    Files.delete(Paths.get(folderSrc))
  }

  def deserializeEmbeddings(path: String, spark: SparkContext): Unit = {
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.hadoopConfiguration)
    val src = getEmbeddingsSerializedPath(path)

    if (fs.exists(src)) {
      val embeddings = SparkWordEmbeddings(spark, src.toUri.toString, 0, WordEmbeddingsFormat.SPARKNLP)
      setIndexPath(embeddings.clusterFilePath.toString)
    }
  }

  def serializeEmbeddings(path: String, spark: SparkContext): Unit = {
    if (isDefined(indexPath)) {
      val index = new Path(SparkFiles.get($(indexPath)))
      val uri = new java.net.URI(path)
      val fs = FileSystem.get(uri, spark.hadoopConfiguration)

      val dst = getEmbeddingsSerializedPath(path)
      fs.copyFromLocalFile(false, true, index, dst)
    }
  }

  def getEmbeddingsSerializedPath(path: String): Path = Path.mergePaths(new Path(path), new Path("/embeddings"))

  override def onWrite(path: String, spark: SparkSession): Unit = {
    serializeEmbeddings(path, spark.sparkContext)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy