All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.ViTClassifier.scala Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor
import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
import com.johnsnowlabs.nlp.annotators.cv.util.transform.ImageResizeUtils

import scala.collection.JavaConverters._

private[johnsnowlabs] class ViTClassifier(
    val tensorflowWrapper: Option[TensorflowWrapper],
    val onnxWrapper: Option[OnnxWrapper],
    configProtoBytes: Option[Array[Byte]] = None,
    tags: Map[String, BigInt],
    preprocessor: Preprocessor,
    signatures: Option[Map[String, String]] = None)
    extends Serializable {

  val _tfViTSignatures: Map[String, String] =
    signatures.getOrElse(ModelSignatureManager.apply())
  val detectedEngine: String =
    if (tensorflowWrapper.isDefined) TensorFlow.name
    else if (onnxWrapper.isDefined) ONNX.name
    else TensorFlow.name
  private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions

  private def sessionWarmup(): Unit = {
    val image =
      ImageIOUtils.loadImage(getClass.getResourceAsStream("/image/ox.JPEG"))
    val bytes = ImageIOUtils.bufferedImageToByte(image.get)
    val images =
      Array(AnnotationImage("image", "ox.JPEG", 265, 360, 3, 16, bytes, Map("image" -> "0")))
    val encoded = encode(images, preprocessor)
    tag(encoded)
  }

  sessionWarmup()

  def getRawScoresWithTF(batch: Array[Array[Array[Array[Float]]]]): Array[Float] = {
    val tensors = new TensorResources()
    val imageTensors = tensors.createTensor(batch)

    val session = tensorflowWrapper.get.getTFSessionWithSignature(
      configProtoBytes = configProtoBytes,
      savedSignatures = signatures,
      initAllTables = false)
    val runner = session.runner

    runner
      .feed(
        _tfViTSignatures
          .getOrElse(ModelSignatureConstants.PixelValuesInput.key, "missing_pixel_values"),
        imageTensors)
      .fetch(_tfViTSignatures
        .getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

    val outs = runner.run().asScala
    val rawScores = TensorResources.extractFloats(outs.head)

    tensors.clearSession(outs)
    tensors.clearTensors()
    imageTensors.close()
    rawScores
  }

  def getRowScoresWithOnnx(batch: Array[Array[Array[Array[Float]]]]): Array[Float] = {
    val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
    val imageTensors = OnnxTensor.createTensor(env, batch)
    val inputs =
      Map("pixel_values" -> imageTensors).asJava

    val results = runner.run(inputs)
    val rawScores = results
      .get("logits")
      .get()
      .asInstanceOf[OnnxTensor]
      .getFloatBuffer
      .array()

    results.close()
    imageTensors.close()
    rawScores
  }

  def tag(
      batch: Array[Array[Array[Array[Float]]]],
      activation: String = ActivationFunction.softmax): Array[Array[Float]] = {

    val batchLength = batch.length
    val rawScores = detectedEngine match {
      case ONNX.name => getRowScoresWithOnnx(batch)
      case _ => getRawScoresWithTF(batch)
    }
    val dim = rawScores.length / batchLength
    val batchScores: Array[Array[Float]] =
      rawScores
        .grouped(dim)
        .map(scores => calculateSoftmax(scores))
        .toArray
    batchScores
  }

  /** Calculate softmax from returned logits
    * @param scores
    *   logits output from output layer
    * @return
    */
  def calculateSoftmax(scores: Array[Float]): Array[Float] = {
    val exp = scores.map(x => math.exp(x))
    exp.map(x => x / exp.sum).map(_.toFloat)
  }

  /** Calculate sigmoid from returned logits
    * @param scores
    *   logits output from output layer
    * @return
    */
  def calculateSigmoid(scores: Array[Float]): Array[Float] = {
    scores.map(x => 1 / (1 + Math.exp(-x)).toFloat)
  }

  def predict(
      images: Array[AnnotationImage],
      batchSize: Int,
      preprocessor: Preprocessor,
      activation: String = ActivationFunction.softmax): Seq[Annotation] = {

    images
      .grouped(batchSize)
      .flatMap { batch =>
        val encoded = encode(batch, preprocessor)
        val logits = tag(encoded, activation)

        batch.zip(logits).map { case (image, score) =>
          val label =
            tags
              .find(_._2 == score.zipWithIndex.maxBy(_._1)._2)
              .map(_._1)
              .getOrElse(
                tags
                  .find(
                    _._2 == score.zipWithIndex.maxBy(_._1)._2.toString
                  ) // TODO: We shouldn't compare unrelated types: BigInt and String
                  .map(_._1)
                  .getOrElse("NA"))
          val meta = score.zipWithIndex.flatMap(x =>
            Map(tags.take(10).find(_._2 == x._2).map(_._1).toString -> x._1.toString))

          val imageMeta = Map(
            "height" -> image.height.toString,
            "width" -> image.width.toString,
            "nChannels" -> image.nChannels.toString,
            "mode" -> image.mode.toString,
            "origin" -> image.origin)

          Annotation(
            annotatorType = AnnotatorType.CATEGORY,
            begin = 0,
            end = label.length - 1,
            result = label,
            metadata = Map("image" -> "0") ++ imageMeta ++ meta)
        }

      }
  }.toSeq

  def encode(
      annotations: Array[AnnotationImage],
      preprocessor: Preprocessor): Array[Array[Array[Array[Float]]]] = {

    val batchProcessedImages = annotations.map { annot =>
      val bufferedImage = ImageIOUtils.byteToBufferedImage(
        bytes = annot.result,
        w = annot.width,
        h = annot.height,
        nChannels = annot.nChannels)

      val resizedImage = if (preprocessor.do_resize) {
        ImageResizeUtils.resizeBufferedImage(
          width = preprocessor.size,
          height = preprocessor.size,
          preprocessor.resample)(bufferedImage)
      } else bufferedImage

      val normalizedImage =
        ImageResizeUtils.normalizeAndConvertBufferedImage(
          img = resizedImage,
          mean = preprocessor.image_mean,
          std = preprocessor.image_std,
          doNormalize = preprocessor.do_normalize,
          doRescale = preprocessor.do_rescale,
          rescaleFactor = preprocessor.rescale_factor)

      normalizedImage
    }

    batchProcessedImages

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy