com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureConstants.scala Maven / Gradle / Ivy
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.tensorflow.sign
import scala.util.matching.Regex
/** Based on BERT SavedModel reference, for instance:
*
* signature_def['serving_default']: The given SavedModel SignatureDef contains the following
* input(s):
*
* inputs['attention_mask'] tensor_info: dtype: DT_INT32 shape: (-1, -1) name:
* serving_default_attention_mask:0
*
* inputs['input_ids'] tensor_info: dtype: DT_INT32 shape: (-1, -1) name:
* serving_default_input_ids:0
*
* inputs['token_type_ids'] tensor_info: dtype: DT_INT32 shape: (-1, -1) name:
* serving_default_token_type_ids:0
*
* The given SavedModel SignatureDef contains the following output(s):
*
* outputs['last_hidden_state'] tensor_info: dtype: DT_FLOAT shape: (-1, -1, 768) name:
* StatefulPartitionedCall:0
*
* outputs['pooler_output'] tensor_info: dtype: DT_FLOAT shape: (-1, 768) name:
* StatefulPartitionedCall:1
*
* Method name is: tensorflow/serving/predict*
*/
object ModelSignatureConstants {
sealed trait TFInfoDescriptor {
protected val key: String
}
case object Name extends TFInfoDescriptor {
override val key: String = "name"
}
case object DType extends TFInfoDescriptor {
override val key: String = "dtype"
}
case object DimCount extends TFInfoDescriptor {
override val key: String = "dim_count"
}
case object ShapeDimList extends TFInfoDescriptor {
override val key: String = "shape_dim_list"
}
case object SerializedSize extends TFInfoDescriptor {
override val key: String = "serialized_size"
}
sealed trait TFInfoNameMapper {
protected val key: String
protected val value: String
}
case object InputIds extends TFInfoNameMapper {
override val key: String = "input_ids"
override val value: String = "input_ids:0"
}
case object InputIdsV1 extends TFInfoNameMapper {
override val key: String = "input_ids"
override val value: String = "input_ids:0"
}
case object AttentionMask extends TFInfoNameMapper {
override val key: String = "attention_mask"
override val value: String = "attention_mask:0"
}
case object AttentionMaskV1 extends TFInfoNameMapper {
override val key: String = "attention_mask"
override val value: String = "input_mask:0"
}
case object TokenTypeIds extends TFInfoNameMapper {
override val key: String = "token_type_ids"
override val value: String = "token_type_ids:0"
}
case object TokenTypeIdsV1 extends TFInfoNameMapper {
override val key: String = "token_type_ids"
override val value: String = "segment_ids:0"
}
case object LastHiddenState extends TFInfoNameMapper {
override val key: String = "last_hidden_state"
override val value: String = "StatefulPartitionedCall:0"
}
case object LastHiddenStateV1 extends TFInfoNameMapper {
override val key: String = "last_hidden_state"
override val value: String = "sequence_output:0"
}
case object PoolerOutput extends TFInfoNameMapper {
override val key: String = "pooler_output"
override val value: String = "StatefulPartitionedCall:1"
}
case object PoolerOutputV1 extends TFInfoNameMapper {
override val key: String = "pooler_output"
override val value: String = "pooled_output:0"
}
case object EncoderInputIds extends TFInfoNameMapper {
override val key: String = "encoder_input_ids"
override val value: String = "encoder_encoder_input_ids:0"
}
case object EncoderAttentionMask extends TFInfoNameMapper {
override val key: String = "encoder_attention_mask"
override val value: String = "encoder_encoder_attention_mask:0"
}
case object EncoderOutput extends TFInfoNameMapper {
override val key: String = "last_hidden_state"
override val value: String = "StatefulPartitionedCall_1:0"
}
case object DecoderInputIds extends TFInfoNameMapper {
override val key: String = "decoder_input_ids"
override val value: String = "decoder_decoder_input_ids:0"
}
case object DecoderEncoderInputIds extends TFInfoNameMapper {
override val key: String = "encoder_state"
override val value: String = "decoder_encoder_state:0"
}
case object DecoderEncoderAttentionMask extends TFInfoNameMapper {
override val key: String = "decoder_encoder_attention_mask"
override val value: String = "decoder_decoder_encoder_attention_mask:0:0"
}
case object DecoderAttentionMask extends TFInfoNameMapper {
override val key: String = "decoder_attention_mask"
override val value: String = "decoder_decoder_attention_mask:0"
}
case object DecoderOutput extends TFInfoNameMapper {
override val key: String = "output_0"
override val value: String = "StatefulPartitionedCall:0"
}
case object LogitsOutput extends TFInfoNameMapper {
override val key: String = "logits"
override val value: String = "StatefulPartitionedCall:0"
}
case object EndLogitsOutput extends TFInfoNameMapper {
override val key: String = "end_logits"
override val value: String = "StatefulPartitionedCall:0"
}
case object StartLogitsOutput extends TFInfoNameMapper {
override val key: String = "start_logits"
override val value: String = "StatefulPartitionedCall:1"
}
case object TapasLogitsOutput extends TFInfoNameMapper {
override val key: String = "logits"
override val value: String = "StatefulPartitionedCall:0"
}
case object TapasLogitsAggregationOutput extends TFInfoNameMapper {
override val key: String = "logits_aggregation"
override val value: String = "StatefulPartitionedCall:1"
}
case object PixelValuesInput extends TFInfoNameMapper {
override val key: String = "pixel_values"
override val value: String = "pixel_values:0"
}
case object AudioValuesInput extends TFInfoNameMapper {
override val key: String = "input_values"
override val value: String = "input_values:0"
}
/** Retrieve signature patterns for a given provider
*
* @param modelProvider
* : the provider library that built the model and the signatures
* @return
* reference keys array of signature patterns for a given provider
*/
def getSignaturePatterns(modelProvider: String): Array[Regex] = {
val referenceKeys = modelProvider match {
case "TF1" =>
Array(
"(input)(.*)(ids)".r,
"(input)(.*)(mask)".r,
"(segment)(.*)(ids)".r,
"(sequence)(.*)(output)".r,
"(pooled)(.*)(output)".r)
case "TF2" =>
Array(
// first possible keys
"(input_word)(.*)(ids)".r,
"(input)(.*)(mask)".r,
"(input_type)(.*)(ids)".r,
"(bert)(.*)(encoder)".r,
// second possible keys
"(input)(.*)(ids)".r,
"(attention)(.*)(mask)".r,
"(token_type)(.*)(ids)".r,
"(last_hidden_state)".r)
case _ =>
throw new Exception("Unknown model provider! Please provide one in between TF1 or TF2.")
}
referenceKeys
}
/** Convert signatures key names to normalize naming conventions */
def toAdoptedKeys(keyName: String): String = {
keyName match {
case "input_ids" | "input_word_ids" => InputIds.key
case "input_mask" | "attention_mask" => AttentionMask.key
case "segment_ids" | "input_type_ids" | "token_type_ids" => TokenTypeIds.key
case "sequence_output" | "bert_encoder" | "last_hidden_state" => LastHiddenState.key
case "pooler_output" => PoolerOutput.key
case k => k
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy