All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.VisionEncoderDecoder.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig}
import com.johnsnowlabs.ml.onnx.OnnxSession
import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor
import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
import com.johnsnowlabs.nlp.annotators.cv.util.transform.ImageResizeUtils
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.Gpt2Tokenizer
import org.intel.openvino.InferRequest
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

private[johnsnowlabs] class VisionEncoderDecoder(
    val tensorflowWrapper: Option[TensorflowWrapper],
    val onnxWrappers: Option[EncoderDecoderWithoutPastWrappers],
    configProtoBytes: Option[Array[Byte]] = None,
    tokenizer: Gpt2Tokenizer,
    preprocessor: Preprocessor,
    generationConfig: GenerationConfig,
    signatures: Option[Map[String, String]] = None)
    extends Generate {

  val _tfSignatures: Map[String, String] =
    signatures.getOrElse(ModelSignatureManager.apply())

  val tensorResources = new TensorResources()
  private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions

  val detectedEngine: String =
    if (tensorflowWrapper.isDefined) TensorFlow.name
    else if (onnxWrappers.isDefined) ONNX.name
    else throw new IllegalArgumentException("No model engine defined.")
  private def sessionWarmup(): Unit = {
    val nChannels = 3
    val dummyInput = Array(
      AnnotationImage(
        AnnotatorType.IMAGE,
        "",
        preprocessor.size,
        preprocessor.size,
        nChannels,
        0,
        Array.fill[Byte](preprocessor.size * preprocessor.size * nChannels)(0.toByte),
        Map.empty))

    generateFromImage(
      images = dummyInput,
      batchSize = 1,
      maxOutputLength = 1,
      minOutputLength = 0,
      doSample = false,
      beamSize = 1,
      numReturnSequences = 1,
      temperature = 1.0,
      topK = 1,
      topP = 1.0,
      repetitionPenalty = 0,
      noRepeatNgramSize = 0,
      randomSeed = None)
  }

  sessionWarmup()

  private object TfSignatures {
    object InputOps {
      val encoderInput = _tfSignatures
        .getOrElse(ModelSignatureConstants.PixelValuesInput.key, "missing_pixel_values")
      val decoderEncoderState = _tfSignatures.getOrElse(
        ModelSignatureConstants.DecoderEncoderInputIds.key,
        ModelSignatureConstants.DecoderEncoderInputIds.value)
      val decoderInputIds = _tfSignatures.getOrElse(
        ModelSignatureConstants.DecoderInputIds.key,
        ModelSignatureConstants.DecoderInputIds.value)
    }

    object OutputOps {
      val encoderState = _tfSignatures
        .getOrElse(
          ModelSignatureConstants.LastHiddenState.key,
          ModelSignatureConstants.LastHiddenState.value)
      val decoderLogits = _tfSignatures
        .getOrElse(
          ModelSignatureConstants.LogitsOutput.key,
          ModelSignatureConstants.LogitsOutput.value)
    }
  }

  private object OnnxSignatures {
    val encoderInputIdsTensor: String = "pixel_values"
    val encoderOutputKey = "last_hidden_state"
    val decoderOutputKey: String = "logits"
    val decoderInputIDs: String = "input_ids"
    val decoderEncoderState: String = "encoder_hidden_states"

  }

  private def preprocessImages(
      annotations: Array[AnnotationImage]): Array[Array[Array[Array[Float]]]] = {

    val batchProcessedImages = annotations.map { annot =>
      val bufferedImage = ImageIOUtils.byteToBufferedImage(
        bytes = annot.result,
        w = annot.width,
        h = annot.height,
        nChannels = annot.nChannels)

      val resizedImage = if (preprocessor.do_resize) {
        ImageResizeUtils.resizeBufferedImage(
          width = preprocessor.size,
          height = preprocessor.size,
          preprocessor.resample)(bufferedImage)
      } else bufferedImage

      val normalizedImage =
        ImageResizeUtils.normalizeAndConvertBufferedImage(
          img = resizedImage,
          mean = preprocessor.image_mean,
          std = preprocessor.image_std,
          doNormalize = preprocessor.do_normalize,
          doRescale = preprocessor.do_rescale,
          rescaleFactor = preprocessor.rescale_factor)

      normalizedImage
    }

    batchProcessedImages
  }

  /** Encodes processed images with the encoder.
    *
    * Expands the initial encoded images to match the requested beam size.
    *
    * @param batch
    *   Batch of images with dimensions (batchSize, R, G, B)
    * @return
    *   Tensor with encoded representations of the batch
    */
  private def encodeImages(
      batch: Array[Array[Array[Array[Float]]]],
      beamSize: Int,
      tfSession: Option[Session],
      onnxSession: Option[(OrtSession, OrtEnvironment)]): AutoCloseable = {

    val batchForBeams =
      batch.flatMap(imageFloats => Array.fill(beamSize)(imageFloats))
    detectedEngine match {
      case TensorFlow.name =>
        // Expand the array for each beam

        val imageTensors = tensorResources.createTensor(batchForBeams)

        val runner = tfSession.get.runner
          .feed(TfSignatures.InputOps.encoderInput, imageTensors)
          .fetch(TfSignatures.OutputOps.encoderState)

        val outs = runner.run().asScala
        outs.head

      case ONNX.name =>
        val (session, env) = onnxSession.get
        val imageTensors = OnnxTensor.createTensor(env, batchForBeams)
        val encoderResults = session
          .run(Map(OnnxSignatures.encoderInputIdsTensor -> imageTensors).asJava)
        val output = encoderResults
          .get(OnnxSignatures.encoderOutputKey)
          .get()
          .asInstanceOf[OnnxTensor]
        output

      case _ =>
        throw new IllegalArgumentException("Unknown engine type.")
    }
  }

  def generateFromImage(
      images: Array[AnnotationImage],
      batchSize: Int,
      maxOutputLength: Int,
      minOutputLength: Int,
      doSample: Boolean,
      beamSize: Int,
      numReturnSequences: Int,
      temperature: Double,
      topK: Int,
      topP: Double,
      repetitionPenalty: Double,
      noRepeatNgramSize: Int,
      randomSeed: Option[Long]): Seq[Annotation] = {

    val captions: Seq[Annotation] = images
      .grouped(batchSize)
      .flatMap { batch =>
        val batchSize = batch.length
        val preprocessedImages = preprocessImages(images)

        val batchDecoderStartIds = Array.fill(batchSize, 1)(generationConfig.bosId)
        val encoderIds: Array[Array[Int]] = Array.fill(batchDecoderStartIds.length)(Array.empty)
        val generatedTokenIds =
          detectedEngine match {
            case TensorFlow.name =>
              val session: Session = tensorflowWrapper.get
                .getTFSessionWithSignature(
                  configProtoBytes = configProtoBytes,
                  initAllTables = false)
              val encodedImages = encodeImages(preprocessedImages, beamSize, Some(session), None)
                .asInstanceOf[Tensor]
              generate(
                inputIds = encoderIds,
                decoderEncoderStateTensors = Left(encodedImages),
                encoderAttentionMaskTensors = null,
                decoderInputs = batchDecoderStartIds,
                maxOutputLength,
                minOutputLength,
                doSample,
                beamSize,
                numReturnSequences,
                temperature,
                topK,
                topP,
                repetitionPenalty,
                noRepeatNgramSize,
                generationConfig.vocabSize,
                generationConfig.eosId,
                generationConfig.padId,
                randomSeed,
                Array.empty,
                Left(session))
            case ONNX.name =>
              val (encoderSession, encoderEnv) =
                onnxWrappers.get.encoder.getSession(onnxSessionOptions)
              val (decoderSession, decoderEnv) =
                onnxWrappers.get.decoder.getSession(onnxSessionOptions)
              val encodedImages =
                encodeImages(
                  preprocessedImages,
                  beamSize,
                  None,
                  Some((encoderSession, encoderEnv)))
                  .asInstanceOf[OnnxTensor]
              generate(
                inputIds = batchDecoderStartIds,
                decoderEncoderStateTensors = Right(encodedImages),
                encoderAttentionMaskTensors =
                  Right(OnnxTensor.createTensor(encoderEnv, Array(1))),
                decoderInputs = batchDecoderStartIds,
                maxOutputLength,
                minOutputLength,
                doSample,
                beamSize,
                numReturnSequences,
                temperature,
                topK,
                topP,
                repetitionPenalty,
                noRepeatNgramSize,
                generationConfig.vocabSize,
                generationConfig.eosId,
                generationConfig.padId,
                randomSeed,
                Array.empty,
                Right((decoderEnv, decoderSession)))

          }

        val decodedStringsBatch = generatedTokenIds.map(tokenizer.decodeTokens).map(_.trim)

        batch.zip(decodedStringsBatch).map { case (image, caption) =>
          val imageMeta = Map(
            "height" -> image.height.toString,
            "width" -> image.width.toString,
            "nChannels" -> image.nChannels.toString,
            "mode" -> image.mode.toString,
            "origin" -> image.origin)

          Annotation(
            annotatorType = AnnotatorType.DOCUMENT,
            begin = 0,
            end = caption.length - 1,
            result = caption,
            metadata = Map("image" -> "0") ++ imageMeta)
        }
      }
      .toSeq

    // tensorResources.clearTensors()

    captions
  }

  /** Calls the model and returns the output logits.
    *
    * @param encoderInputIds
    *   Input IDs for the, not used for vision encoders
    * @param decoderInputIds
    *   Input IDs for the Decoder
    * @param decoderEncoderStateTensors
    *   Tensor of encoded pixel values for the decoder
    * @param encoderAttentionMaskTensors
    *   Tensor for encoder attention mask, not used
    * @param maxLength
    *   Max length of the input
    * @param session
    *   Tensorflow Session
    * @return
    *   Logits for the input
    */
  override def getModelOutput(
      encoderInputIds: Seq[Array[Int]],
      decoderInputIds: Seq[Array[Int]],
      decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
      encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
      maxLength: Int,
      session: Either[Session, (OrtEnvironment, OrtSession)],
      ovInferRequest: Option[InferRequest]): Array[Array[Float]] = {
    getModelOutput(decoderInputIds, decoderEncoderStateTensors, session)
  }

  def getModelOutput(
      decoderInputIds: Seq[Array[Int]],
      decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
      session: Either[Session, (OrtEnvironment, OrtSession)]) = {

    val decoderEncoderStateTensor = decoderEncoderStateTensors.fold(
      tfTensor => {
        // not implemented yet
        null
      },
      onnxTensor => onnxTensor)

    detectedEngine match {
      case TensorFlow.name =>
        val decoderInputIdsTensor = tensorResources.createTensor(decoderInputIds.toArray)

        val runner =
          session.left.get
            .runner()
            .feed(TfSignatures.InputOps.decoderEncoderState, decoderEncoderStateTensors.left.get)
            .feed(TfSignatures.InputOps.decoderInputIds, decoderInputIdsTensor)
            .fetch(TfSignatures.OutputOps.decoderLogits)

        val decoderOuts = runner.run().asScala
        val logitsRaw = TensorResources.extractFloats(decoderOuts.head)
        decoderOuts.head.close()

        val logits = logitsRaw.grouped(generationConfig.vocabSize)

        logits.toArray
      case ONNX.name =>
        val (env, decoderSession) = session.right.get
        val decoderInputIdsLong: Array[Array[Long]] =
          decoderInputIds.toArray.map { tokenIds => tokenIds.map(_.toLong) }

        val decoderInputIdsTensor =
          OnnxTensor.createTensor(env, decoderInputIdsLong)

        val decoderInputs: java.util.Map[String, OnnxTensor] = Map(
          OnnxSignatures.decoderInputIDs -> decoderInputIdsTensor,
          OnnxSignatures.decoderEncoderState -> decoderEncoderStateTensor).asJava
        val sessionOutput = decoderSession.run(decoderInputs)

        val sequenceLength = decoderInputIds.head.length
        val batchSize = decoderInputIds.length

        val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutputKey)
        val decoderOutputs = (0 until batchSize).map(i => {
          logitsRaw
            .slice(
              i * sequenceLength * generationConfig.vocabSize + (sequenceLength - 1) * generationConfig.vocabSize,
              i * sequenceLength * generationConfig.vocabSize + sequenceLength * generationConfig.vocabSize)
        })
        decoderOutputs.toArray

    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy