All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowNer.scala Maven / Gradle / Ivy

There is a newer version: 1.6.2
Show newest version
package com.johnsnowlabs.ml.tensorflow

import java.nio.file.{Files, Paths}
import com.johnsnowlabs.ml.crf.TextSentenceLabels
import com.johnsnowlabs.nlp.annotators.common.TokenizedSentence
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import org.tensorflow.{Graph, Session}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random


class TensorflowNer
(
  val tensorflow: TensorflowWrapper,
  val encoder: NerDatasetEncoder,
  val batchSize: Int,
  override val verboseLevel: Verbose.Value
) extends Logging {

  override def getLogName: String = "NerDL"

  private val charIdsKey = "char_repr/char_ids"
  private val wordLengthsKey = "char_repr/word_lengths"
  private val wordEmbeddingsKey = "word_repr_1/word_embeddings"
  private val sentenceLengthsKey = "word_repr/sentence_lengths"
  private val dropoutKey = "training/dropout"

  private val learningRateKey = "training/lr"
  private val labelsKey = "training/labels"

  private val lossKey = "inference/loss"
  private val trainingKey = "training_1/Momentum"
  private val predictKey = "context_repr/predicted_labels"

  def predict(dataset: Array[TokenizedSentence]): Array[Array[String]] = {

    val result = ArrayBuffer[Array[String]]()

    for (slice <- dataset.grouped(batchSize)) {
      val sentences = slice.map(r => r.tokens)

      val batchInput = encoder.encodeInputData(sentences)

      val tensors = new TensorResources()

      val calculated = tensorflow.session.runner
        .feed(sentenceLengthsKey, tensors.createTensor(batchInput.sentenceLengths))
        .feed(wordEmbeddingsKey, tensors.createTensor(batchInput.wordEmbeddings))

        .feed(wordLengthsKey, tensors.createTensor(batchInput.wordLengths))
        .feed(charIdsKey, tensors.createTensor(batchInput.charIds))

        .feed(dropoutKey, tensors.createTensor(1.1f))
        .fetch(predictKey)
        .run()

      tensors.clearTensors()

      val tagIds = TensorResources.extractInts(calculated.get(0), batchSize * batchInput.maxLength)
      val tags = encoder.decodeOutputData(tagIds)
      val sentenceTags = encoder.convertBatchTags(tags, batchInput.sentenceLengths)

      result.appendAll(sentenceTags)
    }

    result.toArray
  }

  def train(trainDataset: Array[(TextSentenceLabels, TokenizedSentence)],
            lr: Float,
            po: Float,
            batchSize: Int,
            dropout: Float,
            startEpoch: Int,
            endEpoch: Int,
            validation: Array[(TextSentenceLabels, TokenizedSentence)] = Array.empty,
            test: Array[(TextSentenceLabels, TokenizedSentence)] = Array.empty
           ): Unit = {

    log(s"Training started, trainExamples: ${trainDataset.length}, " +
      s"labels: ${encoder.tags.length} " +
      s"chars: ${encoder.chars.length}, ", Verbose.TrainingStat)

    // Initialize
    if (startEpoch == 0)
      tensorflow.session.runner.addTarget("training_1/init").run()

    // Train
    for (epoch <- startEpoch until endEpoch) {

      val epochDataset = Random.shuffle(trainDataset.toList).toArray
      val learningRate = lr / (1 + po * epoch)

      log(s"Epoch: $epoch started, learning rate: $learningRate, dataset size: ${epochDataset.length}", Verbose.Epochs)

      val time = System.nanoTime()
      var batches = 0
      var loss = 0f
      for (slice <- epochDataset.grouped(batchSize)) {
        val sentences = slice.map(r => r._2.tokens)
        val tags = slice.map(r => r._1.labels.toArray)

        val batchInput = encoder.encodeInputData(sentences)
        val batchTags = encoder.encodeTags(tags)

        val tensors = new TensorResources()
        val calculated = tensorflow.session.runner
          .feed(sentenceLengthsKey, tensors.createTensor(batchInput.sentenceLengths))
          .feed(wordEmbeddingsKey, tensors.createTensor(batchInput.wordEmbeddings))

          .feed(wordLengthsKey, tensors.createTensor(batchInput.wordLengths))
          .feed(charIdsKey, tensors.createTensor(batchInput.charIds))
          .feed(labelsKey, tensors.createTensor(batchTags))

          .feed(dropoutKey, tensors.createTensor(dropout))
          .feed(learningRateKey, tensors.createTensor(learningRate))

          .fetch(lossKey)
          .addTarget(trainingKey)
          .run()

        loss += calculated.get(0).floatValue()

        tensors.clearTensors()
        batches += 1
      }

      log(s"Done, ${(System.nanoTime() - time)/1e9} loss: $loss, batches: $batches", Verbose.Epochs)

      if (validation.nonEmpty) {
        log("Quality on train dataset: ", Verbose.Epochs)
        measure(trainDataset, (s: String) => log(s, Verbose.Epochs))
      }

      if (validation.nonEmpty) {
        log("Quality on validation dataset: ", Verbose.Epochs)
        measure(validation, (s: String) => log(s, Verbose.Epochs))
      }

      if (test.nonEmpty) {
        log("Quality on test dataset: ", Verbose.Epochs)
        measure(test, (s: String) => log(s, Verbose.Epochs))
      }
    }
  }


  def calcStat(correct: Int, predicted: Int, predictedCorrect: Int): (Float, Float, Float) = {
    // prec = (predicted & correct) / predicted
    // rec = (predicted & correct) / correct
    val prec = predictedCorrect.toFloat / predicted
    val rec = predictedCorrect.toFloat / correct
    val f1 = 2 * prec * rec / (prec + rec)

    (prec, rec, f1)
  }

  def measure(labeled: Array[(TextSentenceLabels, TokenizedSentence)],
                  log: (String => Unit),
                  extended: Boolean = false,
                  nErrorsToPrint: Int = 0
                 ): Unit = {

    val started = System.nanoTime()

    val predictedCorrect = mutable.Map[String, Int]()
    val predicted = mutable.Map[String, Int]()
    val correct = mutable.Map[String, Int]()

    val sentence = labeled.map(pair => pair._2).toList
    val sentenceLabels = labeled.map(pair => pair._1.labels.toArray).toList
    val sentencePredictedTags = labeled.map(pair => predict(Array(pair._2)).head).toList

    var errorsPrinted = 0
    var linePrinted = false
    (sentence, sentenceLabels, sentencePredictedTags).zipped.foreach {
      case (tokens, labels, tags) =>
        for (i <- 0 until labels.length) {
          val label = labels(i)
          val tag = tags(i)
          val iWord = tokens.indexedTokens(i)

          correct(label) = correct.getOrElse(label, 0) + 1
          predicted(tag) = predicted.getOrElse(tag, 0) + 1

          if (label == tag)
            predictedCorrect(tag) = predictedCorrect.getOrElse(tag, 0) + 1
          else if (errorsPrinted < nErrorsToPrint) {
            log(s"label: $label, predicted: $tag, word: $iWord")
            linePrinted = false
            errorsPrinted += 1
          }
        }

        if (errorsPrinted < nErrorsToPrint && !linePrinted) {
          log("")
          linePrinted = true
        }
    }

    if (extended)
      log(s"time: ${(System.nanoTime() - started)/1e9}")

    val labels = (correct.keys ++ predicted.keys).toSeq.distinct

    val notEmptyLabels = labels.filter(label => label != "O" && label.nonEmpty)

    val totalCorrect = correct.filterKeys(label => notEmptyLabels.contains(label)).values.sum
    val totalPredicted = predicted.filterKeys(label => notEmptyLabels.contains(label)).values.sum
    val totalPredictedCorrect = predictedCorrect.filterKeys(label => notEmptyLabels.contains(label)).values.sum
    val (prec, rec, f1) = calcStat(totalCorrect, totalPredicted, totalPredictedCorrect)
    log(s"Total stat, prec: $prec\t, rec: $rec\t, f1: $f1")

    if (!extended)
      return

    log("label\tprec\trec\tf1")

    for (label <- notEmptyLabels) {
      val (prec, rec, f1) = calcStat(
        correct.getOrElse(label, 0),
        predicted.getOrElse(label, 0),
        predictedCorrect.getOrElse(label, 0)
      )

      log(s"$label\t$prec\t$rec\t$f1")
    }
  }

}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy