All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.embeddings.WordEmbeddingsIndexer.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.embeddings

import java.io._
import java.nio.{ByteBuffer, ByteOrder}

import org.slf4j.LoggerFactory

import scala.io.Source


object WordEmbeddingsIndexer {

  private[embeddings] def toBytes(embeddings: Array[Float]): Array[Byte] = {
    val buffer = ByteBuffer.allocate(embeddings.length * 4)
    buffer.order(ByteOrder.LITTLE_ENDIAN)
    for (value <- embeddings) {
      buffer.putFloat(value)
    }
    buffer.array()
  }

  private[embeddings] def fromBytes(source: Array[Byte]): Array[Float] = {
    val wrapper = ByteBuffer.wrap(source)
    wrapper.order(ByteOrder.LITTLE_ENDIAN)
    val result = Array.fill[Float](source.length / 4)(0f)

    for (i <- 0 until result.length) {
      result(i) = wrapper.getFloat(i * 4)
    }
    result
  }

  /**
    * Indexes Word embeddings in CSV Format
    */
  def indexText(source: Iterator[String], dbFile: String): Unit = {
    TextIndexer.index(source, dbFile)
  }

  /**
    * Indexes Word embeddings in CSV Text File
    */
  def indexText(source: String, dbFile: String): Unit ={
    TextIndexer.index(source, dbFile)
  }


  def indexBinary(source: DataInputStream, dbFile: String): Unit = {
    BinaryIndexer.index(source, dbFile)
  }

  /**
    * Indexes Binary formatted file
    */
  def indexBinary(source: String, dbFile: String): Unit = {
    BinaryIndexer.index(source, dbFile)
  }
}


private[embeddings] object TextIndexer {

  def index(source: Iterator[String], dbFile: String): Unit = {
    val indexer = RocksDbIndexer(dbFile, Some(1000))

    try {
      for (line <- source) {
        val items = line.split(" ")
        val word = items(0)
        val embeddings = items.drop(1).map(i => i.toFloat)
        indexer.add(word, embeddings)
      }
    } finally {
      indexer.close()
    }
  }

  def index(source: String, dbFile: String): Unit = {
    val lines = Source.fromFile(source)("UTF-8").getLines()
    index(lines, dbFile)
  }
}


private[embeddings] object BinaryIndexer {

  private val logger = LoggerFactory.getLogger("WordEmbeddings")

  def index(source: DataInputStream, dbFile: String): Unit = {
    val indexer = RocksDbIndexer(dbFile, Some(1000))

    try {
      // File Header
      val numWords = Integer.parseInt(readString(source))
      val vecSize = Integer.parseInt(readString(source))

      // File Body
      for (i <- 0 until numWords) {
        val word = readString(source)

        // Unit Vector
        val vector = readFloatVector(source, vecSize)
        indexer.add(word, vector)
      }

      logger.info(s"Loaded $numWords words, vector size $vecSize")
    } finally {
      indexer.close()
    }
  }

  def index(source: String, dbFile: String): Unit = {

    val ds = new DataInputStream(new BufferedInputStream(new FileInputStream(source), 1 << 15))

    try {
      index(ds, dbFile)
    } finally {
      ds.close()
    }
  }

  /**
    * Read a string from the binary model (System default should be UTF-8):
    */
  private def readString(ds: DataInputStream): String = {
    val byteBuffer = new ByteArrayOutputStream()

    var isEnd = false
    while (!isEnd) {
      val byteValue = ds.readByte()
      if ((byteValue != 32) && (byteValue != 10)) {
        byteBuffer.write(byteValue)
      } else if (byteBuffer.size() > 0) {
        isEnd = true
      }
    }

    val word = byteBuffer.toString()
    byteBuffer.close()
    word
  }

  /**
    * Read a Vector - Array of Floats from the binary model:
    */
  private def readFloatVector(ds: DataInputStream, vectorSize: Int): Array[Float] = {
    // Read Bytes
    val vectorBuffer = Array.fill[Byte](4 * vectorSize)(0)
    ds.read(vectorBuffer)

    // Convert Bytes to Floats
    WordEmbeddingsIndexer.fromBytes(vectorBuffer)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy