All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.WordEmbeddingsRetriever.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.util.LruMap
import org.rocksdb._


case class WordEmbeddingsRetriever(dbFile: String,
                                   nDims: Int,
                                   caseSensitive: Boolean,
                                   lruCacheSize: Int = 100000) extends AutoCloseable {

  @transient private var prefetchedDB: RocksDB = null

  private def db: RocksDB = {
    if (Option(prefetchedDB).isDefined)
      prefetchedDB
    else {
      RocksDB.loadLibrary()
      prefetchedDB = RocksDB.openReadOnly(dbFile)
      prefetchedDB
    }
  }

  val zeroArray: Array[Float] = Array.fill[Float](nDims)(0f)

  val lru = new LruMap[String, Array[Float]](lruCacheSize)

  private def getEmbeddingsFromDb(word: String): Array[Float] = {
    lazy val resultLower = db.get(word.trim.toLowerCase.getBytes())
    lazy val resultUpper = db.get(word.trim.toUpperCase.getBytes())
    lazy val resultExact = db.get(word.trim.getBytes())

    if (caseSensitive && resultExact != null)
      WordEmbeddingsIndexer.fromBytes(resultExact)
    else if (resultLower != null)
      WordEmbeddingsIndexer.fromBytes(resultLower)
    else if (resultExact != null)
      WordEmbeddingsIndexer.fromBytes(resultExact)
    else if (resultUpper != null)
      WordEmbeddingsIndexer.fromBytes(resultUpper)
    else
      zeroArray

  }

  def getEmbeddingsVector(word: String): Array[Float] = {
    synchronized {
      lru.getOrElseUpdate(word, getEmbeddingsFromDb(word))
    }
  }

  def containsEmbeddingsVector(word: String): Boolean = {
    val wordBytes = word.trim.getBytes()
    db.get(wordBytes) != null ||
      (db.get(word.trim.toLowerCase.getBytes()) != null) ||
      (db.get(word.trim.toUpperCase.getBytes()) != null)

  }

  override def close(): Unit = {
    if (Option(prefetchedDB).isDefined) {
      db.close()
      prefetchedDB = null
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy