* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import java.util.{Locale, ServiceLoader}
import{OneVsRest, OneVsRestModel}
import{Param, ParamPair, Params}
import{Pipeline, PipelineModel, PipelineStage}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.util.SparkUtil
import org.apache.spark.{SparkContext, SparkException}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JObject, _}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
* Trait for `MLWriter` and `MLReader`.
private[sona] sealed trait BaseReadWrite {
private var optionSparkSession: Option[SparkSession] = None
* Sets the Spark SQLContext to use for saving/loading.
* @deprecated Use session instead. This method will be removed in 3.0.0.
@deprecated("Use session instead. This method will be removed in 3.0.0.", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
optionSparkSession = Option(sqlContext.sparkSession)
* Sets the Spark Session to use for saving/loading.
def session(sparkSession: SparkSession): this.type = {
optionSparkSession = Option(sparkSession)
* Returns the user-specified Spark Session or the default.
protected final def sparkSession: SparkSession = {
if (optionSparkSession.isEmpty) {
optionSparkSession = Some(SparkSession.builder().getOrCreate())
* Returns the user-specified SQL context or the default.
protected final def sqlContext: SQLContext = sparkSession.sqlContext
/** Returns the underlying `SparkContext`. */
protected final def sc: SparkContext = sparkSession.sparkContext
* Abstract class to be implemented by objects that provide ML exportability.
* A new instance of this class will be instantiated each time a save call is made.
* Must have a valid zero argument constructor which will be called to instantiate.
* @since 2.4.0
trait MLWriterFormat {
* Function to write the provided pipeline stage out.
* @param path The path to write the result out to.
* @param session SparkSession associated with the write request.
* @param optionMap User provided options stored as strings.
* @param stage The pipeline stage to be saved.
def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
stage: PipelineStage): Unit
* ML export formats for should implement this trait so that users can specify a shortname rather
* than the fully qualified class name of the exporter.
* A new instance of this class will be instantiated each time a save call is made.
* @since 2.4.0
trait MLFormatRegister extends MLWriterFormat {
* The string that represents the format that this format provider uses. This is, along with
* stageName, is overridden by children to provide a nice alias for the writer. For example:
* {{{
* override def format(): String =
* "pmml"
* }}}
* Indicates that this format is capable of saving a pmml model.
* Must have a valid zero argument constructor which will be called to instantiate.
* Format discovery is done using a ServiceLoader so make sure to list your format in
* META-INF/services.
* @since 2.4.0
def format(): String
* The string that represents the stage type that this writer supports. This is, along with
* format, is overridden by children to provide a nice alias for the writer. For example:
* {{{
* override def stageName(): String =
* ""
* }}}
* Indicates that this format is capable of saving Spark's own PMML model.
* Format discovery is done using a ServiceLoader so make sure to list your format in
* META-INF/services.
* @since 2.4.0
def stageName(): String
private[angel] def shortName(): String = s"${format()}+${stageName()}"
* Abstract class for utility classes that can save ML instances in Spark's internal format.
abstract class MLWriter extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
* Saves the ML instances to the input path.
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit = {
new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc)
* `save()` handles overwriting and then calls this method. Subclasses should override this
* method to implement the actual saving of the instance.
protected def saveImpl(path: String): Unit
* Overwrites if the output path already exists.
def overwrite(): this.type = {
shouldOverwrite = true
* Map to store extra options for this writer.
protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()
* Adds an option to the underlying MLWriter. See the documentation for the specific model's
* writer for possible options. The option name (key) is case-insensitive.
def option(key: String, value: String): this.type = {
require(key != null && !key.isEmpty)
optionMap.put(key.toLowerCase(Locale.ROOT), value)
// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
* A ML Writer which delegates based on the requested format.
class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
private var source: String = "internal"
* Specifies the format of ML export (e.g. "pmml", "internal", or
* the fully qualified class name for export).
def format(source: String): this.type = {
this.source = source
* Dispatches the save to the correct MLFormat.
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
@throws[SparkException]("If multiple sources for a given short name format are found.")
override protected def saveImpl(path: String): Unit = {
val loader = SparkUtil.getContextOrSparkClassLoader
val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
val stageName = stage.getClass.getName
val targetName = s"$source+$stageName"
val formats = serviceLoader.asScala.toList
val shortNames =
val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
// requested name did not match any given registered alias
case Nil =>
Try(loader.loadClass(source)) match {
case Success(writer) =>
// Found the ML writer using the fully qualified path
case Failure(error) =>
throw new SparkException(
s"Could not load requested format $source for $stageName ($targetName) had $formats" +
s"supporting $shortNames", error)
case head :: Nil =>
case _ =>
// Multiple sources
throw new SparkException(
s"Multiple writers found for $source+$stageName, try using the class name of the writer")
if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
writer.write(path, sparkSession, optionMap, stage)
} else {
throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
* Trait for classes that provide `MLWriter`.
trait MLWritable {
* Returns an `MLWriter` instance for this ML instance.
def write: MLWriter
* Saves this ML instance to the input path, a shortcut of ``.
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit =
* Trait for classes that provide `GeneralMLWriter`.
trait GeneralMLWritable extends MLWritable {
* Returns an `MLWriter` instance for this ML instance.
override def write: GeneralMLWriter
* :: DeveloperApi ::
* Helper trait for making simple `Params` types writable. If a `Params` class stores
* all data as [[Param]] values, then extending this trait will provide
* a default implementation of writing saved instances of the class.
* This only handles simple [[Param]] types; e.g., it will not handle
* [[org.apache.spark.sql.Dataset]].
* @see `DefaultParamsReadable`, the counterpart to this trait
trait DefaultParamsWritable extends MLWritable {
self: Params =>
override def write: MLWriter = new DefaultParamsWriter(this)
* Abstract class for utility classes that can load ML instances.
* @tparam T ML instance type
abstract class MLReader[T] extends BaseReadWrite {
* Loads the ML component from the input path.
def load(path: String): T
// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
* Trait for objects that provide `MLReader`.
* @tparam T ML instance type
trait MLReadable[T] {
* Returns an `MLReader` instance for this class.
def read: MLReader[T]
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
* @note Implementing classes should override this to be Java-friendly.
def load(path: String): T = read.load(path)
* :: DeveloperApi ::
* Helper trait for making simple `Params` types readable. If a `Params` class stores
* all data as [[Param]] values, then extending this trait will provide
* a default implementation of reading saved instances of the class.
* This only handles simple [[Param]] types; e.g., it will not handle
* [[org.apache.spark.sql.Dataset]].
* @tparam T ML instance type
* @see `DefaultParamsWritable`, the counterpart to this trait
trait DefaultParamsReadable[T] extends MLReadable[T] {
override def read: MLReader[T] = new DefaultParamsReader[T]
* Default `MLWriter` implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @param instance object to save
private[angel] class DefaultParamsWriter(instance: Params) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
private[angel] object DefaultParamsWriter {
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - defaultParamMap
* - paramMap
* - (optionally, extra metadata)
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[Param]]s are encoded using
* [[ml.param.Param.jsonEncode()]].
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val metadataPath = new Path(path, "metadata").toString
val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many sub-models.
* @see [[saveMetadata()]] for details on what this includes.
def getMetadataToSave(
instance: Params,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.paramMap.toSeq
val defaultParams = instance.defaultParamMap.toSeq
val jsonParams = paramMap.getOrElse(render( { case ParamPair(p, v) => -> parse(p.jsonEncode(v))
val jsonDefaultParams = render( { case ParamPair(p, v) => -> parse(p.jsonEncode(v))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams) ~
("defaultParamMap" -> jsonDefaultParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
val metadataJson: String = compact(render(metadata))
* Default `MLReader` implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
private[angel] class DefaultParamsReader[T] extends MLReader[T] {
override def load(path: String): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = SparkUtil.classForName(metadata.className)
val instance =
private[angel] object DefaultParamsReader {
* All info from metadata file.
* @param params paramMap, as a `JValue`
* @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4,
* this is `JNothing`.
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
defaultParams: JValue,
metadata: JValue,
metadataJson: String) {
private def getValueFromParams(params: JValue): Seq[(String, JValue)] = {
params match {
case JObject(pairs) => pairs
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: $metadataJson.")
* Get the JSON value of the [[Param]] of the given name.
* This can be useful for getting a Param value before an instance of `Params`
* is available. This will look up `params` first, if not existing then looking up
* `defaultParams`.
def getParamValue(paramName: String): JValue = {
implicit val format = DefaultFormats
// Looking up for `params` first.
var pairs = getValueFromParams(params)
var foundPairs = pairs.filter { case (pName, jsonValue) =>
pName == paramName
if (foundPairs.length == 0) {
// Looking up for `defaultParams` then.
pairs = getValueFromParams(defaultParams)
foundPairs = pairs.filter { case (pName, jsonValue) =>
pName == paramName
assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" +
s" ${foundPairs.length} in JSON Params: " +", "))
* Extract Params from metadata, and set them in the instance.
* This works if all Params (except params included by `skipParams` list) implement
* [[ml.param.Param.jsonDecode()]].
* @param skipParams The params included in `skipParams` won't be set. This is useful if some
* params don't implement [[ml.param.Param.jsonDecode()]]
* and need special handling.
def getAndSetParams(
instance: Params,
skipParams: Option[List[String]] = None): Unit = {
setParams(instance, skipParams, isDefault = false)
// For metadata file prior to Spark 2.4, there is no default section.
val (major, minor) = SparkUtil.majorMinorVersion(sparkVersion)
if (major > 2 || (major == 2 && minor >= 4)) {
setParams(instance, skipParams, isDefault = true)
private def setParams(
instance: Params,
skipParams: Option[List[String]],
isDefault: Boolean): Unit = {
implicit val format = DefaultFormats
val paramsToSet = if (isDefault) defaultParams else params
paramsToSet match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
if (skipParams == None || !skipParams.get.contains(paramName)) {
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
if (isDefault) {
Params.setDefault(instance, param, value)
} else {
instance.set(param, value)
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadataJson}.")
* Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
parseMetadata(metadataStr, expectedClassName)
* Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
* This is a helper function for [[loadMetadata()]].
* @param metadataStr JSON string of metadata
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val defaultParams = metadata \ "defaultParamMap"
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr)
* Load a `Params` instance from the given path, and return it.
* This assumes the instance implements [[MLReadable]].
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = SparkUtil.classForName(metadata.className)
* Default Meta-Algorithm read and write implementation.
private[angel] object MetaAlgorithmReadWrite {
* Examine the given estimator (which may be a compound estimator) and extract a mapping
* from UIDs to corresponding `Params` instances.
def getUidMap(instance: Params): Map[String, Params] = {
val uidList = getUidMapImpl(instance)
val uidMap = uidList.toMap
if (uidList.size != uidMap.size) {
throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
s" with stages with duplicate UIDs. List of UIDs: ${", ")}.")
private def getUidMapImpl(instance: Params): List[(String, Params)] = {
val subStages: Array[Params] = instance match {
case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
case ovr: OneVsRest => Array(ovr.getClassifier)
case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
case _: Params => Array.empty[Params]
val subStageMaps = subStages.flatMap(getUidMapImpl)
List((instance.uid, instance)) ++ subStageMaps
private[angel] class FileSystemOverwrite extends Logging {
def handleOverwrite(path: String, shouldOverwrite: Boolean, sc: SparkContext): Unit = {
val hadoopConf = sc.hadoopConfiguration
val outputPath = new Path(path)
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
if (fs.exists(qualifiedOutputPath)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(qualifiedOutputPath, true)
} else {
throw new IOException(s"Path $path already exists. To overwrite it, " +
s"please use write.overwrite().save(path) for Scala and use " +
s"write().overwrite().save(path) for Java and Python.")
