All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.onnx.TensorResources.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.onnx

import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}

import scala.collection.mutable.ArrayBuffer

/** Class to manage the creation of ONNX Tensors (WIP).
  *
  * Tensors that do not belong to a [[OrtSession]], will need to be explicitly closed. This class
  * manages created tensors and can free all of them at once.
  *
  * The [[OrtEnvironment]] exists at most once, so all TensorResources will share the same one.
  */
class TensorResources extends AutoCloseable {
  private val tensors = ArrayBuffer[OnnxTensor]()
  private val env = OrtEnvironment.getEnvironment()
  def createTensor(data: Any): OnnxTensor = {
    val tensor = OnnxTensor.createTensor(env, data)
    tensors.append(tensor)
    tensor
  }

  def clearTensors(): Unit = {
    tensors.foreach(_.close())
    tensors.clear()
  }

  override def close(): Unit = clearTensors()
}

object TensorResources {

  def getOnnxTensor(result: OrtSession.Result, key: String): OnnxTensor =
    result.get(key).get().asInstanceOf[OnnxTensor]
  object implicits {
    implicit class OnnxSessionResult(result: OrtSession.Result) {
      def getOnnxTensor(key: String): OnnxTensor =
        TensorResources.getOnnxTensor(result, key)

      def getOnnxTensors(keys: Array[String]): Map[String, OnnxTensor] = {
        keys.map { key =>
          (key, TensorResources.getOnnxTensor(result, key))
        }.toMap
      }

      def getFloatArray(key: String): Array[Float] = Option(
        TensorResources.getOnnxTensor(result, key).getFloatBuffer) match {
        case Some(floats) =>
          val floatArray = floats.array()
          floatArray
        case None => throw new IllegalStateException("Could not extract floats from OnnxTensor.")
      }

    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy