All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.downloader.Schema.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.downloader

import java.io.InputStream
import java.net.URI
import org.apache.commons.codec.digest.DigestUtils
import spray.json._

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

private[spark] object NamingConventions {

  def canonicalModelFilename(name: String, dataset: String): String =
    s"${name}_$dataset.model"

  def canonicalModelFilename(model: ModelSchema): String =
    s"${model.name}_${model.dataset}.model"

}

/** Abstract representation of a schema for an item that can be held in a repository
  *
  * @param uri  location of the underlying file (local, HDFS, or HTTP)
  * @param hash sha256 hash of the underlying file
  * @param size size in bytes of the underlying file
  */
abstract class Schema(val uri: URI, val hash: String, val size: Long) {

  def updateURI(newURI: URI): this.type

  def assertMatchingHash(bytes: InputStream): Unit = {
    val newHash = DigestUtils.sha256Hex(bytes)
    if (newHash != hash) {
      throw new IllegalArgumentException(s"downloaded hash: $newHash does not match given hash: $hash")
    }
  }

}

/** Class representing the schema of a CNTK model
  * @param name      name of the model architecture
  * @param dataset   dataset the model was trained on
  * @param modelType type of problem the model is suited for eg: (image, text, sound, sentiment etc)
  * @param uri       location of the underlying file (local, HDFS, or HTTP)
  * @param hash      sha256 hash of the underlying file
  * @param size      size in bytes of the underlying file
  * @param inputNode the node which represents the input
  * @param numLayers the number of layers of the model
  * @param layerNames the names nodes that represent layers in the network
  */
case class ModelSchema(name: String,
                       dataset: String,
                       modelType: String,
                       override val uri: URI,
                       override val hash: String,
                       override val size: Long,
                       inputNode: Int,
                       numLayers: Int,
                       layerNames: Array[String])
  extends Schema(uri, hash, size) {

  def this(name: String, dataset: String, modelType: String,
           uri: URI, hash: String, size: Long, inputNode: Int, numLayers: Int,
           layerNames: java.util.ArrayList[String]) = {
    this(name, dataset, modelType, uri, hash, size,
      inputNode, numLayers, layerNames.toList.toArray)
  }

  override def updateURI(newURI: URI): this.type = this.copy(uri = newURI).asInstanceOf[this.type]

}

private[spark] object SchemaJsonProtocol extends DefaultJsonProtocol {

  implicit object URIJsonFormat extends JsonFormat[URI] {
    def write(u: URI): JsValue = {
      JsString(u.toString)
    }

    def read(value: JsValue): URI = new URI(value.asInstanceOf[JsString].value)
  }

  implicit val ModelSchemaFormat: RootJsonFormat[ModelSchema] =
    jsonFormat(ModelSchema.apply,
      "name", "dataset", "modelType", "uri", "hash", "size", "inputNode", "numLayers", "layerNames")

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy