All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowAssertion.scala Maven / Gradle / Ivy

There is a newer version: 1.6.2
Show newest version
package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.nlp.annotators.datasets.AssertionAnnotationWithLabel
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import com.johnsnowlabs.ml.tensorflow.TensorResources._

/**
  * Created by jose on 15/03/18.
  */

class TensorflowAssertion (
  val tensorflow: TensorflowWrapper,
  val encoder: AssertionDatasetEncoder,
  val batchSize: Int,
  val verboseLevel: Verbose.Value
) extends Logging {

  private val sentenceLengthsKey = "word_repr/sentence_lengths"
  private val wordEmbeddingsKey = "word_repr/word_embeddings"
  private val labelsKey = "training/labels"
  private val dropoutKey = "training/dropout"
  private val learningRateKey = "training/lr"
  private val trainingKey = "training/Adam"
  private val lossKey = "training/loss"
  private val outputKey = "output_label"

  def predict(dataset: Array[Array[String]], start:Array[Int], end:Array[Int]): Array[String] = {

    val result = ArrayBuffer[String]()

    for ((sents, start, end)
         <- (dataset.grouped(batchSize).toList, start.grouped(batchSize).toList, end.grouped(batchSize).toList).zipped) {

      val batchInput = encoder.encodeInputData(sents, start, end)
      val tensors = new TensorResources()

      val output = tensorflow.session.runner
        .feed(sentenceLengthsKey, tensors.createTensor(batchInput.sentenceLengths))
        .feed(wordEmbeddingsKey, tensors.createTensor(batchInput.wordEmbeddings))
        .fetch(outputKey)
        .run()

      tensors.clearTensors()

      val tagIds = extractInts(output.get(0), batchSize)
      val tags = encoder.decodeOutputData(tagIds)

      result.appendAll(tags)
    }
    result.toArray
  }

  def train(trainDataset: Array[(Array[String], AssertionAnnotationWithLabel)],
            lr: Float,
            batchSize: Int,
            dropout: Float,
            startEpoch: Int,
            endEpoch: Int
           ): Unit = {

    log(s"Training started, trainExamples: ${trainDataset.length} ", Verbose.TrainingStat)

    // Initialize
    if (startEpoch == 0)
      tensorflow.session.runner.addTarget("init").run()

    // initial value for learning rate
    var learningRate = lr

    // Train
    for (epoch <- startEpoch until endEpoch) {

      val epochDataset = Random.shuffle(trainDataset.toList).toArray
      log(s"Epoch: $epoch started, learning rate: $learningRate, dataset size: ${epochDataset.length}", Verbose.Epochs)

      val time = System.nanoTime()
      var batches = 0
      var loss = 0f

      for (batchData <- epochDataset.grouped(batchSize)) {

        val (sentences, annotations) = batchData.unzip
        val labels = annotations.map(r => encoder.encodeOneHot(r.label))

        val start = annotations.map(_.start)
        val end = annotations.map(_.end)
        val batchInput = encoder.encodeInputData(sentences, start, end)

        val tensors = new TensorResources()
        val calculated = tensorflow.session.runner
          .feed(sentenceLengthsKey, tensors.createTensor(batchInput.sentenceLengths))
          .feed(wordEmbeddingsKey, tensors.createTensor(batchInput.wordEmbeddings))
          .feed(labelsKey, tensors.createTensor(labels))

          .feed(dropoutKey, tensors.createTensor(1.0f - dropout))
          .feed(learningRateKey, tensors.createTensor(learningRate))

          .fetch(lossKey)
          .addTarget(trainingKey)
          .run()

        loss += calculated.get(0).floatValue()

        tensors.clearTensors()
        batches += 1
      }

      learningRate = learningRate * 0.95f
      System.out.println(s"Done, ${(System.nanoTime() - time)/1e9} loss: $loss, batches: $batches")
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy