All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureManager.scala Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.tensorflow.sign

import org.slf4j.{Logger, LoggerFactory}
import org.tensorflow.SavedModelBundle
import org.tensorflow.proto.framework.TensorInfo
import org.tensorflow.proto.util.SaverDef

import java.util
import scala.util.matching.Regex

object ModelSignatureManager {

  val KnownProviders: Array[String] = Array("TF1", "TF2")

  private[ModelSignatureManager] val logger: Logger =
    LoggerFactory.getLogger("ModelSignatureManager")

  def apply(
      tfSignatureType: String = "TF1",
      tokenIdsValue: String = ModelSignatureConstants.InputIdsV1.value,
      maskIdsValue: String = ModelSignatureConstants.AttentionMaskV1.value,
      segmentIdsValue: String = ModelSignatureConstants.TokenTypeIdsV1.value,
      embeddingsValue: String = ModelSignatureConstants.LastHiddenStateV1.value,
      sentenceEmbeddingsValue: String = ModelSignatureConstants.PoolerOutputV1.value)
      : Map[String, String] =
    tfSignatureType.toUpperCase match {
      case "TF1" =>
        Map[String, String](
          ModelSignatureConstants.InputIds.key -> tokenIdsValue,
          ModelSignatureConstants.AttentionMask.key -> maskIdsValue,
          ModelSignatureConstants.TokenTypeIds.key -> segmentIdsValue,
          ModelSignatureConstants.LastHiddenState.key -> embeddingsValue,
          ModelSignatureConstants.PoolerOutput.key -> sentenceEmbeddingsValue)
      case _ => throw new Exception("Model provider not available.")
    }

  def getInputIdsKey: String = ModelSignatureConstants.InputIds.key

  def getInputIdsValue: String = ModelSignatureConstants.InputIds.value

  def getAttentionMaskIdsKey: String = ModelSignatureConstants.AttentionMask.key

  def getAttentionMaskIdsValue: String = ModelSignatureConstants.AttentionMask.value

  def getTokenTypeIdsKey: String = ModelSignatureConstants.TokenTypeIds.key

  def getTokenTypeIdsValue: String = ModelSignatureConstants.TokenTypeIds.value

  def getLastHiddenStateKey: String = ModelSignatureConstants.LastHiddenState.key

  def getLastHiddenStateValue: String = ModelSignatureConstants.LastHiddenState.value

  def getPoolerOutputKey: String = ModelSignatureConstants.PoolerOutput.key

  def getPoolerOutputValue: String = ModelSignatureConstants.PoolerOutput.value

  /** Return a formatted map of key -> value for model signature objects */
  def convertToAdoptedKeys(matched: Map[String, String]): Map[String, String] = {
    val SecondaryIndexSep = "::"
    matched
      .map { case (k, v) => k.split(SecondaryIndexSep)(1) -> v } // signature def name
      .map { case (k, v) => ModelSignatureConstants.toAdoptedKeys(k) -> v }
  }

  /** Extract signatures from actual model
    *
    * @param model
    *   : a SavedModelBundle object
    * @return
    *   a list of tuples of type (OperationType, key, TFInfoName)
    */
  def getSignaturesFromModel(model: SavedModelBundle): Map[String, String] = {
    import collection.JavaConverters._

    val InputPrefix = "input"
    val OutputPrefix = "output"
    val Sep = "::"

    val modelSignatures = scala.collection.mutable.Map.empty[String, String]

    /** Loop imperatively over signature definition to extract them in a map
      *
      * @param prefix
      *   : input or output attribute
      * @param signDefinitionsMap
      *   : Java signature definition map
      */
    def extractSignatureDefinitions(
        prefix: String,
        signDefinitionsMap: util.Map[String, TensorInfo]): Unit = {
      for (e <- signDefinitionsMap.entrySet.asScala) {

        val key: String = e.getKey
        val tfInfo: TensorInfo = e.getValue

        modelSignatures +=
          (s"$prefix$Sep$key$Sep${ModelSignatureConstants.Name.key}" ->
            tfInfo.getName)
        modelSignatures +=
          (s"$prefix$Sep$key$Sep${ModelSignatureConstants.DType.key}" ->
            tfInfo.getDtype.toString)
        modelSignatures +=
          (s"$prefix$Sep$key$Sep${ModelSignatureConstants.DimCount.key}" ->
            tfInfo.getTensorShape.getDimCount.toString)
        modelSignatures +=
          (s"$prefix$Sep$key$Sep${ModelSignatureConstants.ShapeDimList.key}" ->
            tfInfo.getTensorShape.getDimList.toString
              .replaceAll("\n", "")
              .replaceAll("size:", ""))
        modelSignatures +=
          (s"$prefix$Sep$key$Sep${ModelSignatureConstants.SerializedSize.key}" ->
            tfInfo.getName)
      }
    }

    if (model.metaGraphDef.hasGraphDef && model.metaGraphDef.getSignatureDefCount > 0) {
      for (sigDef <- model.metaGraphDef.getSignatureDefMap.values.asScala) {
        // extract input sign map
        extractSignatureDefinitions(InputPrefix, sigDef.getInputsMap)
        // extract output sign map
        extractSignatureDefinitions(OutputPrefix, sigDef.getOutputsMap)
      }
    }

    modelSignatures.toMap
  }

  /** Regex matcher */
  def findTFKeyMatch(candidate: String, pattern: Regex): Boolean = {
    val _value = candidate.split("::")(1) // i.e. input::input_ids::name
    val res = pattern findAllIn _value
    if (res.nonEmpty)
      true
    else
      false
  }

  /** Extract the model provider counting the signature pattern matches
    *
    * @param signDefNames
    *   : the candidate signature definitions inputs and outputs
    * @param modelProvider
    *   : the true model provider in between TF1 and TF2 to evaluate
    * @return
    *   : the model provider name in between TF1 and TF2
    */
  def classifyProvider(
      signDefNames: Map[String, String],
      modelProvider: Option[String] = None): String = {

    val versionMatchesCount = KnownProviders.map { provider =>
      provider -> {
        signDefNames.map { signName =>
          val patterns: Array[Regex] = ModelSignatureConstants.getSignaturePatterns(provider)
          val matches =
            (for (pattern <- patterns if findTFKeyMatch(signName._1, pattern)) yield 1).toList.sum
          matches
        }
      }.sum
    }.toMap

    val (topModelProvider, _) = versionMatchesCount.toSeq.maxBy(_._2)
    topModelProvider
  }

  /** Extract input and output signatures from TF saved models
    *
    * @param modelProvider
    *   model framework provider, i.e. TF1 or TF2, default TF1
    * @param model
    *   loaded SavedModelBundle
    * @return
    *   the list ot matching signatures as tuples
    */
  def extractSignatures(
      model: SavedModelBundle,
      saverDef: SaverDef): Option[Map[String, String]] = {

    val signatureCandidates = getSignaturesFromModel(model)
    val signDefNames: Map[String, String] =
      signatureCandidates.filterKeys(_.contains(ModelSignatureConstants.Name.key))

    val modelProvider = classifyProvider(signDefNames)

    val adoptedKeys = convertToAdoptedKeys(signDefNames) + (
      "filenameTensorName_" -> saverDef.getFilenameTensorName.replaceAll(":0", ""),
      "restoreOpName_" -> saverDef.getRestoreOpName.replaceAll(":0", ""),
      "saveTensorName_" -> saverDef.getSaveTensorName.replaceAll(":0", "")
    )

    Option(adoptedKeys)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy