All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowClassifier.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.nlp.annotators.classifier.dl.ClassifierMetrics
import com.johnsnowlabs.nlp.annotators.ner.Verbose
import com.johnsnowlabs.nlp.util.io.OutputHelper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.apache.spark.ml.util.Identifiable

import scala.collection.mutable
import scala.util.Random

class TensorflowClassifier(
    val tensorflow: TensorflowWrapper,
    val encoder: ClassifierDatasetEncoder,
    val testEncoder: Option[ClassifierDatasetEncoder],
    override val verboseLevel: Verbose.Value)
    extends Serializable
    with ClassifierMetrics {

  private val inputKey = "inputs:0"
  private val labelKey = "labels:0"
  private val learningRateKey = "lr:0"
  private val dropoutKey = "dp:0"

  private val numClasses: Int = encoder.params.tags.length

  private val predictionKey = s"softmax_output_$numClasses/Softmax:0"
  private val optimizer = s"optimizer_adam_$numClasses/Adam/Assign:0"
  private val cost = s"loss_$numClasses/softmax_cross_entropy_with_logits_sg:0"
  private val accuracy = s"accuracy_$numClasses/mean_accuracy:0"
  private val initKey = "init_all_tables"

  def train(
      inputs: (Array[Array[Float]], Array[String]),
      testInputs: Option[(Array[Array[Float]], Array[String])],
      classNum: Int,
      lr: Float = 5e-3f,
      batchSize: Int = 64,
      dropout: Float = 0.5f,
      startEpoch: Int = 0,
      endEpoch: Int = 10,
      configProtoBytes: Option[Array[Byte]] = None,
      validationSplit: Float = 0.0f,
      evaluationLogExtended: Boolean = false,
      enableOutputLogs: Boolean = false,
      outputLogsPath: String,
      uuid: String = Identifiable.randomUID("classifierdl")): Unit = {

    // Initialize
    if (startEpoch == 0)
      tensorflow
        .createSession(configProtoBytes = configProtoBytes)
        .runner
        .addTarget(initKey)
        .run()

    val (trainSet, validationSet, testSet) = buildDatasets(inputs, testInputs, validationSplit)

    println(
      s"Training started - epochs: $endEpoch - learning_rate: $lr - batch_size: $batchSize - training_examples: ${trainSet.length} - classes: $classNum")
    outputLog(
      s"Training started - epochs: $endEpoch - learning_rate: $lr - batch_size: $batchSize - training_examples: ${trainSet.length} - classes: $classNum",
      uuid,
      enableOutputLogs,
      outputLogsPath)

    for (epoch <- startEpoch until endEpoch) {

      val time = System.nanoTime()
      var batches = 0
      var loss = 0f
      var acc = 0f
      val learningRate = lr / (1 + dropout * epoch)

      for (batch <- trainSet.grouped(batchSize)) {
        val tensors = new TensorResources()

        val inputArrays = batch.map(x => x._1)
        val labelsArray = batch.map(x => x._2)

        val inputTensor = tensors.createTensor(inputArrays)
        val labelTensor = tensors.createTensor(labelsArray)
        val lrTensor = tensors.createTensor(learningRate)
        val dpTensor = tensors.createTensor(dropout)

        val calculated = tensorflow
          .getTFSession(configProtoBytes = configProtoBytes)
          .runner
          .feed(inputKey, inputTensor)
          .feed(labelKey, labelTensor)
          .feed(learningRateKey, lrTensor)
          .feed(dropoutKey, dpTensor)
          .fetch(optimizer)
          .fetch(predictionKey)
          .fetch(cost)
          .fetch(accuracy)
          .run()

        loss += TensorResources.extractFloats(calculated.get(2))(0)
        acc += TensorResources.extractFloats(calculated.get(3))(0)
        batches += 1

        tensors.clearTensors()
      }
      acc /= (trainSet.length / batchSize)
      acc = acc.min(1.0f).max(0.0f)

      val endTime = (System.nanoTime() - time) / 1e9
      println(
        f"Epoch ${epoch + 1}/$endEpoch - $endTime%.2fs - loss: $loss - acc: $acc - batches: $batches")
      outputLog(
        f"Epoch $epoch/$endEpoch - $endTime%.2fs - loss: $loss - acc: $acc - batches: $batches",
        uuid,
        enableOutputLogs,
        outputLogsPath)

      if (validationSet.nonEmpty && validationSplit > 0.0) {
        println(
          s"Quality on validation dataset (${validationSplit * 100}%), validation examples = ${validationSet.length}")
        outputLog(
          s"Quality on validation dataset (${validationSplit * 100}%), validation examples = ${validationSet.length}",
          uuid,
          enableOutputLogs,
          outputLogsPath)

        measure(
          validationSet,
          "validation",
          extended = evaluationLogExtended,
          enableOutputLogs,
          outputLogsPath)
      } else if (validationSet.isEmpty) {
        println(f"WARNING: Could not create validation set. " +
          f"Number of data points (${inputs._1.length}) not enough for validation split $validationSplit.")
      }

      if (testSet.nonEmpty) {
        println(s"Quality on test dataset: ")
        outputLog("Quality on test dataset: ", uuid, enableOutputLogs, outputLogsPath)

        measure(
          testSet,
          "test",
          extended = evaluationLogExtended,
          enableOutputLogs,
          outputLogsPath)
      }

    }

    if (enableOutputLogs) {
      OutputHelper.exportLogFile(outputLogsPath)
    }
  }

  private def buildDatasets(
      inputs: (Array[Array[Float]], Array[String]),
      testInputs: Option[(Array[Array[Float]], Array[String])],
      validationSplit: Float): (
      Array[(Array[Float], Array[Int])],
      Array[(Array[Float], Array[Int])],
      Array[(Array[Float], Array[Int])]) = {

    val trainingDataset = Random.shuffle(encodeInputs(inputs, "train").toSeq).toArray
    val sample: Int = (trainingDataset.length * validationSplit).toInt

    val (newTrainDataset, validateDatasetSample) = if (validationSplit > 0f) {
      val (trainingSample, trainingSet) = trainingDataset.splitAt(sample)
      (trainingSet, trainingSample)
    } else {
      // No validationSplit has been set so just use the entire training Dataset
      val emptyValid: Array[(Array[Float], Array[Int])] = Array((Array.empty, Array.empty))
      (trainingDataset, emptyValid)
    }

    val testDataset: Array[(Array[Float], Array[Int])] =
      if (testInputs.isDefined) encodeInputs(testInputs.get, "test") else Array.empty

    (newTrainDataset, validateDatasetSample, testDataset)
  }

  private def encodeInputs(
      inputs: (Array[Array[Float]], Array[String]),
      sourceData: String): Array[(Array[Float], Array[Int])] = {

    val (embeddings, labels) = inputs
    val myEncoder = if (sourceData == "train") encoder else testEncoder.get
    val encodedLabels = myEncoder.encodeTags(labels)

    embeddings.zip(encodedLabels)
  }

  def predict(
      docs: Seq[(Int, Seq[Annotation])],
      configProtoBytes: Option[Array[Byte]] = None): Seq[Annotation] = {

    val tensors = new TensorResources()

    // FixMe: implement batchSize

    val inputs = encoder.extractSentenceEmbeddings(docs)

    val calculated = tensorflow
      .getTFSession(configProtoBytes = configProtoBytes)
      .runner
      .feed(inputKey, tensors.createTensor(inputs))
      .fetch(predictionKey)
      .run()

    val tagsId = TensorResources.extractFloats(calculated.get(0)).grouped(numClasses).toArray
    val tagsName = encoder.decodeOutputData(tagIds = tagsId)
    tensors.clearTensors()

    docs.flatMap { sentence =>
      sentence._2.zip(tagsName).map { case (content, score) =>
        val label = score.find(_._1 == score.maxBy(_._2)._1).map(_._1).getOrElse("NA")

        Annotation(
          annotatorType = AnnotatorType.CATEGORY,
          begin = content.begin,
          end = content.end,
          result = label,
          metadata = Map("sentence" -> sentence._1.toString) ++ score.flatMap(x =>
            Map(x._1 -> x._2.toString)))
      }

    }

  }

  def internalPredict(
      inputs: Array[Array[Float]],
      numClasses: Int,
      configProtoBytes: Option[Array[Byte]] = None): Array[Int] = {

    val tensors = new TensorResources()

    val calculated = tensorflow
      .getTFSession(configProtoBytes = configProtoBytes)
      .runner
      .feed(inputKey, tensors.createTensor(inputs))
      .fetch(predictionKey)
      .run()

    val tagsId = TensorResources.extractFloats(calculated.get(0)).grouped(numClasses).toArray
    val predictedLabels = tagsId.map { score =>
      val labelId = score.zipWithIndex.maxBy(_._1)._2
      labelId
    }
    tensors.clearTensors()
    predictedLabels
  }

  def measure(
      labeled: Array[(Array[Float], Array[Int])],
      sourceData: String,
      extended: Boolean = false,
      enableOutputLogs: Boolean = false,
      outputLogsPath: String,
      batchSize: Int = 100): (Float, Float) = {

    val started = System.nanoTime()

    val evaluationEncoder = if (sourceData == "validation") encoder else testEncoder.get

    // ToDo: Add batch strategy
    val truePositives = mutable.Map[String, Int]()
    val falsePositives = mutable.Map[String, Int]()
    val falseNegatives = mutable.Map[String, Int]()
    val predicted = mutable.Map[String, Int]()
    val correct = mutable.Map[String, Int]()

    val originalEmbeddings = labeled.map(x => x._1)
    val originalLabels: Array[Int] = labeled.map(x => x._2).map { x =>
      x.zipWithIndex.maxBy(_._1)._2
    }

    val evaluationNumClasses =
      if (sourceData == "validation") numClasses else testEncoder.get.params.tags.length

    val predictedLabels: Array[Int] = internalPredict(originalEmbeddings, evaluationNumClasses)
    val labeledPredictions: Array[(Int, Int)] = predictedLabels.zip(originalLabels)

    for (labeledPrediction <- labeledPredictions) {
      val predict = labeledPrediction._1
      val original = labeledPrediction._2
      val groundTruthTag = evaluationEncoder.tags(original)
      val predictedTag = evaluationEncoder.tags(predict)

      correct(groundTruthTag) = correct.getOrElse(groundTruthTag, 0) + 1
      predicted(predictedTag) = predicted.getOrElse(predictedTag, 0) + 1

      if (original == predict) {
        truePositives(groundTruthTag) = truePositives.getOrElse(groundTruthTag, 0) + 1
      } else {
        falsePositives(predictedTag) = falsePositives.getOrElse(predictedTag, 0) + 1
        falseNegatives(groundTruthTag) = falseNegatives.getOrElse(groundTruthTag, 0) + 1
      }
    }

    val endTime = (System.nanoTime() - started) / 1e9
    println(f"time to finish evaluation: $endTime%.2fs")

    val labels = (correct.keys ++ predicted.keys).toSeq.distinct
    aggregatedMetrics(
      labels,
      truePositives.toMap,
      falsePositives.toMap,
      falseNegatives.toMap,
      extended,
      enableOutputLogs,
      outputLogsPath)

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy