All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ergentOrder.onnx-scala-backends_2.13.0.11.0.source-code.ORTOperatorBackend213.scala Maven / Gradle / Ivy

package org.emergentorder.onnx.backends

import java.nio._
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.language.implicitConversions
import scala.util.Using
import ai.onnxruntime._
import ai.onnxruntime.TensorInfo.OnnxTensorType._
import org.emergentorder.onnx._
import org.emergentorder.onnx.Tensors._
import org.emergentorder.onnx.Tensors.Tensor._
import ORTTensorUtils._

trait ORTOperatorBackend
    extends OpToONNXBytesConverter
    with AutoCloseable {

  //Java map performs better
  //val sessionCache = new java.util.HashMap[Integer, OrtSession]
  
  val env = OrtEnvironment.getEnvironment()

  def getSession(bytes: Array[Byte]) = { 

//    val session_options = new OrtSession.SessionOptions()
//    session_options.addDnnl(true)
    env.createSession(bytes) //, session_options)
  }

  def runModel[T, Ax <: Axes](
      sess: OrtSession,
      input_tensor_values: Array[OnnxTensor],
      inputNames: List[String],
      outputNames: List[String] 
  ): Tensor[T, Ax] = { 
    val inputs = (inputNames zip input_tensor_values).toMap.asJava

    //TODO: More outputs / handle via ONNXSequence / ONNXMap
      val output_tensor = sess.run(inputs)
      val firstOut = output_tensor.get(0).asInstanceOf[OnnxTensor]
      val shape = firstOut.getInfo.getShape.map(_.toInt)

      val result: Tensor[T, Ax] = Tensor.create(getArrayFromOnnxTensor[T](firstOut), shape).asInstanceOf[Tensor[T, Ax]] //dangerous
      result
  }
    
// def cachedSess(bytes: Array[Byte]) = sessionCache.computeIfAbsent(java.util.Arrays.hashCode(bytes), _ => getSession(bytes))

  def callByteArrayOp[
      T,
      Ax <: Axes
  ](
      opModel: Array[Byte],
      inputs: Seq[_]
  ): Tensor[T, Ax] = {
    val input_node_names = List("0", "1", "2", "3", "4", "5", "6", "7", "8")
    //TODO: more outputs
    val output_node_names = List("outName") 

    //TODO: don't mix up Options and Tensors here
    val inputTensors: Array[OnnxTensor] = inputs.toArray.map{elem =>
      elem match {
            case opt: Option[Tensor[T, Ax]] =>
              opt match{
                case Some(x) => Some(getOnnxTensor(data(x), shape(x), env))
                case None => None
              }
            case tens: Tensor[T, Ax] => Some(getOnnxTensor(data(tens), shape(tens), env))
          }
      }.flatten

      val res: Tensor[T, Ax] = Using.resource(getSession(opModel)) { sess =>
        runModel(
          sess, 
          inputTensors,
          input_node_names,
          output_node_names
        )
      }
        res
  } 

  def callOp[
      T, Ax <: Axes](
      name: String,
      opName: String,
      inputs: Seq[_],
      //    outName: String,
      attrs: Map[String, Any]
  ): Tensor[T, Ax] = {
    //TODO: prevent passing input to opToONNXBytes

    val bytes = opToONNXBytes(name, opName, inputs, "outName", attrs)
    callByteArrayOp[T, Ax](bytes,inputs)
  }

  override def close(): Unit = {
      env.close
//    super.close
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy