Maven / Gradle / Ivy
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.json4s.JsonDSL._
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.{Experimental, Since}
import{ParamPair, Params}
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
* Trait for [[MLWriter]] and [[MLReader]].
private[util] sealed trait BaseReadWrite {
private var optionSQLContext: Option[SQLContext] = None
* Sets the SQL context to use for saving/loading.
def context(sqlContext: SQLContext): this.type = {
optionSQLContext = Option(sqlContext)
* Returns the user-specified SQL context or the default.
protected final def sqlContext: SQLContext = {
if (optionSQLContext.isEmpty) {
optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate()))
/** Returns the [[SparkContext]] underlying [[sqlContext]] */
protected final def sc: SparkContext = sqlContext.sparkContext
* Abstract class for utility classes that can save ML instances.
abstract class MLWriter extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
* Saves the ML instances to the input path.
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit = {
val hadoopConf = sc.hadoopConfiguration
val outputPath = new Path(path)
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
if (fs.exists(qualifiedOutputPath)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(qualifiedOutputPath, true)
} else {
throw new IOException(
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
* [[save()]] handles overwriting and then calls this method. Subclasses should override this
* method to implement the actual saving of the instance.
protected def saveImpl(path: String): Unit
* Overwrites if the output path already exists.
def overwrite(): this.type = {
shouldOverwrite = true
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
* Trait for classes that provide [[MLWriter]].
trait MLWritable {
* Returns an [[MLWriter]] instance for this ML instance.
def write: MLWriter
* Saves this ML instance to the input path, a shortcut of ``.
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit =
private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
override def write: MLWriter = new DefaultParamsWriter(this)
* Abstract class for utility classes that can load ML instances.
* @tparam T ML instance type
abstract class MLReader[T] extends BaseReadWrite {
* Loads the ML component from the input path.
def load(path: String): T
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
* Trait for objects that provide [[MLReader]].
* @tparam T ML instance type
trait MLReadable[T] {
* Returns an [[MLReader]] instance for this class.
def read: MLReader[T]
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
* Note: Implementing classes should override this to be Java-friendly.
def load(path: String): T = read.load(path)
private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
override def read: MLReader[T] = new DefaultParamsReader
* Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @param instance object to save
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
private[ml] object DefaultParamsWriter {
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[]]s are encoded using
* [[]].
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render( { case ParamPair(p, v) => -> parse(p.jsonEncode(v))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
* Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
private[ml] class DefaultParamsReader[T] extends MLReader[T] {
override def load(path: String): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
val instance =
DefaultParamsReader.getAndSetParams(instance, metadata)
private[ml] object DefaultParamsReader {
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String)
* Load metadata from file.
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[]].
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
* Load a [[Params]] instance from the given path, and return it.
* This assumes the instance implements [[MLReadable]].
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy