All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowSpell.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.ml.tensorflow.TensorResources.extractFloats
import com.johnsnowlabs.nlp.annotators.ner.Verbose

class TensorflowSpell(
  val tensorflow: TensorflowWrapper,
  val verboseLevel: Verbose.Value
  ) extends Logging with Serializable {

  val testInitOp = "test/init"
  val validWords = "valid_words"
  val fileNameTest = "file_name"
  val inMemoryInput = "in-memory-input"
  val batchesKey = "batches"
  val lossKey = "Add:0"
  val dropoutRate = "dropout_rate"

  // these are the inputs to the graph
  val wordIds = "batches:0"
  val contextIds = "batches:1"
  val contextWordIds = "batches:2"

  /* returns the loss associated with the last word, given previous history  */
  def predict(dataset: Array[Array[Int]], cids: Array[Array[Int]], cwids:Array[Array[Int]], configProtoBytes: Option[Array[Byte]] = None) = {

    val tensors = new TensorResources

    val lossWords = tensorflow.getSession(configProtoBytes=configProtoBytes).runner
      .feed(dropoutRate, tensors.createTensor(1.0f))
      .feed(wordIds, tensors.createTensor(dataset.map(_.dropRight(1))))
      .feed(contextIds, tensors.createTensor(cids.map(_.tail)))
      .feed(contextWordIds, tensors.createTensor(cwids.map(_.tail)))
      .fetch(lossKey)
      .fetch(validWords)
      .run()

    tensors.clearTensors()

    val result = extractFloats(lossWords.get(0))
    val width = dataset.head.length
    result.grouped(width - 1).map(_.last)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy