All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.combust.bundle.serializer.ModelSerializer.scala Maven / Gradle / Ivy

package ml.combust.bundle.serializer

import java.nio.file.{Files, Path}

import ml.combust.bundle.{BundleContext, HasBundleRegistry}
import ml.combust.bundle.dsl.{Bundle, Model}
import ml.combust.bundle.json.JsonSupport._
import spray.json._

import scala.util.Try

/** Trait for serializing a protobuf model definition.
  */
trait FormatModelSerializer {
  /** Write a protobuf model definition.
    *
    * @param path path to write to
    * @param model model to write
    */
  def write(path: Path, model: Model): Unit

  /** Read a protobuf model definition.
    *
    * @param path path to read from
    * @return model that was read
    */
  def read(path: Path): Model
}

/** Companion object for utility methods related to model definition serialization.
  */
object FormatModelSerializer {
  /** Get the appropriate JSON or Protobuf serializer based on the serialization context.
    *
    * @param bc bundle context
    * @return JSON or Protobuf model definition serializer depending on serialization context
    */
  def serializer[T](implicit bc: BundleContext[T]): FormatModelSerializer = bc.format match {
    case SerializationFormat.Json => JsonFormatModelSerializer()
    case SerializationFormat.Protobuf => ProtoFormatModelSerializer()
  }
}

/** Object for serializing/deserializing model definitions with JSON.
  */
case class JsonFormatModelSerializer(implicit hr: HasBundleRegistry) extends FormatModelSerializer {
  override def write(path: Path, model: Model): Unit = {
    Files.write(path, model.asBundle.toJson.prettyPrint.getBytes("UTF-8"))
  }

  override def read(path: Path): Model = {
    Model.fromBundle(new String(Files.readAllBytes(path), "UTF-8").parseJson.convertTo[ml.bundle.Model])
  }
}

/** Object for serializing/deserializing model definitions with Protobuf.
  */
case class ProtoFormatModelSerializer(implicit hr: HasBundleRegistry) extends FormatModelSerializer {
  override def write(path: Path, model: Model): Unit = {
    Files.write(path, model.asBundle.toByteArray)
  }

  override def read(path: Path): Model = {
    Model.fromBundle(ml.bundle.Model.parseFrom(Files.readAllBytes(path)))
  }
}

/** Class for serializing Bundle.ML models.
  *
  * @param bundleContext bundle context for path and bundle registry
  * @tparam Context context class for implementation
  */
case class ModelSerializer[Context](bundleContext: BundleContext[Context]) {
  private implicit val bc = bundleContext
  private implicit val format = bundleContext.format

  /** Write a model to the current context path.
    *
    * This will write the model to the current path in the [[BundleContext]].
    *
    * @param obj model to write
    */
  def write(obj: Any): Try[Any] = Try {
    Files.createDirectories(bundleContext.path)
    val m = bundleContext.bundleRegistry.modelForObj[Context, Any](obj)
    val model = Model(op = m.modelOpName(obj))
    m.store(model, obj)(bundleContext)
  }.flatMap {
    model => Try(FormatModelSerializer.serializer.write(bundleContext.file(Bundle.modelFile), model))
  }

  /** Read a model from the current bundle context.
    *
    * @return deserialized model
    */
  def read(): Try[Any] = readWithModel().map(_._1)

  /** Read a model and return it along with its type class.
    *
    * @return (deserialized model, model type class)
    */
  def readWithModel(): Try[(Any, Model)] = {
    Try(FormatModelSerializer.serializer.read(bundleContext.file(Bundle.modelFile))).map {
      bundleModel =>
        val om = bundleContext.bundleRegistry.model[Context, Any](bundleModel.op)
        (om.load(bundleModel), bundleModel)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy