All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.downloader.ModelDownloader.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.downloader

import java.io._
import java.net.{URI, URL}
import java.util

import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.{Configuration => HadoopConf}
import org.apache.hadoop.fs.{FileSystem, LocatedFileStatus, Path}
import org.apache.hadoop.io.{IOUtils => HUtils}
import org.apache.log4j.LogManager
import org.apache.spark.sql.SparkSession
import spray.json._

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}

/** Abstract representation of a repository for future expansion
  *
  * @tparam S an instantiation of the
  */
private[spark] abstract class Repository[S <: Schema] {

  def listSchemas(): Iterable[S]

  def getBytes(schema: S): InputStream

  def addBytes(schema: S, location: URI, bytes: InputStream): S

}

object FaultToleranceUtils {
  def retryWithTimeout[T](times: Int, timeout: Duration)(f: => T): T ={
    try {
      Await.result(Future(f)(ExecutionContext.global), timeout)
    } catch {
      case e: Exception if times >= 1 =>
        print(s"Received exception on call, retrying: $e")
        retryWithTimeout(times-1, timeout)(f)
    }
  }
}

/** Exception returned if a repo cannot find the file
  *
  * @param uri : location of the file
  */
class ModelNotFoundException(uri: URI) extends FileNotFoundException(s"model located at $uri could not be found")

private[spark] class HDFSRepo[S <: Schema](val uri: URI, val hconf: HadoopConf)
                                            (implicit val jsonFormat: JsonFormat[S])
  extends Repository[S] {

  private val rootPath = new Path(uri)

  private val fs = FileSystem.get(uri, hconf)

  if (!fs.exists(rootPath)) {
    fs.mkdirs(rootPath)
  }

  override def listSchemas(): Iterable[S] = {
    val fileIteratorHadoop = fs.listFiles(rootPath, false)
    val fileIterator = new Iterator[LocatedFileStatus] {
      def hasNext: Boolean = fileIteratorHadoop.hasNext

      def next(): LocatedFileStatus = fileIteratorHadoop.next()
    }

    val schemaStrings = fileIterator
      .filter(status =>
        status.isFile & status.getPath.toString.endsWith(".meta"))
      .map(status =>
        IOUtils.toString(fs.open(status.getPath).getWrappedStream))

    schemaStrings.map(s => s.parseJson.convertTo[S]).toList
  }

  override def getBytes(schema: S): InputStream = {
    try {
      fs.open(new Path(schema.uri))
    } catch {
      case _: IOException => throw new ModelNotFoundException(schema.uri)
    }
  }

  override def addBytes(schema: S, location: URI, bytes: InputStream): S = {
    val path = new Path(location)
    val os = fs.create(path)
    try {
      HUtils.copyBytes(bytes, os, hconf)
    } finally {
      os.close()
    }
    val downloadedIs = fs.open(path)
    try {
      schema.assertMatchingHash(downloadedIs)
    } finally {
      downloadedIs.close()
    }

    val newSchema = schema.updateURI(location)
    val schemaPath = new Path(location.getPath + ".meta")
    val osSchema = fs.create(schemaPath)
    val schemaIs = IOUtils.toInputStream(newSchema.toJson.prettyPrint)
    try {
      HUtils.copyBytes(schemaIs, osSchema, hconf)
    } finally {
      osSchema.close()
      schemaIs.close()
    }
    newSchema
  }

}

/** Class to represent repository of models that will eventually be hosted outside
  * the repo.
  */
private[spark] class DefaultModelRepo(val baseURL: URL) extends Repository[ModelSchema] {
  var connectTimeout = 15000
  var readTimeout = 5000

  import SchemaJsonProtocol._

  private def toStream(url: URL) = {
    val urlCon = url.openConnection()
    urlCon.setConnectTimeout(connectTimeout)
    urlCon.setReadTimeout(readTimeout)
    new BufferedInputStream(urlCon.getInputStream)
  }

  private def join(root: URL, file: String) = {
    new Path(new Path(root.toURI), file).toUri.toURL
  }

  override def listSchemas(): Iterable[ModelSchema] = {
    val url = join(baseURL, "MANIFEST")
    val manifestStream = toStream(url)
    try {
      val modelStreams = IOUtils.readLines(manifestStream).map(fn => toStream(join(baseURL, fn)))
      try {
        modelStreams.map(s => IOUtils.toString(s).parseJson.convertTo[ModelSchema])
      } finally {
        modelStreams.foreach(_.close())
      }
    } finally {
      manifestStream.close()
    }
  }

  override def getBytes(schema: ModelSchema): InputStream = {
    try {
      val url = schema.uri.toURL
      val urlCon = url.openConnection()
      urlCon.setConnectTimeout(connectTimeout)
      urlCon.setReadTimeout(readTimeout)
      new BufferedInputStream(urlCon.getInputStream)
    } catch {
      case _: IOException => throw new ModelNotFoundException(schema.uri)
    }
  }

  override def addBytes(schema: ModelSchema, location: URI, bytes: InputStream): ModelSchema =
    throw new IllegalAccessError("Do not have the credentials to write a file to the remote repository")
}

private[spark] abstract class Client {
  var quiet = false

  private def log(s: String): Unit = {
    LogManager.getRootLogger.info(s)
  }

  def repoTransfer[T <: Schema](schema: T, targetLocation: URI,
                                source: Repository[T], target: Repository[T],
                                overwrite: Boolean = false, closeStream: Boolean = true): T = {
    if (target.listSchemas().exists(s =>
      (s.uri == targetLocation) && (s.hash == schema.hash))) {
      log(s"Using model at $targetLocation, skipping download")
      target.listSchemas().find(_.hash == schema.hash).get
    } else {
      log(s"No model found in local repo, writing bytes to $targetLocation")
      val sourceStream = source.getBytes(schema)
      try {
        target.addBytes(schema, targetLocation, sourceStream)
      } finally {
        if (closeStream) sourceStream.close()
      }
    }
  }

}

private[spark] object ModelDownloader {
  private[spark] val DefaultURL = new URL("https://mmlspark.azureedge.net/datasets/CNTKModels/")
}

/** Class for downloading models from a server to Local or HDFS
  *
  * @param spark Spark session so that the downloader can save to HDFS
  * @param localPath path to a directory that will store the models (local or HDFS)
  * @param serverURL URL of the server which supplies models ( The default URL is subject to change)
  */
class ModelDownloader(val spark: SparkSession,
                      val localPath: URI,
                      val serverURL: URL = ModelDownloader.DefaultURL) extends Client {

  import SchemaJsonProtocol._

  def this(spark: SparkSession, localPath: String, serverURL: String) = {
    this(spark, new URI(localPath), new URL(serverURL))
  }

  private val localModelRepo = new HDFSRepo[ModelSchema](localPath, spark.sparkContext.hadoopConfiguration)

  private val remoteModelRepo = new DefaultModelRepo(serverURL)

  /** Function for querying the local repository for its registered models
    *
    * @return the model schemas found in the downloader's local path
    */
  def localModels: util.Iterator[ModelSchema] =
    FaultToleranceUtils.retryWithTimeout(3, Duration.apply(60, "seconds")) {
      localModelRepo.listSchemas().iterator.asJava
    }

  /** Function for querying the remote server for its registered models
    *
    * @return the model schemas found in remote reposiory accessed through the serverURL
    */
  def remoteModels: util.Iterator[ModelSchema] =
    FaultToleranceUtils.retryWithTimeout(3, Duration.apply(60, "seconds")) {
      remoteModelRepo.listSchemas().iterator.asJava
    }

  /** Method to download a single model
    * @param model the remote model schema
    * @return the new local model schema with a URI that points to the model's location (on HDFS or local)
    */
  def downloadModel(model: ModelSchema): ModelSchema = {
    FaultToleranceUtils.retryWithTimeout(3, Duration.apply(10, "minutes")) {
      repoTransfer(model,
        new Path(new Path(localPath), NamingConventions.canonicalModelFilename(model)).toUri,
        remoteModelRepo, localModelRepo)
    }
  }

  def downloadByName(name: String): ModelSchema = {
    val models = remoteModels.filter(_.name == name).toList
    if (models.length != 1) {
      throw new IllegalArgumentException(s"there are ${models.length} models with the same name")
    }
    downloadModel(models.head)
  }

  /** @param models An iterable of remote model schemas
    * @return An list of local model schema whose URI's points to the model's location (on HDFS or local)
    */
  def downloadModels(models: Iterable[ModelSchema] = remoteModels.toIterable): List[ModelSchema] =
  // Call toList so that all models are downloaded when downloadModels are called
    models.map(downloadModel).toList

  /** @param models A java iterator of remote model schemas for in the java api (for python wrapper)
    * @return A java List of local model schema whose URI's points to the model's location (on HDFS or local)
    */
  def downloadModels(models: util.ArrayList[ModelSchema]): util.List[ModelSchema] =
  // Call toList so that all models are downloaded when downloadModels are called
    models.map(downloadModel).toList.asJava

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy