All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowSentenceDetectorDL.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.nlp.annotators.ner.Verbose
import com.johnsnowlabs.nlp.util.io.OutputHelper
import org.apache.spark.ml.util.Identifiable
import org.tensorflow.Graph
import org.tensorflow.proto.framework.GraphDef

import scala.collection.JavaConverters._
import scala.util.Random

class TensorflowSentenceDetectorDL(
    val model: TensorflowWrapper,
    val verboseLevel: Verbose.Value = Verbose.All,
    val outputLogsPath: Option[String] = None)
    extends Serializable
    with Logging {

  private val initKey = "init"
  private val inputsKey = "inputs"
  private val targetsKey = "targets"
  private val classWeightsKey = "class_weights"
  private val dropoutKey = "dropout"
  private val learningRateKey = "learning_rate"
  private val trainingKey = "optimizer"
  private val lossKey = "loss"
  private val outputsKey = "outputs"
  private val predictionsKey = "predictions"
  private val accuracyKey = "accuracy"

  private lazy val _graphOperations = {
    val graph = new Graph()
    graph.importGraphDef(GraphDef.parseFrom(model.graph))
    graph.operations().asScala.toArray
  }

  private lazy val _inputDim: Int = {
    val op = _graphOperations.find(op => op.name() == inputsKey)
    if (op.isDefined) {
      op.get.output(0).shape().size(1).toInt
    } else {
      throw new Exception("Can't find input tensor")
    }
  }

  private lazy val _outputDim: Int = {
    val op = _graphOperations.find(op => op.name() == outputsKey)
    if (op.isDefined) {
      op.get.output(0).shape().size(1).toInt
    } else {
      throw new Exception("Can't find output tensor")
    }
  }

  def getTFModel: TensorflowWrapper = this.model

  protected def logMessage(message: String, uuid: String): Unit = {

    if (outputLogsPath.isDefined) {
      outputLog(message, uuid, true, outputLogsPath.get)
    }

  }

  def train(
      features: Array[Array[Float]],
      labels: Array[Array[Float]],
      batchSize: Int,
      epochsNumber: Int,
      learningRate: Float = 0.001f,
      validationSplit: Float = 0.0f,
      classWeights: Array[Float],
      dropout: Float = 0.0f,
      configProtoBytes: Option[Array[Byte]] = None,
      uuid: String = Identifiable.randomUID("annotator")): Unit = {

    model.createSession(configProtoBytes).runner.addTarget(initKey).run()

    val outputClassWeights = classWeights.padTo(_outputDim, 0.0f)

    val zippedDataset = features
      .map(x => x.padTo(_inputDim, 0.0f))
      .zip(labels.map(x => x.padTo(_outputDim, 0.0f)))
      .toSeq

    val allData = Random.shuffle(zippedDataset)

    val (trainDataset, validationDataset) = if (validationSplit > 0f) {
      allData.splitAt((features.length * (1 - validationSplit)).toInt)
    } else {
      // No validationSplit has been set so just use the entire training Dataset
      val emptyValid: Seq[(Array[Float], Array[Float])] = Seq((Array.empty, Array.empty))
      (allData, emptyValid)
    }

    println(f"Training $epochsNumber epochs")
    logMessage(f"Training $epochsNumber epochs", uuid)

    for (epoch <- 1 to epochsNumber) {

      var loss = 0.0f
      var acc = 0.0f
      var batches = 0

      val time = System.nanoTime()

      val randomizedTrainingData = Random.shuffle(trainDataset).toArray

      for (batch <- randomizedTrainingData.grouped(batchSize)) {

        val tensors = new TensorResources()

        val featuresArray = batch.map(x => x._1)
        val labelsArray = batch.map(x => x._2)

        val inputTensor = tensors.createTensor(featuresArray)
        val labelTensor = tensors.createTensor(labelsArray)
        val lrTensor = tensors.createTensor(learningRate)
        val classWeightsTensor = tensors.createTensor(outputClassWeights)
        val dropoutTensor = tensors.createTensor(dropout)

        val calculated = model
          .getTFSession(configProtoBytes)
          .runner
          .feed(inputsKey, inputTensor)
          .feed(targetsKey, labelTensor)
          .feed(learningRateKey, lrTensor)
          .feed(classWeightsKey, classWeightsTensor)
          .feed(dropoutKey, dropoutTensor)
          .addTarget(trainingKey)
          .fetch(lossKey)
          .fetch(accuracyKey)
          .run()
        loss += TensorResources.extractFloats(calculated.get(0))(0)
        acc += TensorResources.extractFloats(calculated.get(1))(0)
        batches += 1
        tensors.clearTensors()

      }

      acc /= batches

      if (validationSplit > 0.0) {
        val (validationFeatures, validationLabels) = validationDataset.toArray.unzip
        val (_, valid_acc) = internalPredict(
          validationFeatures,
          validationLabels,
          configProtoBytes,
          outputClassWeights)
        val endTime = (System.nanoTime() - time) / 1e9
        println(
          f"Epoch $epoch/$epochsNumber\t$endTime%.2fs\tLoss: $loss\tACC: $acc\tValidation ACC: $valid_acc")
        logMessage(
          f"Epoch $epoch/$epochsNumber\t$endTime%.2fs\tLoss: $loss\tACC: $acc\tValidation ACC: $valid_acc",
          uuid)
      } else {
        val endTime = (System.nanoTime() - time) / 1e9
        println(f"Epoch $epoch/$epochsNumber\t$endTime%.2fs\tLoss: $loss\tACC: $acc")
        logMessage(f"Epoch $epoch/$epochsNumber\t$endTime%.2fs\tLoss: $loss\tACC: $acc", uuid)
      }
    }
    println(f"Training completed.")
    logMessage(f"Training completed.", uuid)

    if (outputLogsPath.isDefined) {
      OutputHelper.exportLogFile(outputLogsPath.get)
    }
  }

  protected def internalPredict(
      features: Array[Array[Float]],
      labels: Array[Array[Float]],
      configProtoBytes: Option[Array[Byte]] = None,
      classWeights: Array[Float]): (Float, Float) = {

    val tensors = new TensorResources()

    val inputTensor = tensors.createTensor(features)
    val labelTensor = tensors.createTensor(labels)
    val classWeightsTensor = tensors.createTensor(classWeights)

    val calculated = model
      .getTFSession(configProtoBytes)
      .runner
      .feed(inputsKey, inputTensor)
      .feed(targetsKey, labelTensor)
      .feed(classWeightsKey, classWeightsTensor)
      .fetch(lossKey)
      .fetch(accuracyKey)
      .run()

    val loss = TensorResources.extractFloats(calculated.get(0))(0)
    val acc = TensorResources.extractFloats(calculated.get(1))(0)

    tensors.clearTensors()

    (loss, acc)
  }

  def predict(
      features: Array[Array[Float]],
      configProtoBytes: Option[Array[Byte]] = None): (Array[Long], Array[Float]) = {

    val tensors = new TensorResources()
    val inputTensor = tensors.createTensor(features.map(x => x.padTo(_inputDim, 0.0f)))

    val calculated = model
      .getTFSession(configProtoBytes)
      .runner
      .feed(inputsKey, inputTensor)
      .fetch(predictionsKey)
      .fetch(outputsKey)
      .run()

    val prediction = TensorResources.extractLongs(calculated.get(0))
    val outputs = TensorResources.extractFloats(calculated.get(1)).grouped(_outputDim).toArray
    val confidence = 0.until(prediction.length).map(i => outputs(i)(prediction(i).toInt)).toArray

    tensors.clearTensors()

    (prediction, confidence)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy