All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.onnx.OnnxSession.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2023 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.johnsnowlabs.ml.onnx

import ai.onnxruntime.OrtEnvironment
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader}
import org.slf4j.{Logger, LoggerFactory}

import java.io.Serializable

class OnnxSession extends Serializable {

  // Important for serialization on none-kyro serializers
  @transient val logger: Logger = LoggerFactory.getLogger("OnnxSession")

  def getSessionOptions: Map[String, String] = {
    val providers = OrtEnvironment.getAvailableProviders
    if (providers.toArray.map(x => x.toString).contains("CUDA")) {
      getCUDASessionConfig
    } else getCPUSessionConfig
  }

  private def getCUDASessionConfig: Map[String, String] = {
    val gpuDeviceId = ConfigLoader.getConfigIntValue(ConfigHelper.onnxGpuDeviceId)
    Map(ConfigHelper.onnxGpuDeviceId -> gpuDeviceId.toString)
  }

  private def getCPUSessionConfig: Map[String, String] = {
    val intraOpNumThreads =
      ConfigLoader.getConfigIntValue(ConfigHelper.onnxIntraOpNumThreads)
    val optimizationLevel =
      ConfigLoader.getConfigStringValue(ConfigHelper.onnxOptimizationLevel)
    val executionMode =
      ConfigLoader.getConfigStringValue(ConfigHelper.onnxExecutionMode)

    Map(ConfigHelper.onnxIntraOpNumThreads -> intraOpNumThreads.toString) ++
      Map(ConfigHelper.onnxOptimizationLevel -> optimizationLevel) ++
      Map(ConfigHelper.onnxExecutionMode -> executionMode)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy