All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.CNTK.SerializableFunction.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.CNTK

import java.io._
import java.util.UUID.randomUUID

import com.microsoft.CNTK.CNTKUtils._
import com.microsoft.ml.spark.core.env.StreamUtilities.using
import org.apache.commons.io.FileUtils.{forceDelete, getTempDirectoryPath, writeByteArrayToFile}
import org.apache.commons.io.IOUtils
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructField, DataType => SDataType}

import scala.language.implicitConversions

object CNTKExtensions {
  implicit def toSerializable(in: Function): SerializableFunction = new SerializableFunction(in)

  implicit def fromSerializable(in: SerializableFunction): Function = in.fvar
}

object CNTKUtils {

  def loadModelFromBytes(bytes: Array[Byte],
                         device: DeviceDescriptor =
                         DeviceDescriptor.useDefaultDevice): Function = {
    try {
      Function.load(bytes, device)
    } catch {
      case _: IllegalArgumentException =>
        // println("WARNING: model is a legacy model," +
        //           " consider using a more recent model for faster loading")
        import java.util.UUID._
        val modelFile = new File(s"$getTempDirectoryPath/$randomUUID.model")
        writeByteArrayToFile(modelFile, bytes)
        val model = try {
          Function.load(modelFile.getPath, device)
        } finally forceDelete(modelFile)
        model
    }
  }

  def variableToElementType(v: Variable): SDataType = {
    v.getDataType match {
      case DataType.Double => DoubleType
      case DataType.Float => FloatType
      case t => throw new IllegalArgumentException(s"Data Type $t not supported")
    }
  }

  def variableToSchema(v: Variable): SDataType = {
    val baseType: SDataType = variableToElementType(v)
    (1 to v.getShape.getRank.toInt).foldLeft(baseType) {case (t,_) => ArrayType(t)}
  }

  def variableToArraySchema(v: Variable): SDataType = {
    ArrayType(variableToElementType(v))
  }

  val ArgumentPrefix = "ARGUMENT_" // used for indexing into the arguments array of the model

  val OutputPrefix = "OUTPUT_" // used for indexing into the outputs array of the modes

}

object SerializableFunction {

  import CNTKExtensions._

  def load(filepath: String): SerializableFunction = Function.load(filepath)

  def load(buffer: Array[Byte], device: DeviceDescriptor): SerializableFunction =
    Function.load(buffer, device)

  def load(filepath: String, device: DeviceDescriptor): SerializableFunction =
    Function.load(filepath, device)

  def loadModelFromBytes(bytes: Array[Byte],
                         device: DeviceDescriptor = DeviceDescriptor.useDefaultDevice()): SerializableFunction =
    CNTKUtils.loadModelFromBytes(bytes, device)

}

class SerializableFunction(f: Function) extends Serializable {
  var fvar: Function = f

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream): Unit = {
    val modelFile = new File(s"$getTempDirectoryPath/$randomUUID.model")
    fvar.save(modelFile.toString)
    val bytes = try {
      using(new FileInputStream(modelFile)){ fis =>
        IOUtils.toByteArray(fis)
      }.get
    } finally {
      forceDelete(modelFile)
    }
    out.writeObject(bytes)
  }

  @throws(classOf[IOException])
  private def readObject(in: ObjectInputStream): Unit = {
    val bytes = in.readObject().asInstanceOf[Array[Byte]]
    val device = DeviceDescriptor.useDefaultDevice()
    fvar = CNTKUtils.loadModelFromBytes(bytes, device)
  }

  private def findByNameSafe(name: String): Variable = {
    Option(fvar.findByName(name)).map(_.getOutput).getOrElse(
      throw new IllegalArgumentException(s"Node $name could not be found in the model")
    )
  }

  def getInputVar(name: String): Variable = {
    if (name.startsWith(CNTKUtils.ArgumentPrefix)){
      fvar.getArguments.get(name.stripPrefix(CNTKUtils.ArgumentPrefix).toInt)
    }else{
      findByNameSafe(name)
    }
  }

  def getOutputVar(name: String): Variable = {
    if (name.startsWith(CNTKUtils.OutputPrefix)){
      fvar.getOutputs.get(name.stripPrefix(CNTKUtils.OutputPrefix).toInt)
    }else{
      findByNameSafe(name)
    }
  }

  def getInputTypes(feedDict: Map[String, String]): List[StructField] = {
    feedDict.toList.sorted.map { case (varname, colname) =>
      StructField(colname, variableToArraySchema(getInputVar(varname)))
    }
  }

  def getOutputSchema(fetchDict: Map[String, String]): List[StructField] = {
    fetchDict.toList.sorted.map { case (colname, varname) =>
      StructField(colname, variableToArraySchema(getOutputVar(varname)))
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy