All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.bundle.Serializer.scala Maven / Gradle / Ivy

The newest version!
package ml.bundle

import java.io._

import spray.json.DefaultJsonProtocol._
import ml.bundle.support.JsonStreamSerializer._
import spray.json.RootJsonFormat

import scala.reflect.ClassTag
import scala.util.{Failure, Try}

/**
  * Created by hollinwilkins on 3/4/18.
  */
object BundleInfo {
  implicit val format: RootJsonFormat[BundleInfo] = jsonFormat3(BundleInfo.apply)
  val serializer: StreamSerializer[BundleInfo] = format.toStreamSerializer("ml.bundle.v1.serializer.BundleInfo")
}

case class BundleInfo(className: String,
                      version: String,
                      hasMetaData: Boolean)

object MetaData {
  implicit val format: RootJsonFormat[MetaData] = jsonFormat3(MetaData.apply)
  val serializer: StreamSerializer[MetaData] = format.toStreamSerializer("ml.bundle.v1.serializer.MetaData")
}

case class MetaData(name: String,
                    description: String,
                    version: String)

case class Bundle(info: BundleInfo,
                  meta: Option[MetaData],
                  model: Any)

trait Serializer {
  var serializers: Map[String, StreamSerializer[_]] = Map()
  var bundleSerializers: Map[String, BundleSerializer[_]] = Map()
  var mlNameLookup: Map[String, String] = Map()
  var canonicalNameLookup: Map[String, String] = Map()

  val version: String

  def isCompatibleVersion(otherVersion: String): Boolean = version == otherVersion

  def addSerializer[T: ClassTag](serializer: StreamSerializer[T]) = {
    serializers += (serializer.key -> serializer)
    val name = implicitly[ClassTag[T]].runtimeClass.getCanonicalName
    mlNameLookup += (name -> serializer.key)
    canonicalNameLookup += (serializer.key -> name)
  }

  def addSerializer[T: ClassTag](serializer: BundleSerializer[T]) = {
    bundleSerializers += (serializer.key -> serializer)
    val name = implicitly[ClassTag[T]].runtimeClass.getCanonicalName
    mlNameLookup += (name -> serializer.key)
    canonicalNameLookup += (serializer.key -> name)
  }

  def getSerializer(key: String): Option[StreamSerializer[_]] = serializers.get(key)
  def getBundleSerializer(key: String): Option[BundleSerializer[_]] = bundleSerializers.get(key)

  def getMlName(key: String): String = mlNameLookup(key)
  def getCanonicalName(key: String): String = canonicalNameLookup(key)

  def serializeWithClass(obj: Any, out: OutputStream): Unit = {
    val key = mlNameLookup(obj.getClass.getCanonicalName)
    val bytes = key.getBytes
    val dataOut = new DataOutputStream(out)
    dataOut.writeInt(key.length)
    dataOut.write(bytes)
    serializers(key).serializeAny(obj, out)
  }

  def serialize(obj: Any, out: OutputStream): Unit = {
    val key = mlNameLookup(obj.getClass.getCanonicalName)
    serializers(key).serializeAny(obj, out)
  }

  def deserializeWithClass(in: InputStream): Any = {
    val dataIn = new DataInputStream(in)
    val size = dataIn.readInt()
    val bytes = new Array[Byte](size)
    val key = new String(bytes)
    serializers(key).deserializeAny(in)
  }

  def deserialize[T: ClassTag](in: InputStream): T = {
    val key = mlNameLookup(implicitly[ClassTag[T]].runtimeClass.getCanonicalName)
    serializers(key).deserializeAny(in).asInstanceOf[T]
  }

  def serializeWithClass(obj: Any,
                         bundle: BundleWriter,
                         metaData: Option[MetaData] = None): Bundle = {
    val key = mlNameLookup(obj.getClass.getCanonicalName)

    val infoWriter = bundle.contentWriter("info.ml")
    val info = BundleInfo(key, version, metaData.isDefined)
    BundleInfo.serializer.serialize(info, infoWriter)
    bundle.close(infoWriter)

    metaData.foreach {
      m =>
        val metaWriter = bundle.contentWriter("meta")
        MetaData.serializer.serialize(m, metaWriter)
        bundle.close(metaWriter)
    }

    serializers.get(key) match {
      case Some(serializer) =>
        val contentWriter = bundle.contentWriter("content")
        serializer.serializeAny(obj, contentWriter)
        bundle.close(contentWriter)
      case None =>
        bundleSerializers.get(key) match {
          case Some(serializer) =>
            serializer.serializeAny(obj, bundle)
          case None =>
            throw new Error("Could not serialize to bundle: " + key)
        }
    }

    Bundle(info, metaData, obj)
  }

  def deserializeWithClass(bundle: BundleReader): Bundle = {
    val info = deserializeBundleInfo(bundle)

    val meta = if(info.hasMetaData) {
      Some(deserializeMetaData(bundle))
    } else {
      None
    }

    val model = serializers.get(info.className) match {
      case Some(serializer) =>
        val contentReader = bundle.contentReader("content")
        val obj = serializer.deserializeAny(contentReader)
        bundle.close(contentReader)
        obj
      case None =>
        bundleSerializers.get(info.className) match {
          case Some(serializer) =>
            serializer.deserializeAny(bundle)
          case None =>
            throw new Error("Could not deserialize: " + info.className)
        }
    }

    Bundle(info, meta, model)
  }

  def validate(bundle: BundleReader): Try[(BundleInfo, Option[MetaData])] = {
    Try(deserializeBundleInfo(bundle))
      .flatMap {
        info =>
          val meta = if(info.hasMetaData) {
            Some(deserializeMetaData(bundle))
          } else {
            None
          }

          serializers.get(info.className) match {
            case Some(serializer) =>
              val contentReader = bundle.contentReader("content")
              val result = serializer.validate(contentReader)
              bundle.close(contentReader)
              result.map(_ => (info, meta))
            case None =>
              bundleSerializers.get(info.className) match {
                case Some(serializer) =>
                  serializer.validate(bundle).map(_ => (info, meta))
                case None =>
                  Failure(new Error("Unknown serialization object: " + info.className))
              }
          }
      }
  }

  private def deserializeBundleInfo(bundle: BundleReader): BundleInfo = {
    val info = BundleInfo.serializer.deserialize(bundle.contentReader("info.ml"))

    if(!isCompatibleVersion(info.version)) {
      throw new Error(s"Incompatible version: ${info.version}, supported version is: $version")
    }

    info
  }

  private def deserializeMetaData(bundle: BundleReader): MetaData = {
    MetaData.serializer.deserialize(bundle.contentReader("meta"))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy