All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.pipeline.inference.OpenVINOModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.pipeline.inference

import java.io.{File, IOException}
import java.nio.file.{Files, Paths}
import java.util.{ArrayList, Arrays, UUID, List => JList}

import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.zoo.pipeline.api.net.{NetUtils, RegistryMap, SerializationHolder}
import com.intel.analytics.zoo.pipeline.inference.DeviceType.DeviceTypeEnumVal
import com.intel.analytics.zoo.pipeline.inference.OpenVINOModel.OpenVINOModelHolder
import org.apache.commons.io.FileUtils
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._

class OpenVINOModel(var modelHolder: OpenVINOModelHolder,
                    var isInt8: Boolean,
                    var batchSize: Int = -1,
                    var deviceType: DeviceTypeEnumVal = DeviceType.CPU)
  extends AbstractModel with InferenceSupportive with Serializable {

  private var isRelease: Boolean = false

  @transient
  private lazy val supportive: OpenVinoInferenceSupportive = {
    OpenVINOModel.logger.info("Prepare OpenVinoInferenceSupportive")
    OpenVinoInferenceSupportive.forceLoad()
    new OpenVinoInferenceSupportive()
  }

  @transient
  private lazy val executableNetworkReference: Long = {
    OpenVINOModel.logger.info("Lazy loading OpenVINO model")
    var nativeRef = -1L
    try {
      val modelFile = File.createTempFile("OpenVINO", ".xml")
      Files.write(Paths.get(modelFile.toURI), modelHolder.modelBytes)
      val weightFile = File.createTempFile("OpenVINO", ".bin")
      Files.write(Paths.get(weightFile.toURI), modelHolder.weightBytes)

      nativeRef = if (isInt8) {
        OpenVINOModel.logger.info(s"Load int8 model")
        supportive.loadOpenVinoIRInt8(modelFile.getAbsolutePath,
          weightFile.getAbsolutePath,
          deviceType.value, batchSize)
      } else {
        OpenVINOModel.logger.info(s"Load fp32 model")
        supportive.loadOpenVinoIR(modelFile.getAbsolutePath,
          weightFile.getAbsolutePath,
          deviceType.value, batchSize)
      }
      FileUtils.deleteQuietly(modelFile)
      FileUtils.deleteQuietly(weightFile)
    }
    catch {
      case io: IOException =>
        OpenVINOModel.logger.error("error during loading OpenVINO model")
        throw io
    }
    nativeRef
  }

  override def predict(inputs: JList[JList[JTensor]]): JList[JList[JTensor]] = {
    val outputs = new ArrayList[JList[JTensor]]()
    inputs.asScala.map(input => {
      val tensor = input.get(0)
      val output = if (isInt8) {
        supportive.predictInt8(executableNetworkReference,
          tensor.getData, tensor.getShape)
      } else {
        supportive.predict(executableNetworkReference,
          tensor.getData, tensor.getShape)
      }
      outputs.add(Arrays.asList({
        output
      }))
    })
    outputs
  }

  override def predict(inputActivity: Activity): Activity = {
    val (inputList, batchSize) = inputActivity.isTable match {
      case true =>
        val inputTable = inputActivity.toTable
        val batchSize = inputTable.length()
        (transferBatchTableToJListOfJListOfJTensor(inputTable, batchSize), batchSize)
      case false =>
        val inputTensor = inputActivity.toTensor[Float]
        val batchSize = inputTensor.size(1)
        (transferBatchTensorToJListOfJListOfJTensor(inputTensor, batchSize), batchSize)
    }
    val outputs = predict(inputList)
    transferListOfActivityToActivityOfBatch(outputs, batchSize)
  }

  override def copy(num: Int): Array[AbstractModel] = {
    Array(this)
//    val arr = new Array[AbstractModel](num)
//    val modelBytes = this.modelHolder.modelBytes.clone()
//    val weightBytes = this.modelHolder.weightBytes.clone()
//    val modelHolder = new OpenVINOModelHolder(modelBytes, weightBytes)
//    (0 until num).foreach(i => {
//      arr(i) = new OpenVINOModel(modelHolder, isInt8, batchSize, DeviceType.CPU)
//    })
//    arr
  }

  override def release(): Unit = {
    isReleased match {
      case true =>
      case false =>
        supportive.releaseOpenVINOIR(executableNetworkReference)
        isRelease = true
    }
  }

  override def isReleased(): Boolean = {
    isRelease
  }
}

object OpenVINOModel {

  private val modelBytesRegistry = new RegistryMap[(Array[Byte], Array[Byte])]()

  val logger = LoggerFactory.getLogger(getClass)

  @transient
  private lazy val inDriver = NetUtils.isDriver

  class OpenVINOModelHolder(@transient var modelBytes: Array[Byte],
                            @transient var weightBytes: Array[Byte],
                            private var id: String)
    extends SerializationHolder {

    def this(modelBytes: Array[Byte], weightBytes: Array[Byte]) =
      this(modelBytes, weightBytes, UUID.randomUUID().toString)

    def getModelBytes(): Array[Byte] = {
      modelBytes
    }

    def getWeightBytes(): Array[Byte] = {
      weightBytes
    }

    override def writeInternal(out: CommonOutputStream): Unit = {
      val (graphDef, _) = modelBytesRegistry.getOrCreate(id) {
        (modelBytes, weightBytes)
      }
      logger.debug("Write OpenVINO model into stream")
      out.writeString(id)
      if (inDriver) {
        out.writeInt(graphDef._1.length)
        timing(s"writing " +
          s"${graphDef._1.length / 1024 / 1024}Mb openvino model to stream") {
          out.write(graphDef._1)
        }
        out.writeInt(graphDef._2.length)
        timing(s"writing " +
          s"${graphDef._2.length / 1024 / 1024}Mb openvino weight to stream") {
          out.write(graphDef._2)
        }
      } else {
        out.writeInt(0)
      }
    }

    override def readInternal(in: CommonInputStream): Unit = {
      id = in.readString()
      val (graphDef, _) = modelBytesRegistry.getOrCreate(id) {
        val modelLen = in.readInt()
        logger.debug("Read OpenVINO model from stream")
        assert(modelLen >= 0, "OpenVINO model length should be an non-negative integer")
        val localModelBytes = new Array[Byte](modelLen)
        timing("reading OpenVINO model from stream") {
          var numOfBytes = 0
          while (numOfBytes < modelLen) {
            val read = in.read(localModelBytes, numOfBytes, modelLen - numOfBytes)
            numOfBytes += read
          }
        }
        val weightLen = in.readInt()
        assert(weightLen >= 0, "OpenVINO weight length should be an non-negative integer")
        var localWeightBytes = new Array[Byte](weightLen)
        timing("reading OpenVINO weight from stream") {
          var numOfBytes = 0
          while (numOfBytes < weightLen) {
            val read = in.read(localWeightBytes, numOfBytes, weightLen - numOfBytes)
            numOfBytes += read
          }
        }
        (localModelBytes, localWeightBytes)
      }
      modelBytes = graphDef._1
      weightBytes = graphDef._2
      id = id
    }
  }

  def apply(modelHolder: OpenVINOModelHolder, isInt8: Boolean): OpenVINOModel = {
    new OpenVINOModel(modelHolder, isInt8)
  }

  def apply(modelHolder: OpenVINOModelHolder, isInt8: Boolean, batchSize: Int): OpenVINOModel = {
    new OpenVINOModel(modelHolder, isInt8, batchSize)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy