All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.Janus.scala Maven / Gradle / Ivy

There is a newer version: 6.0.3
Show newest version
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.ai
import java.lang.Math

import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig
import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.JanusWrappers
import com.johnsnowlabs.nlp.annotators.common.Sentence
import com.johnsnowlabs.ml.util.{ONNX, Openvino}
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common.SentenceSplit
import com.johnsnowlabs.nlp.annotators.cv.util.transform.ImageResizeUtils
import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor
import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, JanusTokenizer, SpecialTokens}
import org.intel.openvino.{InferRequest, Tensor}

import javax.imageio.ImageIO
import scala.util.Random
import scala.reflect.ClassTag
import java.awt.{Color, Graphics2D}
import java.awt.image.BufferedImage
import java.io.ByteArrayOutputStream
import scala.collection.JavaConverters._

private[johnsnowlabs] class Janus(
    val onnxWrappers: Option[DecoderWrappers],
    val openvinoWrapper: Option[JanusWrappers],
    merges: Map[(String, String), Int],
    vocabulary: Map[String, Int],
    addedTokens: Map[String, Int],
    preprocessor: Preprocessor,
    generationConfig: GenerationConfig,
    imageTokenLength: Int,
    imageToken: Int)
    extends Serializable {

  val detectedEngine: String =
    if (onnxWrappers.isDefined) ONNX.name
    else if (openvinoWrapper.isDefined) Openvino.name
    else Openvino.name

  private val GenerationConfig(
    bosTokenId: Int,
    paddingTokenId: Int,
    eosTokenId: Int,
    vocabSize: Int,
    beginSuppressTokens,
    suppressTokenIds,
    forcedDecoderIds) =
    generationConfig
  val reversedVocabulary: Map[Int, String] = vocabulary.map(_.swap)

  val specialTokens: SpecialTokens = SpecialTokens(
    vocabulary,
    startTokenString = reversedVocabulary(bosTokenId),
    endTokenString = reversedVocabulary(eosTokenId),
    unkTokenString = reversedVocabulary(eosTokenId),
    maskTokenString = reversedVocabulary(eosTokenId),
    padTokenString = reversedVocabulary(paddingTokenId),
    additionalStrings = addedTokens.keys.toArray)

  val bpeTokenizer: JanusTokenizer = BpeTokenizer
    .forModel(
      "Janus",
      merges = merges,
      vocab = vocabulary,
      specialTokens = Some(specialTokens),
      addPrefixSpaceToSentence = true,
      alwaysAddPrefix = false)
    .asInstanceOf[JanusTokenizer]

  var randomSeedGenerator = new Random()

  /** Decode a sequence of sentences
    * @param sentences
    *   Sequence of sentences
    * @return
    *   Sequence of decoded sentences
    */
  def decode(sentences: Array[Array[Int]]): Seq[String] = {
    sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt)))
  }

  /** Encode a sequence of sentences for generation
    * @param sentences
    *   Sequence of sentences
    * @return
    *   Sequence of encoded sentences
    */
  private def encodeTextForGeneration(sentences: Seq[Annotation]): Seq[Array[Int]] = {
    val startOfImage = ""
    val endOfImage = ""
    val startOfImageToken = vocabulary.getOrElse(startOfImage, 100016)
    val endOfImageToken = vocabulary.getOrElse(endOfImage, 100593)

    // encode text and add beginning of image token

    val tokens = SentenceSplit
      .unpack(sentences)
      .map(s => {
        val sentWithTask = s
        bpeTokenizer
          .tokenize(sentWithTask)
          .map(bpeTokenizer.encode)
          .flatMap(_.map(_.pieceId))
      })
      .map(s => Array(bosTokenId) ++ s ++ Array(startOfImageToken))

    tokens

  }

  /** Encode a sequence of sentences
    * @param sentences
    *   Sequence of sentences
    * @return
    *   Sequence of encoded sentences
    */
  def encodeText(sentences: Seq[Annotation], imgTokenLen: List[Int]): Seq[Array[Int]] = {

    val pattern = raw"".r

    val startOfImage = ""
    val endOfImage = ""
    val startOfImageToken = vocabulary.getOrElse(startOfImage, 100016)
    val endOfImageToken = vocabulary.getOrElse(endOfImage, 100593)

    // raise an error if the pattern is not found in the text
    if (pattern.findFirstIn(sentences.head.result).isEmpty) {
      throw new IllegalArgumentException(
        "The pattern  is not found in the text")
    }

    // split the sentences into chunks based on the pattern and tokenize them
    // eg in python prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)]
    val promptChunks = sentences
      .map(s => {
        val sentWithTask = s.result
        var offsetLength = 0
        pattern
          .split(sentWithTask)
          .zipWithIndex
          .map(s => {
            val sentenceWithTask = Sentence(
              content = s._1,
              start = offsetLength,
              end = offsetLength + s._1.length,
              index = s._2)
            offsetLength += s._1.length
            bpeTokenizer
              .tokenize(sentenceWithTask)
              .map(bpeTokenizer.encode)
              .flatMap(_.map(_.pieceId))
          })
      })

    // inject the image padding tokens of length imgTokenLen between the prompt chunks and reduce the Seq[Array[Array[Int]]] to Seq[Array[Int]]
    val tokens = promptChunks
      .zip(imgTokenLen)
      .map(s => {
        val (promptChunk, imgTokenLen) = s
        val imgPaddingTokens =
          Array(startOfImageToken) ++ Array.fill(imgTokenLen)(imageToken) ++ Array(
            endOfImageToken)
        val combinedChunks = promptChunk
          .map(_.toArray)
          .reduce(_ ++ imgPaddingTokens ++ _)
        Array(bosTokenId) ++ combinedChunks
      })

    //    val tokens = SentenceSplit
    //      .unpack(sentences)
    //      .map(s => {
    //        val sentWithTask = s
    //        bpeTokenizer
    //          .tokenize(sentWithTask)
    //          .map(bpeTokenizer.encode)
    //          .flatMap(_.map(_.pieceId))
    //      })
    tokens
  }

  def encode(
      imageAnnotations: Seq[AnnotationImage],
      sentences: Seq[Annotation],
      preprocessor: Preprocessor,
      imageTokenLength: Int = imageTokenLength)
      : (Seq[Array[Int]], Array[Array[Array[Array[Array[Float]]]]]) = {
    val preprocessedImages = encodeImage(imageAnnotations.toArray, preprocessor)
    val encodedText = encodeText(sentences, List(imageTokenLength)).toArray

    (encodedText, preprocessedImages)
  }

  def tag(
      batch: Seq[Array[Int]],
      images: Array[Array[Array[Array[Array[Float]]]]],
      minOutputLength: Int,
      maxOutputLength: Int,
      doSample: Boolean,
      temperature: Double,
      topK: Int,
      topP: Double,
      repetitionPenalty: Double,
      noRepeatNgramSize: Int,
      randomSeed: Option[Long],
      ignoreTokenIds: Array[Int] = Array(),
      beamSize: Int,
      maxInputLength: Int,
      stopTokenIds: Array[Int]): Array[Array[Int]] = {

    val pixelValues = images
    val ignoreTokenIdsInt = ignoreTokenIds
    val expandedDecoderInputsVals = batch
    val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
    val maxSentenceLength = sequencesLength.max // - curLen
    //    val pixelValues = images._1
    //    val imageSizes = images._2
    val numReturn_sequences = 1
    // from config

    var effectiveBatch_size = 1
    var effectiveBatch_mult = 1

    if (doSample) {
      effectiveBatch_size = expandedDecoderInputsVals.length * numReturn_sequences
      effectiveBatch_mult = numReturn_sequences
    } else {
      effectiveBatch_size = expandedDecoderInputsVals.length
      effectiveBatch_mult = 1
    }

    val inferRequestLanguageModel =
      openvinoWrapper.get.languageModel.getCompiledModel().create_infer_request()
    val inferRequestVisionEmbeddingsModel =
      openvinoWrapper.get.visionEmbeddingsModel.getCompiledModel().create_infer_request()
    val inferRequestTextEmbeddingsModel =
      openvinoWrapper.get.textEmbeddingsModel.getCompiledModel().create_infer_request()
    val inferRequestLMHeadModel =
      openvinoWrapper.get.lmHeadModel.getCompiledModel().create_infer_request()
    val inferRequestMergeModel =
      openvinoWrapper.get.mergeModel.getCompiledModel().create_infer_request()

    val generatedIds = generateGreedy(
      batch.toArray,
      batch.toArray,
      pixelValues,
      maxOutputLength,
      inferRequestLanguageModel,
      inferRequestVisionEmbeddingsModel,
      inferRequestTextEmbeddingsModel,
      inferRequestLMHeadModel,
      inferRequestMergeModel)
    generatedIds
  }

  def generateGreedy(
      encoderInputIds: Array[Array[Int]],
      decoderInputIds: Array[Array[Int]],
      pixelValues: Array[Array[Array[Array[Array[Float]]]]],
      maxOutputLength: Int,
      inferRequestLanguageModel: InferRequest,
      inferRequestVisionEmbeddingsModel: InferRequest,
      inferRequestTextEmbeddingsModel: InferRequest,
      inferRequestLMHeadModel: InferRequest,
      inferRequestMergeModel: InferRequest): Array[Array[Int]] = {

    var generatedIds: Array[Array[Int]] = Array()
    var decoderInputIdsCopied = decoderInputIds
    while (!greedyGenerationFinished(generatedIds, eosTokenId, maxOutputLength)) {
      val decoderOutputs = getModelOutputs(
        encoderInputIds,
        decoderInputIdsCopied,
        pixelValues,
        inferRequestLanguageModel,
        inferRequestVisionEmbeddingsModel,
        inferRequestTextEmbeddingsModel,
        inferRequestLMHeadModel,
        inferRequestMergeModel)

      val nextTokenIds = decoderOutputs.map { scores =>
        argmax(scores)
      }

      if (generatedIds.isEmpty) {
        generatedIds = nextTokenIds.map(Array(_))
      } else {
        generatedIds =
          generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) =>
            currentIds ++ Array(nextId)
          }
      }

      // extend decoder input ids
      decoderInputIdsCopied =
        decoderInputIdsCopied.zip(nextTokenIds).map { case (currentIds, nextId) =>
          currentIds ++ Array(nextId)
        }
    }
    generatedIds
  }

  def predict(
      sentences: Seq[Annotation],
      imageAnnotations: Seq[AnnotationImage],
      imageGenerateMode: Boolean,
      batchSize: Int,
      minOutputLength: Int,
      maxOutputLength: Int,
      doSample: Boolean,
      temperature: Double,
      topK: Int,
      topP: Double,
      repetitionPenalty: Double,
      noRepeatNgramSize: Int,
      randomSeed: Option[Long] = None,
      ignoreTokenIds: Array[Int] = Array(),
      beamSize: Int,
      maxInputLength: Int,
      numOfParallelImages: Int): Seq[Annotation] = {

    if (imageGenerateMode) {
      randomSeedGenerator = randomSeed.map(s => new Random(s)).getOrElse(new Random())
      val encodedText: Array[Array[Int]] = encodeTextForGeneration(sentences).toArray
      val parallelSize = numOfParallelImages
      val tokens = Array.ofDim[Int](parallelSize * 2, encodedText.head.length)
      for (i <- 0 until parallelSize * 2) {
        if (i % 2 != 0) {
          tokens(i) = Array.fill(encodedText.head.length)(paddingTokenId)
          // update the first and last token to bos and eos respectively
          tokens(i)(0) = encodedText.head.head
          tokens(i)(encodedText.head.length - 1) = encodedText.head.last
        } else {
          tokens(i) = encodedText.head
        }
      }
      val generatedImages = generateImage(
        tokens,
        tokens,
        parallelSize = parallelSize,
        patchSize = 16,
        imageSize = preprocessor.size,
        randomSeed = randomSeed,
        inferRequestTextEmbeddingsModel =
          openvinoWrapper.get.textEmbeddingsModel.getCompiledModel().create_infer_request(),
        inferRequestGenEmbeddingsModel =
          openvinoWrapper.get.genEmbeddingsModel.getCompiledModel().create_infer_request(),
        inferRequestGenHeadModel =
          openvinoWrapper.get.genHeadModel.getCompiledModel().create_infer_request(),
        inferRequestLanguageModel =
          openvinoWrapper.get.languageModel.getCompiledModel().create_infer_request(),
        inferRequestGenDecoderModel =
          openvinoWrapper.get.genDecoderModel.getCompiledModel().create_infer_request())

      // group generated images into ( batch_size, parallel_size) and convert them to annotations
      val parallelSizeBatchedImages: Array[Array[BufferedImage]] =
        generatedImages.grouped(parallelSize).toArray

      val annotations = parallelSizeBatchedImages.zip(sentences).map { case (imgs, sent) =>
        var metadata = Map[String, String]()
        // add each image to the metadata
        imgs.zipWithIndex.foreach { case (img, i) =>
          val bos = new ByteArrayOutputStream()
          ImageIO.write(img, "png", bos)
          val base64EncodedImage = java.util.Base64.getEncoder.encodeToString(bos.toByteArray)
          metadata += (s"generated_image_$i" -> base64EncodedImage)
        }
        val annots = new Annotation(
          annotatorType = DOCUMENT,
          begin = 0,
          end = 0,
          result = sent.result,
          metadata = metadata)
        annots
      }
      annotations
    } else {
      val (encodedText, preprocessedImages) = encode(imageAnnotations, sentences, preprocessor)
      val tagged = tag(
        encodedText,
        preprocessedImages,
        minOutputLength,
        maxOutputLength,
        doSample,
        temperature,
        topK,
        topP,
        repetitionPenalty,
        noRepeatNgramSize,
        randomSeed,
        ignoreTokenIds,
        beamSize,
        maxInputLength,
        Array(eosTokenId))
      val decoded = decode(tagged)

      var sentBegin, nextSentEnd = 0
      val annotations = decoded.map { content =>
        nextSentEnd += content.length - 1
        val annots = new Annotation(
          annotatorType = DOCUMENT,
          begin = sentBegin,
          end = nextSentEnd,
          result = content,
          metadata = Map())
        sentBegin += nextSentEnd + 1
        annots
      }
      annotations
    }
  }

  def getModelOutputs(
      encoderInputIds: Array[Array[Int]],
      decoderInputIds: Array[Array[Int]],
      pixelValues: Array[Array[Array[Array[Array[Float]]]]],
      inferRequestLanguageModel: InferRequest,
      inferRequestVisionEmbeddingsModel: InferRequest,
      inferRequestTextEmbeddingsModel: InferRequest,
      inferRequestLMHeadModel: InferRequest,
      inferRequestMergeModel: InferRequest): Array[Array[Float]] = {

    val mergeRequest = openvinoWrapper.get.mergeModel.getCompiledModel().create_infer_request()
    val inputEmbeds = getMultimodalEmbeddings(
      encoderInputIds,
      decoderInputIds,
      pixelValues,
      inferRequestVisionEmbeddingsModel,
      inferRequestTextEmbeddingsModel,
      mergeRequest)
    val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
      if (encoderInputIds.head.length == decoderInputIds.head.length) {
        // First pass
        val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
        val posIdsLong = decoderInputIds.flatMap { tokenIds =>
          tokenIds.zipWithIndex.map { case (_, i) =>
            i.toLong
          }
        }
        (inpIdsLong, posIdsLong)
      } else {
        // Subsequent passes
        val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong }
        val posIdsLong = decoderInputIds.map { tokenIds =>
          tokenIds.zipWithIndex.map { case (_, i) =>
            i.toLong
          }.last
        }
        (inpIdsLong, posIdsLong)
      }
    val attentionMask: Array[Long] =
      decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }

    val batchSize: Int = decoderInputIds.length
    val beamIdx: Array[Int] = new Array[Int](batchSize)
    val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)

    val decoderAttentionMask: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask)
    val decoderPositionIDs: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
    val beamIdxTensor: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(Array(batchSize), beamIdx)

    inferRequestLanguageModel.set_tensor("inputs_embeds", inputEmbeds)
    inferRequestLanguageModel.set_tensor("attention_mask", decoderAttentionMask)
    inferRequestLanguageModel.set_tensor("position_ids", decoderPositionIDs)
    inferRequestLanguageModel.set_tensor("beam_idx", beamIdxTensor)

    inferRequestLanguageModel.infer()

    val result = inferRequestLanguageModel.get_tensor("last_hidden_state")

    inferRequestLMHeadModel.set_input_tensor(result)
    inferRequestLMHeadModel.infer()

    val logits = inferRequestLMHeadModel.get_output_tensor()

    val logitsRaw = logits.data()

    val sequenceLength = inputIdsLong.length / batchSize
    val decoderOutputs = (0 until batchSize).map(i => {
      logitsRaw
        .slice(
          i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
          i * sequenceLength * vocabSize + sequenceLength * vocabSize)
    })
    decoderOutputs.toArray
  }

  def generateImage(
      encoderInputIds: Array[Array[Int]],
      decoderInputIds: Array[Array[Int]],
      parallelSize: Int = 1,
      patchSize: Int = 16,
      imageSize: Int = preprocessor.size,
      randomSeed: Option[Long] = None,
      inferRequestTextEmbeddingsModel: InferRequest,
      inferRequestGenEmbeddingsModel: InferRequest,
      inferRequestGenHeadModel: InferRequest,
      inferRequestLanguageModel: InferRequest,
      inferRequestGenDecoderModel: InferRequest): Array[BufferedImage] = {

    val generatedTokens = getImageModelOutputs(
      encoderInputIds,
      decoderInputIds,
      randomSeed,
      inferRequestTextEmbeddingsModel,
      inferRequestGenEmbeddingsModel,
      inferRequestGenHeadModel,
      inferRequestLanguageModel)

    inferRequestGenDecoderModel.set_tensor(
      "code_b",
      new org.intel.openvino.Tensor(
        Array(generatedTokens.length, generatedTokens.head.length),
        generatedTokens.flatten.map(_.toLong)))

    inferRequestGenDecoderModel.set_tensor(
      "shape",
      new org.intel.openvino.Tensor(
        Array(4),
        Array(parallelSize, 8, imageSize / patchSize, imageSize / patchSize).map(_.toLong)))

    inferRequestGenDecoderModel.infer()

    val dec = inferRequestGenDecoderModel.get_output_tensor()

    val decShape = dec.get_shape()
    val decChannelsLast = transposeArray(dec.data(), decShape, Array(0, 2, 3, 1))

    val decChannelsLastReshaped =
      reshape4D(decChannelsLast, decShape(0), decShape(2), decShape(3), decShape(1))

    val decClipped: Array[Array[Array[Array[Int]]]] = decChannelsLastReshaped.map { x =>
      x.map { y =>
        y.map { z =>
          z.map { w =>
            Math.min(Math.max(((w + 1) / 2) * 255, 0), 255).toInt
          }
        }
      }
    }

    // convert each image to a BufferedImage
    val bufferedImages = decClipped.map { img =>
      ImageIOUtils.arrayToBufferedImage(img)
    }
    bufferedImages
  }

  def getImageModelOutputs(
      encoderInputIds: Array[Array[Int]],
      decoderInputIds: Array[Array[Int]],
      randomSeed: Option[Long] = None,
      inferRequestTextEmbeddingsModel: InferRequest,
      inferRequestGenEmbeddingsModel: InferRequest,
      inferRequestGenHeadModel: InferRequest,
      inferRequestLanguageModel: InferRequest): Array[Array[Int]] = {

    var generatedTokens: Array[Array[Int]] = Array()
    var nextInputEmbedsTensor: Option[org.intel.openvino.Tensor] = None
    var decoderInputIdsCopied = decoderInputIds.clone()
    // run the model for imageTokenLength times
    for (i <- 0 until imageTokenLength) {
      val nextTokenIds = getNextImageTokens(
        encoderInputIds,
        decoderInputIdsCopied,
        cfgWeight = 5.0f,
        temperature = 1.0f,
        randomSeed = randomSeed,
        inputEmbeds = nextInputEmbedsTensor,
        inferRequestTextEmbeddingsModel,
        inferRequestGenHeadModel,
        inferRequestLanguageModel)
      val nextTokenIdsTensor = new org.intel.openvino.Tensor(
        Array(nextTokenIds.length * 2),
        nextTokenIds.flatMap(x => Array(x, x)).map(_.toLong))

      inferRequestGenEmbeddingsModel.set_input_tensor(nextTokenIdsTensor)
      inferRequestGenEmbeddingsModel.infer()

      val imageEmbeddings = inferRequestGenEmbeddingsModel.get_output_tensor()

      nextInputEmbedsTensor = None
      nextInputEmbedsTensor = Some(
        new org.intel.openvino.Tensor(
          Array(imageEmbeddings.get_shape()(0), 1, imageEmbeddings.get_shape()(1)),
          imageEmbeddings.data()))

      if (generatedTokens.isEmpty) {
        generatedTokens = nextTokenIds.map(Array(_))
      } else {
        generatedTokens =
          generatedTokens.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) =>
            currentIds ++ Array(nextId)
          }
      }

      // repeat the nextTokenIds twice and add them to the decoder input ids
      val repeatedNextTokenIds = nextTokenIds.flatMap(x => Array(x, x))

      // extend decoder input ids to include the generated tokens. Decoder input ids are duplicated for each image
      decoderInputIdsCopied =
        decoderInputIdsCopied.zip(repeatedNextTokenIds).map { case (currentIds, nextId) =>
          currentIds ++ Array(nextId)
        }
    }
    generatedTokens
  }

  private def getNextImageTokens(
      encoderInputIds: Array[Array[Int]],
      decoderInputIds: Array[Array[Int]],
      cfgWeight: Float = 5.0f,
      temperature: Float = 1.0f,
      randomSeed: Option[Long] = None,
      inputEmbeds: Option[Tensor],
      inferRequestTextEmbeddingsModel: InferRequest,
      inferRequestGenHeadModel: InferRequest,
      inferRequestLanguageModel: InferRequest): Array[Int] = {

    val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
      if (encoderInputIds.head.length == decoderInputIds.head.length) {
        // First pass
        val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
        val posIdsLong = decoderInputIds.flatMap { tokenIds =>
          tokenIds.zipWithIndex.map { case (_, i) =>
            i.toLong
          }
        }
        (inpIdsLong, posIdsLong)
      } else {
        // Subsequent passes
        val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong }
        val posIdsLong = decoderInputIds.map { tokenIds =>
          tokenIds.zipWithIndex.map { case (_, i) =>
            i.toLong
          }.last
        }
        (inpIdsLong, posIdsLong)
      }
    val attentionMask: Array[Long] =
      decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }

    val batchSize: Int = decoderInputIds.length
    val beamIdx: Array[Int] = new Array[Int](batchSize)
    val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)

    val decoderAttentionMask: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask)
    val decoderPositionIDs: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
    val beamIdxTensor: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(Array(batchSize), beamIdx)

    val inputEmbedsTensor: org.intel.openvino.Tensor = if (inputEmbeds.isDefined) {
      inputEmbeds.get
    } else {
      val inputIdsLongTensor: org.intel.openvino.Tensor =
        new org.intel.openvino.Tensor(shape, inputIdsLong)
      inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor)
      inferRequestTextEmbeddingsModel.infer()

      val textEmbeddings = inferRequestTextEmbeddingsModel.get_output_tensor()
      textEmbeddings
    }

    inferRequestLanguageModel.set_tensor("inputs_embeds", inputEmbedsTensor)
    inferRequestLanguageModel.set_tensor("attention_mask", decoderAttentionMask)
    inferRequestLanguageModel.set_tensor("position_ids", decoderPositionIDs)
    inferRequestLanguageModel.set_tensor("beam_idx", beamIdxTensor)

    inferRequestLanguageModel.infer()

    val result = inferRequestLanguageModel.get_tensor("last_hidden_state")
    val resultShape = result.get_shape()
    // select the last hidden state
    // (2*parallel_images, sequence_length, hidden_size)
    // Reshape the tensor
    val reshapedArray: Array[Array[Array[Float]]] =
      reshape3D(result.data(), resultShape(0), resultShape(1), resultShape(2))
    val lastResult = reshapedArray.map { x =>
      x(resultShape(1) - 1)
    }.toArray
    val lastResultTensor =
      new org.intel.openvino.Tensor(Array(resultShape(0), resultShape(2)), lastResult.flatten)

    inferRequestGenHeadModel.set_input_tensor(lastResultTensor)
    inferRequestGenHeadModel.infer()

    val logits = inferRequestGenHeadModel.get_output_tensor()
    val logitsShape = logits.get_shape()

    val logitsRaw = logits.data()
    val reshapedLogits: Array[Array[Float]] =
      reshape2D(logitsRaw, logitsShape(0), logitsShape(1))
    // every second element starting from 0 to the end will be the conditional logits\
    val logitCond = reshapedLogits.zipWithIndex.filter(_._2 % 2 == 0).map(_._1)
    // every second element starting from 1 to the end will be the unconditional logits
    val logitUncond = reshapedLogits.zipWithIndex.filter(_._2 % 2 == 1).map(_._1)

    val logitDiff = logitCond.zip(logitUncond).map { case (cond, uncond) =>
      cond.zip(uncond).map { case (c, u) =>
        u + cfgWeight * (c - u)
      }
    }

    val probs = logitDiff.map(softmax)
    val nextTokenIds = multinomial(probs, numSamples = 1, seed = randomSeed)
    // pick a random token from the nextTokenIds
//    val randomIndex = new Random()
//    nextTokenIds.map(x => x(randomIndex.nextInt(x.length)))
    nextTokenIds.map(_.head)

  }

  private def multinomial(
      probs: Array[Array[Float]],
      numSamples: Int = 1,
      seed: Option[Long] = None): Array[Array[Int]] = {
    val random = seed.map(s => new Random(s)).getOrElse(new Random())
    probs.map { p =>
      require(p.nonEmpty, "Probability array cannot be empty")
      require(p.forall(_ >= 0.0f), "Probabilities must be non-negative")
      require(Math.abs(p.sum - 1.0f) < 1e-3, "Probabilities must sum to approximately 1.0")
      require(p.exists(_ > 0.0f), "Probability array cannot contain all zeros")

      val cumSum = p.scanLeft(0.0f)(_ + _).drop(1)

      (0 until numSamples).map { _ =>
        val rand = Math.nextAfter(random.nextFloat(), Float.PositiveInfinity)
        cumSum.indexWhere(_ > rand) match {
          case -1 => cumSum.length - 1 // Ensure a valid index is always chosen
          case idx => idx
        }
      }.toArray
    }.toArray
  }

  private def argmax(scores: Array[Float]): Int =
    scores.zipWithIndex.maxBy { case (score, _) =>
      score
    }._2

  private def greedyGenerationFinished(
      decoderIds: Seq[Array[Int]],
      eosTokenId: Int,
      maxOutputLength: Int): Boolean = {
    if (decoderIds.isEmpty) {
      false
    } else {
      decoderIds.forall { ids =>
        ids.length >= maxOutputLength || ids.last == eosTokenId
      }
    }
  }

  def getResizeSizes(
      width: Int,
      height: Int,
      minSize: Int = 14,
      imageSize: Int = 384): (Int, Int) = {
    val maxSize = math.max(width, height)
    (
      math.max((height.toFloat / maxSize * imageSize).toInt, minSize),
      math.max((width.toFloat / maxSize * imageSize).toInt, minSize))
  }

  def expandToSquare(img: BufferedImage, r: Int, g: Int, b: Int): BufferedImage = {
    val backgroundColor = new Color(r, g, b)
    val width = img.getWidth
    val height = img.getHeight

    if (width == height) {
      img
    } else {
      val size = Math.max(width, height)
      val squaredImage = new BufferedImage(size, size, img.getType)
      val g2d: Graphics2D = squaredImage.createGraphics()

      // Fill the background
      g2d.setColor(backgroundColor)
      g2d.fillRect(0, 0, size, size)

      // Calculate the position to center the original image
      val x = if (width < height) (size - width) / 2 else 0
      val y = if (height < width) (size - height) / 2 else 0

      // Draw the original image onto the new square image
      g2d.drawImage(img, x, y, null)
      g2d.dispose()

      squaredImage
    }
  }
  private def encodeImage(
      annotations: Array[AnnotationImage],
      preprocessor: Preprocessor): Array[Array[Array[Array[Array[Float]]]]] = {

    val batchProcessedImages = annotations.map { annot =>
      val bufferedImage = ImageIOUtils.byteToBufferedImage(
        bytes = annot.result,
        w = annot.width,
        h = annot.height,
        nChannels = annot.nChannels)

      val (resize_height, resize_width): (Int, Int) = getResizeSizes(
        width = bufferedImage.getWidth,
        height = bufferedImage.getHeight,
        imageSize = preprocessor.size)

      val resizedImage = if (preprocessor.do_resize) {
        ImageResizeUtils.resizeBufferedImage(
          width = resize_height,
          height = resize_width,
          preprocessor.resample)(bufferedImage)
      } else bufferedImage

      val resizedImageSquare = expandToSquare(
        resizedImage,
        (preprocessor.image_mean(0) * 255).toInt,
        (preprocessor.image_mean(1) * 255).toInt,
        (preprocessor.image_mean(2) * 255).toInt)

      val normalizedImage =
        ImageResizeUtils.normalizeAndConvertBufferedImage(
          img = resizedImageSquare,
          mean = preprocessor.image_mean,
          std = preprocessor.image_std,
          doNormalize = preprocessor.do_normalize,
          doRescale = preprocessor.do_rescale,
          rescaleFactor = preprocessor.rescale_factor)

      Array(normalizedImage)
    }

    batchProcessedImages

  }

  def getMultimodalEmbeddings(
      encoderInputIds: Array[Array[Int]],
      decoderInputIds: Array[Array[Int]],
      pixelValues: Array[Array[Array[Array[Array[Float]]]]],
      inferRequestVisionEmbeddingsModel: InferRequest,
      inferRequestTextEmbeddingsModel: InferRequest,
      inferRequestMergeModel: InferRequest): org.intel.openvino.Tensor = {
    val inputIdsLong: Array[Long] =
      if (encoderInputIds.head.length == decoderInputIds.head.length) {
        // First pass
        val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }

        inpIdsLong
      } else {
        // Subsequent passes
        val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong }
        inpIdsLong
      }
    val batchSize: Int = decoderInputIds.length
    val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)
    val inputIdsLongTensor: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(shape, inputIdsLong)

    val imageEmbeddings: org.intel.openvino.Tensor =
      if (encoderInputIds.head.length == decoderInputIds.head.length) {
        val pixelValuesTensor: org.intel.openvino.Tensor =
          new org.intel.openvino.Tensor(
            Array(
              pixelValues.length,
              pixelValues.head.length,
              pixelValues.head.head.length,
              pixelValues.head.head.head.length,
              pixelValues.head.head.head.head.length),
            pixelValues.flatten.flatten.flatten.flatten.map(_.toFloat))

        // Get image embeddings
        inferRequestVisionEmbeddingsModel.set_input_tensor(pixelValuesTensor)

        inferRequestVisionEmbeddingsModel.infer()

        val imageEmbeddings = inferRequestVisionEmbeddingsModel.get_output_tensor()

        // Get text embeddings
        inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor)

        inferRequestTextEmbeddingsModel.infer()

        val textEmbeddings = inferRequestTextEmbeddingsModel.get_output_tensor()

        // Merge image and text embeddings
        inferRequestMergeModel.set_tensor("vision_embeds", imageEmbeddings)
        inferRequestMergeModel.set_tensor("inputs_embeds", textEmbeddings)
        inferRequestMergeModel.set_tensor("input_ids", inputIdsLongTensor)

        inferRequestMergeModel.infer()

        inferRequestMergeModel.get_tensor("final_embeddings")
      } else {
        // Get text embeddings
        inferRequestTextEmbeddingsModel.set_input_tensor(inputIdsLongTensor)

        inferRequestTextEmbeddingsModel.infer()

        val textEmbeddings = inferRequestTextEmbeddingsModel.get_output_tensor()

        textEmbeddings
      }
    imageEmbeddings
  }

  def softmax(logitValues: Array[Float]): Array[Float] = {
    val maxLogit = logitValues.max
    val logitsExp = logitValues.map(l => Math.exp(l - maxLogit))
    val expSum = logitsExp.sum
    logitsExp.map(exp => (exp / expSum).toFloat)
  }

  // logSoftmax
  def logSoftmax(logitValues: Array[Float]): Array[Float] = {
    val maxLogit = logitValues.max
    val logitsExp = logitValues.map(l => Math.exp(l - maxLogit))
    val expSum = logitsExp.sum
    val logSumExp = Math.log(expSum)
    logitValues.map(l => l - maxLogit - logSumExp).map(_.toFloat)
  }

  // Function to reshape the flattened array
  def reshapeArray(flatArray: Array[Float], shape: Array[Int]): Any = {
    require(flatArray.length == shape.product, "Shape does not match data length")

    def recursiveReshape(data: Array[Float], shape: List[Int]): Any = shape match {
      case Nil => data.head // Base case: return a single element
      case head :: Nil => data.grouped(head).toArray.asInstanceOf[Array[Any]] // 1D array
      case head :: tail =>
        data
          .grouped(head)
          .map(subArr => recursiveReshape(subArr, tail))
          .toArray
          .asInstanceOf[Array[Any]] // Cast to Array[Any] to preserve structure
    }

    recursiveReshape(flatArray, shape.toList).asInstanceOf[Array[Any]]
  }

  def reshape2D(data: Array[Float], rows: Int, cols: Int): Array[Array[Float]] = {
//    data.grouped(cols).toArray.map(_.toArray)
//    i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
//    i * sequenceLength * vocabSize + sequenceLength * vocabSize)
    0.until(rows)
      .map { i =>
        data.slice(i * cols, (i + 1) * cols)
      }
      .toArray
  }

  def reshape3D(
      data: Array[Float],
      depth: Int,
      rows: Int,
      cols: Int): Array[Array[Array[Float]]] = {
//    data.grouped(rows * cols).toArray.map { slice =>
//      reshape2D(slice, rows, cols)
//    }
    // use the depth to slice the data
    0.until(depth)
      .map { i =>
        data.slice(i * rows * cols, (i + 1) * rows * cols)
      }
      .map { slice =>
        reshape2D(slice, rows, cols)
      }
      .toArray
  }

  def reshape4D(
      data: Array[Float],
      batch: Int,
      depth: Int,
      rows: Int,
      cols: Int): Array[Array[Array[Array[Float]]]] = {
    data.grouped(depth * rows * cols).toArray.map { slice =>
      reshape3D(slice, depth, rows, cols)
    }
  }

  def transposeArray[T: ClassTag](
      inputArray: Array[T],
      inputArrayShape: Array[Int],
      axes: Array[Int]): Array[T] = {
    require(
      inputArrayShape.length == axes.length,
      "Axes must have the same length as the shape dimensions")

    val outputShape = axes.map(inputArrayShape(_))
    val size = inputArray.length
    val inputStrides = inputArrayShape.scanRight(1)(_ * _).tail
    val outputStrides = outputShape.scanRight(1)(_ * _).tail

    def getTransposedIndex(index: Int): Int = {
      val originalIndices =
        inputArrayShape.indices.map(i => (index / inputStrides(i)) % inputArrayShape(i))
      val transposedIndices = axes.map(originalIndices)
      transposedIndices.zip(outputStrides).map { case (idx, stride) => idx * stride }.sum
    }

    val outputArray = new Array[T](size)
    for (i <- inputArray.indices) {
      outputArray(getTransposedIndex(i)) = inputArray(i)
    }
    outputArray
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy