All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.onnx.ONNXUtils.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.onnx

import ai.onnx.proto.OnnxMl.{GraphProto, ModelProto, NodeProto, ValueInfoProto}
import ai.onnxruntime._
import shade.com.google.protobuf.ProtocolStringList
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.types._

import java.nio._
import java.util
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters.mapAsScalaMapConverter
import scala.reflect.ClassTag

object ONNXUtils {
  private[onnx] def mapOnnxJavaTypeToDataType(javaType: OnnxJavaType): DataType = {
    javaType match {
      case OnnxJavaType.INT8 => ByteType
      case OnnxJavaType.INT16 => ShortType
      case OnnxJavaType.INT32 => IntegerType
      case OnnxJavaType.INT64 => LongType
      case OnnxJavaType.FLOAT => FloatType
      case OnnxJavaType.DOUBLE => DoubleType
      case OnnxJavaType.BOOL => BooleanType
      case OnnxJavaType.STRING => StringType
      case OnnxJavaType.UNKNOWN => BinaryType
    }
  }

  private[onnx] def mapTensorInfoToDataType(tensorInfo: TensorInfo): DataType = {
    val dataType = mapOnnxJavaTypeToDataType(tensorInfo.`type`)

    def nestedArrayType(depth: Int, dataType: DataType): ArrayType = {
      if (depth == 1)
        ArrayType(dataType)
      else
        ArrayType(nestedArrayType(depth - 1, dataType))
    }

    if (tensorInfo.isScalar) {
      dataType
    } else if (tensorInfo.getShape.length == 1) {
      // first dimension is assumed to be batch size.
      dataType
    } else {
      nestedArrayType(tensorInfo.getShape.length - 1, dataType)
    }
  }

  @tailrec
  private[onnx] def mapValueInfoToDataType(valueInfo: ValueInfo): DataType = {
    valueInfo match {
      case mapInfo: MapInfo =>
        val keyDataType = mapOnnxJavaTypeToDataType(mapInfo.keyType)
        val valueDataType = mapOnnxJavaTypeToDataType(mapInfo.valueType)
        MapType(keyDataType, valueDataType)
      case seqInfo: SequenceInfo =>
        if (seqInfo.sequenceOfMaps) {
          mapValueInfoToDataType(seqInfo.mapInfo)
        } else {
          mapOnnxJavaTypeToDataType(seqInfo.sequenceType)
        }
      case tensorInfo: TensorInfo =>
        mapTensorInfoToDataType(tensorInfo)
    }
  }

  private[onnx] def mapOnnxValueToArray(value: OnnxValue): Seq[Any] = {
    value.getInfo match {
      case tensorInfo: TensorInfo =>
        if (tensorInfo.isScalar)
          Seq(value.getValue)
        else {
          value.getValue.asInstanceOf[Array[_]].toSeq
        }
      case sequenceInfo: SequenceInfo =>
        if (sequenceInfo.sequenceOfMaps) {
          value.getValue.asInstanceOf[java.util.List[java.util.Map[_, _]]]
            .asScala.toArray.map(_.asScala.toMap)
        } else {
          value.getValue.asInstanceOf[java.util.List[_]].asScala
        }
      case _: MapInfo =>
        Array(value.getValue.asInstanceOf[java.util.Map[_, _]].asScala.toMap)
    }
  }

  private[onnx] def createTensor(env: OrtEnvironment, tensorInfo: TensorInfo, batchedValues: Seq[_]): OnnxTensor = {
    val shape: Array[Long] = tensorInfo.getShape
    shape(0) = batchedValues.length
    if (shape.count(_ == -1) > 1) {
      throw new Exception(s"The input tensor has shape [${shape.mkString(",")}], " +
        s"but -1 is only allowed for at most one dimension. " +
        s"If the array size in each row can vary, either pass in one row at a time, " +
        s"or set the mini batch size to 1.")
    }

    val inferredShape = validateBatchShapes(batchedValues, shape)
    loadTensorBuffer(env, tensorInfo, batchedValues, inferredShape)
  }

  private def validateBatchShapes(batchedValues: Seq[_], expectedShape: Array[Long]): Array[Long] = {
    // Validate input shape based on first sequence in each parent
    @tailrec
    def validateOneShape(nestedSeq: Seq[_], currentSize: Array[Long], expectedShape: Array[Long]): Array[Long] = {
      if (nestedSeq.isEmpty) {
        throw new IllegalArgumentException("Input element dimension is empty")
      }

      val currentDim = currentSize.length
      if (expectedShape(currentDim) == -1) {
        // If the current dimension is variable length, fill in with actual length
        expectedShape(currentDim) = nestedSeq.length.toLong
      }

      val newSize = currentSize :+ nestedSeq.length.toLong
      nestedSeq.head match {
        case s: Seq[_] => validateOneShape(s, newSize, expectedShape)
        case _ =>
          if (!util.Arrays.equals(newSize, expectedShape)) {
            throw new IllegalArgumentException(
              s"Input element does not match input tensor shape [${expectedShape.mkString(",")}]." +
                s" Found shape [${newSize.mkString(",")}]. Consider setting mini batch size to 1.")
          } else {
            expectedShape
          }
      }
    }

    val inferredShapes: Seq[Array[Long]] = batchedValues.map {
      case s: Seq[_] => validateOneShape(s, Array[Long](), expectedShape.tail)
      case _ => Array.empty[Long]
    }

    inferredShapes.filterNot(_.sameElements(inferredShapes.head)).headOption.foreach {
      shape =>
        throw new Exception("Each element in the input batch must have the same shape. " +
          s"The head of the batch has shape ${inferredShapes.head.mkString(",")}, but detected an " +
          s"element with shape ${shape.mkString(",")}")
    }

    // batch size         shape for inner elements
    inferredShapes.length.toLong +: inferredShapes.head
  }

  // scalastyle:off cyclomatic.complexity
  private[onnx] def loadTensorBuffer(env: OrtEnvironment,
                                     tensorInfo: TensorInfo,
                                     batchedValues: Seq[_],
                                     shape: Array[Long]): OnnxTensor = {
    val size = shape.product.toInt
    tensorInfo.`type` match {
      case OnnxJavaType.FLOAT =>
        val buffer = FloatBuffer.allocate(size)
        val actualCount = writeNestedSeqToBuffer[Float](batchedValues, buffer.put(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape)
      case OnnxJavaType.DOUBLE =>
        val buffer = DoubleBuffer.allocate(size)
        val actualCount = writeNestedSeqToBuffer[Double](batchedValues, buffer.put(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape)
      case OnnxJavaType.BOOL =>
        val buffer = ByteBuffer.allocateDirect(size)
        val bool2byte: Boolean => Byte = b => if (b) 1.toByte else 0.toByte
        val actualCount = writeNestedSeqToBuffer[Boolean](batchedValues, (bool2byte andThen buffer.put)(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape, OnnxJavaType.BOOL)
      case OnnxJavaType.INT8 =>
        val buffer = ByteBuffer.allocateDirect(size)
        val actualCount = writeNestedSeqToBuffer[Byte](batchedValues, buffer.put(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape)
      case OnnxJavaType.INT16 =>
        val buffer = ShortBuffer.allocate(size)
        val actualCount = writeNestedSeqToBuffer[Short](batchedValues, buffer.put(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape)
      case OnnxJavaType.INT32 =>
        val buffer = IntBuffer.allocate(size)
        val actualCount = writeNestedSeqToBuffer[Int](batchedValues, buffer.put(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape)
      case OnnxJavaType.INT64 =>
        val buffer = LongBuffer.allocate(size)
        val actualCount = writeNestedSeqToBuffer[Long](batchedValues, buffer.put(_))
        assertBufferElementsWritten(size, actualCount, shape)
        buffer.rewind()
        OnnxTensor.createTensor(env, buffer, shape)
      case OnnxJavaType.STRING =>
        val flattened = writeNestedSeqToStringBuffer(batchedValues, size).toArray
        OnnxTensor.createTensor(env, flattened, shape)
      case other =>
        throw new NotImplementedError(s"Tensor input type $other not supported. " +
          s"Only FLOAT, DOUBLE, BOOL, INT8, INT16, INT32, INT64, STRING types are supported.")
    }
  }

  private def assertBufferElementsWritten(expected: Long, actual: Long, shape: Array[Long]): Unit = {
    if (expected != actual) {
      throw new IllegalArgumentException(s"Expected $expected batch elements but found $actual." +
        s" Expected shape is [${shape.mkString}].")
    }
  }

  private def writeNestedSeqToBuffer[T: ClassTag](nestedSeq: Seq[_], bufferWrite: T => Unit): Long = {
    nestedSeq.foldLeft(0: Long) { (cur, element) => element match {
      case x: T =>
        bufferWrite(x)
        cur + 1
      case s: Seq[_] => cur + writeNestedSeqToBuffer(s, bufferWrite)
      case _ => cur + 0L
    }}
  }

  private def writeNestedSeqToStringBuffer(nestedSeq: Seq[_], size: Int): ArrayBuffer[String] = {
    var i = 0
    val buffer = ArrayBuffer.fill[String](size)("")

    def innerWrite(nestedSeq: Seq[_]): Unit = {
      nestedSeq.foreach {
        case x: String =>
          buffer.update(i, x)
          i = i + 1
        case s: Seq[_] =>
          innerWrite(s)
      }
    }

    innerWrite(nestedSeq)
    buffer
  }

  /**
    * Returns true if the two data types are compatible. They are compatible if they share the same "shape", and
    * 1. The element types from both sides are numeric types, or
    * 2. The element types from both sides are the same.
    */
  @tailrec
  private[onnx] def compatible(from: DataType, to: DataType): Boolean = {
    (from, to) match {
      case (VectorType, right: ArrayType) =>
        compatible(DoubleType, right.elementType)
      case (left: ArrayType, right: ArrayType) =>
        compatible(left.elementType, right.elementType)
      case (_: NumericType, _: NumericType) => true
      case (fromDataType, toDataType) => fromDataType == toDataType
    }
  }

  /*
   * Create a new ONNXModel from an existing model, but shrink to only include given outputs
   */
  private[onnx] def sliceModelAtOutputs(fullModel: ONNXModel, outputs: Array[String]): ONNXModel = {
    val fullProtobufModel = ModelProto.parseFrom(fullModel.getModelPayload)

    val (newNodes, newOutputs) = findUsedNodesForOutputs(fullProtobufModel, outputs)

    // Make a new model with the reduced set of nodes
    val slicedGraph = makeGraph(newNodes, newOutputs, fullProtobufModel.getGraph)
    val slicedProtobufModel = makeModel(slicedGraph, fullProtobufModel)

    // Return a new model with all the same parameters, but with a new protobuf model
    fullModel
      .copy(ParamMap.empty)
      .setModelPayload(slicedProtobufModel.toByteArray)
  }

  /*
   * Find all nodes required to produce the given outputs
   */
  private def findUsedNodesForOutputs(model: ModelProto,
                                      newOutputNames: Array[String]): (Array[NodeProto], Array[ValueInfoProto]) = {
    val graph = model.getGraph
    val nodes = graph.getNodeList.toArray.map(_.asInstanceOf[NodeProto])

    val allInternalOutputs = nodes.flatMap(node => node.getNodeOutputNames)
    newOutputNames.foreach(out => {
      if (!allInternalOutputs.contains(out)) throw new IllegalArgumentException(s"Unknown output: $out")
    })

    // This is an array which will track which nodes are needed in the new model
    val nodeUsageStatus = collection.mutable.Map[String, Boolean]()
    nodes.foreach(node => nodeUsageStatus(node.getName) = false)

    val outputNameToNodeMap = nodes.flatMap(node => node.getNodeOutputNames.map(name => name -> node)).toMap

    // Recursive method to traverse graph backwards, marking nodes as used or not
    @tailrec
    def markAsUsed(nodes: Seq[NodeProto], visitedNodes: mutable.HashSet[String]): Unit = {
      if (nodes.nonEmpty) {
        val head = nodes.head
        // If the node is already marked, skip it
        val nodesToCheck = if (!visitedNodes(head.getName)) {
          visitedNodes.add(head.getName)
          // This input might be an actual external input variable or initializer, which has no upstream node
          // Otherwise continue up the graph chain marking other upstream nodes as used.
          val upstreamNodes = head.getNodeInputNames.flatMap(outputNameToNodeMap.get).toSeq
          (nodes.tail ++ upstreamNodes).distinct
        } else nodes.tail

        markAsUsed(nodesToCheck, visitedNodes)
      }
    }

    def markAsUsedFrom(node: NodeProto): mutable.Set[String] = {
      val visited = new mutable.HashSet[String]()
      markAsUsed(Seq(node), visited)
      visited
    }

    // Starting at the outputs we wish to slice at, mark all nodes as needed or not recursively
    val usedNodes = newOutputNames.map(out => markAsUsedFrom(outputNameToNodeMap(out))).reduce(_ ++ _)
    val newNodes = nodes.filter(node => usedNodes(node.getName))
    val newOutputs = newOutputNames.map(out => ValueInfoProto.newBuilder().setName(out).build())

    (newNodes, newOutputs)
  }

  /*
   * Construct a GraphProto from a reference graph with a given set of nodes and outputs
   */
  private def makeGraph(nodes: Array[NodeProto],
                        outputs: Array[ValueInfoProto],
                        source: GraphProto): GraphProto = {
    val graph = GraphProto.newBuilder(source)
    graph.clearNode()
    nodes.foreach(node => graph.addNode(node))
    graph.clearOutput()
    outputs.foreach(output => graph.addOutput(output))
    graph.build()
  }

  /*
   * Construct a ModelProto from a given graph and reference model
   */
  private def makeModel(graph: GraphProto, source: ModelProto): ModelProto = {
    val model = ModelProto.newBuilder(source)
    model.setGraph(graph).build()
  }

  private implicit class AugmentedProtocolStringList(list: ProtocolStringList) {
    def toStringArray: Array[String] = {
      list.toArray().map(_.asInstanceOf[String])
    }
  }

  private implicit class AugmentedNodeProto(node: NodeProto) {
    def getNodeOutputNames: Array[String] = {
      node.getOutputList.toStringArray
    }

    def getNodeInputNames: Array[String] = {
      node.getInputList.toStringArray
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy