![JAR search and dependency download from the Maven repository](/logo.png)
com.johnsnowlabs.ml.onnx.OnnxWrapper.scala Maven / Gradle / Ivy
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.onnx
import ai.onnxruntime.OrtSession.SessionOptions
import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel}
import ai.onnxruntime.providers.OrtCUDAProviderOptions
import ai.onnxruntime.{OrtEnvironment, OrtSession}
import com.johnsnowlabs.ml.util.LoadExternalModel
import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}
import java.io._
import java.nio.file.{Files, Paths}
import java.util.UUID
import scala.util.{Failure, Success, Try}
class OnnxWrapper(var modelFileName: Option[String] = None, var dataFileDirectory: Option[String])
extends Serializable {
/** For Deserialization */
def this() = {
this(null, null)
}
// Important for serialization on none-kyro serializers
@transient private var ortSession: OrtSession = _
@transient private var ortEnv: OrtEnvironment = _
def getSession(onnxSessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) =
this.synchronized {
if (ortSession == null && ortEnv == null) {
val modelFilePath = if (modelFileName.isDefined) {
SparkFiles.get(modelFileName.get)
} else {
throw new UnsupportedOperationException("modelFileName not defined")
}
val (session, env) =
OnnxWrapper.withSafeOnnxModelLoader(onnxSessionOptions, Some(modelFilePath))
ortEnv = env
ortSession = session
}
(ortSession, ortEnv)
}
def saveToFile(file: String, zip: Boolean = true): Unit = {
// 1. Create tmp director
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx")
.toAbsolutePath
.toString
val tmpModelFilePath = SparkFiles.get(modelFileName.get)
// 2. Zip folder
if (zip) ZipArchiveUtil.zip(tmpModelFilePath, file)
// 3. Remove tmp directory
FileHelper.delete(tmpFolder)
}
}
/** Companion object */
object OnnxWrapper {
private[OnnxWrapper] val logger: Logger = LoggerFactory.getLogger("OnnxWrapper")
// TODO: make sure this.synchronized is needed or it's not a bottleneck
private def withSafeOnnxModelLoader(
sessionOptions: Map[String, String],
onnxModelPath: Option[String] = None): (OrtSession, OrtEnvironment) =
this.synchronized {
val env = OrtEnvironment.getEnvironment()
val sessionOptionsObject = if (sessionOptions.isEmpty) {
new SessionOptions()
} else {
mapToSessionOptionsObject(sessionOptions)
}
if (onnxModelPath.isDefined) {
val session = env.createSession(onnxModelPath.get, sessionOptionsObject)
(session, env)
} else {
throw new UnsupportedOperationException("onnxModelPath not defined")
}
}
def read(
sparkSession: SparkSession,
modelPath: String,
zipped: Boolean = true,
useBundle: Boolean = false,
modelName: String = "model",
dataFileSuffix: Option[String] = Some("_data"),
onnxFileSuffix: Option[String] = None): OnnxWrapper = {
// 1. Create tmp folder
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx")
.toAbsolutePath
.toString
// 2. Unpack archive
val randomSuffix = generateRandomSuffix(onnxFileSuffix)
val folder =
if (zipped)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), randomSuffix)
else
modelPath
val onnxFile =
if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString
else Paths.get(folder, new File(folder).list().head).toString
var onnxDataFile: File = null
// see if the onnx model has a .onnx_data file
// get parent directory of onnx file if modelPath is a file
val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath
val onnxDataFileExist: Boolean = {
if (onnxFileSuffix.isDefined && dataFileSuffix.isDefined) {
val onnxDataFilePath = s"${onnxFileSuffix.get}_$modelName${dataFileSuffix.get}"
onnxDataFile = Paths.get(parentDir, onnxDataFilePath).toFile
onnxDataFile.exists()
} else false
}
if (onnxDataFileExist) {
sparkSession.sparkContext.addFile(onnxDataFile.toString)
}
sparkSession.sparkContext.addFile(onnxFile)
val onnxFileName = Some(new File(onnxFile).getName)
val dataFileDirectory = if (onnxDataFileExist) Some(onnxDataFile.toString) else None
val onnxWrapper = new OnnxWrapper(onnxFileName, dataFileDirectory)
onnxWrapper
}
private def generateRandomSuffix(fileSuffix: Option[String]): Option[String] = {
val randomSuffix = Some(LoadExternalModel.generateRandomString(10))
Some(s"${randomSuffix.get}${fileSuffix.getOrElse("")}")
}
private def mapToSessionOptionsObject(sessionOptions: Map[String, String]): SessionOptions = {
val providers = OrtEnvironment.getAvailableProviders
if (providers.toArray.map(x => x.toString).contains("CUDA")) {
mapToCUDASessionConfig(sessionOptions)
} else mapToCPUSessionConfig(sessionOptions)
}
private def mapToCUDASessionConfig(sessionOptionsMap: Map[String, String]): SessionOptions = {
logger.info("Using CUDA")
println("Using CUDA")
// it seems there is no easy way to use multiple GPUs
// at least not without using multiple threads
// TODO: add support for multiple GPUs
val gpuDeviceId = sessionOptionsMap(ConfigHelper.onnxGpuDeviceId).toInt
val sessionOptions = new OrtSession.SessionOptions()
logger.info(s"ONNX session option gpuDeviceId=$gpuDeviceId")
val cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId)
sessionOptions.addCUDA(cudaOpts)
sessionOptions
}
private def mapToCPUSessionConfig(sessionOptionsMap: Map[String, String]): SessionOptions = {
val defaultExecutionMode = ExecutionMode.SEQUENTIAL
val defaultOptLevel = OptLevel.ALL_OPT
def getOptLevel(optLevel: String): OptLevel = {
Try(OptLevel.valueOf(optLevel)) match {
case Success(value) => value
case Failure(_) => {
logger.warn(
s"Error while getting OptLevel, using default value: ${defaultOptLevel.name()}")
defaultOptLevel
}
}
}
def getExecutionMode(executionMode: String): ExecutionMode = {
Try(ExecutionMode.valueOf(executionMode)) match {
case Success(value) => value
case Failure(_) => {
logger.warn(
s"Error while getting Execution Mode, using default value: ${defaultExecutionMode.name()}")
defaultExecutionMode
}
}
}
logger.info("Using CPUs")
println("Using CPUs")
// TODO: the following configs can be tested for performance
// However, so far, they seem to be slower than the ones used
// opts.setIntraOpNumThreads(Runtime.getRuntime.availableProcessors())
// opts.setMemoryPatternOptimization(true)
// opts.setCPUArenaAllocator(false)
val intraOpNumThreads = sessionOptionsMap(ConfigHelper.onnxIntraOpNumThreads).toInt
val optimizationLevel = getOptLevel(sessionOptionsMap(ConfigHelper.onnxOptimizationLevel))
val executionMode = getExecutionMode(sessionOptionsMap(ConfigHelper.onnxExecutionMode))
val sessionOptions = new OrtSession.SessionOptions()
logger.info(s"ONNX session option intraOpNumThreads=$intraOpNumThreads")
sessionOptions.setIntraOpNumThreads(intraOpNumThreads)
logger.info(s"ONNX session option optimizationLevel=$optimizationLevel")
sessionOptions.setOptimizationLevel(optimizationLevel)
logger.info(s"ONNX session option executionMode=$executionMode")
sessionOptions.setExecutionMode(executionMode)
sessionOptions
}
case class EncoderDecoderWrappers(
encoder: OnnxWrapper,
decoder: OnnxWrapper,
decoderWithPast: OnnxWrapper)
case class DecoderWrappers(decoder: OnnxWrapper)
case class EncoderDecoderWithoutPastWrappers(encoder: OnnxWrapper, decoder: OnnxWrapper)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy