All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.util.ReadWrite.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.util

import java.io.IOException
import java.util.{Locale, ServiceLoader}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.{Since, Unstable}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.{Utils, VersionUtils}

/**
 * Trait for `MLWriter` and `MLReader`.
 */
private[util] sealed trait BaseReadWrite {
  private var optionSparkSession: Option[SparkSession] = None

  /**
   * Sets the Spark Session to use for saving/loading.
   */
  @Since("2.0.0")
  def session(sparkSession: SparkSession): this.type = {
    optionSparkSession = Option(sparkSession)
    this
  }

  /**
   * Returns the user-specified Spark Session or the default.
   */
  protected final def sparkSession: SparkSession = {
    if (optionSparkSession.isEmpty) {
      optionSparkSession = Some(SparkSession.builder().getOrCreate())
    }
    optionSparkSession.get
  }

  /**
   * Returns the user-specified SQL context or the default.
   */
  protected final def sqlContext: SQLContext = sparkSession.sqlContext

  /** Returns the underlying `SparkContext`. */
  protected final def sc: SparkContext = sparkSession.sparkContext
}

/**
 * Abstract class to be implemented by objects that provide ML exportability.
 *
 * A new instance of this class will be instantiated each time a save call is made.
 *
 * Must have a valid zero argument constructor which will be called to instantiate.
 *
 * @since 2.4.0
 */
@Unstable
@Since("2.4.0")
trait MLWriterFormat {
  /**
   * Function to write the provided pipeline stage out.
   *
   * @param path  The path to write the result out to.
   * @param session  SparkSession associated with the write request.
   * @param optionMap  User provided options stored as strings.
   * @param stage  The pipeline stage to be saved.
   */
  @Since("2.4.0")
  def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
    stage: PipelineStage): Unit
}

/**
 * ML export formats for should implement this trait so that users can specify a shortname rather
 * than the fully qualified class name of the exporter.
 *
 * A new instance of this class will be instantiated each time a save call is made.
 *
 * @since 2.4.0
 */
@Unstable
@Since("2.4.0")
trait MLFormatRegister extends MLWriterFormat {
  /**
   * The string that represents the format that this format provider uses. This is, along with
   * stageName, is overridden by children to provide a nice alias for the writer. For example:
   *
   * {{{
   *   override def format(): String =
   *       "pmml"
   * }}}
   * Indicates that this format is capable of saving a pmml model.
   *
   * Must have a valid zero argument constructor which will be called to instantiate.
   *
   * Format discovery is done using a ServiceLoader so make sure to list your format in
   * META-INF/services.
   * @since 2.4.0
   */
  @Since("2.4.0")
  def format(): String

  /**
   * The string that represents the stage type that this writer supports. This is, along with
   * format, is overridden by children to provide a nice alias for the writer. For example:
   *
   * {{{
   *   override def stageName(): String =
   *       "org.apache.spark.ml.regression.LinearRegressionModel"
   * }}}
   * Indicates that this format is capable of saving Spark's own PMML model.
   *
   * Format discovery is done using a ServiceLoader so make sure to list your format in
   * META-INF/services.
   * @since 2.4.0
   */
  @Since("2.4.0")
  def stageName(): String

  private[ml] def shortName(): String = s"${format()}+${stageName()}"
}

/**
 * Abstract class for utility classes that can save ML instances in Spark's internal format.
 */
@Since("1.6.0")
abstract class MLWriter extends BaseReadWrite with Logging {

  protected var shouldOverwrite: Boolean = false

  /**
   * Saves the ML instances to the input path.
   */
  @Since("1.6.0")
  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  def save(path: String): Unit = {
    new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sparkSession)
    saveImpl(path)
  }

  /**
   * `save()` handles overwriting and then calls this method.  Subclasses should override this
   * method to implement the actual saving of the instance.
   */
  @Since("1.6.0")
  protected def saveImpl(path: String): Unit

  /**
   * Overwrites if the output path already exists.
   */
  @Since("1.6.0")
  def overwrite(): this.type = {
    shouldOverwrite = true
    this
  }

  /**
   * Map to store extra options for this writer.
   */
  protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()

  /**
   * Adds an option to the underlying MLWriter. See the documentation for the specific model's
   * writer for possible options. The option name (key) is case-insensitive.
   */
  @Since("2.3.0")
  def option(key: String, value: String): this.type = {
    require(key != null && !key.isEmpty)
    optionMap.put(key.toLowerCase(Locale.ROOT), value)
    this
  }

  // override for Java compatibility
  @Since("1.6.0")
  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
}

/**
 * A ML Writer which delegates based on the requested format.
 */
@Unstable
@Since("2.4.0")
class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
  private var source: String = "internal"

  /**
   * Specifies the format of ML export (e.g. "pmml", "internal", or
   * the fully qualified class name for export).
   */
  @Since("2.4.0")
  def format(source: String): this.type = {
    this.source = source
    this
  }

  /**
   * Dispatches the save to the correct MLFormat.
   */
  @Since("2.4.0")
  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  @throws[SparkException]("If multiple sources for a given short name format are found.")
  override protected def saveImpl(path: String): Unit = {
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
    val stageName = stage.getClass.getName
    val targetName = s"$source+$stageName"
    val formats = serviceLoader.asScala.toList
    val shortNames = formats.map(_.shortName())
    val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
      // requested name did not match any given registered alias
      case Nil =>
        Try(loader.loadClass(source)) match {
          case Success(writer) =>
            // Found the ML writer using the fully qualified path
            writer
          case Failure(error) =>
            throw new SparkException(
              s"Could not load requested format $source for $stageName ($targetName) had $formats" +
              s"supporting $shortNames", error)
        }
      case head :: Nil =>
        head.getClass
      case _ =>
        // Multiple sources
        throw new SparkException(
          s"Multiple writers found for $source+$stageName, try using the class name of the writer")
    }
    if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
      val writer = writerCls.getConstructor().newInstance().asInstanceOf[MLWriterFormat]
      writer.write(path, sparkSession, optionMap, stage)
    } else {
      throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
    }
  }

  // override for Java compatibility
  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
}

/**
 * Trait for classes that provide `MLWriter`.
 */
@Since("1.6.0")
trait MLWritable {

  /**
   * Returns an `MLWriter` instance for this ML instance.
   */
  @Since("1.6.0")
  def write: MLWriter

  /**
   * Saves this ML instance to the input path, a shortcut of `write.save(path)`.
   */
  @Since("1.6.0")
  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  def save(path: String): Unit = write.save(path)
}

/**
 * Trait for classes that provide `GeneralMLWriter`.
 */
@Since("2.4.0")
@Unstable
trait GeneralMLWritable extends MLWritable {
  /**
   * Returns an `MLWriter` instance for this ML instance.
   */
  @Since("2.4.0")
  override def write: GeneralMLWriter
}

/**
 * Helper trait for making simple `Params` types writable.  If a `Params` class stores
 * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
 * a default implementation of writing saved instances of the class.
 * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle
 * [[org.apache.spark.sql.Dataset]].
 *
 * @see `DefaultParamsReadable`, the counterpart to this trait
 */
trait DefaultParamsWritable extends MLWritable { self: Params =>

  override def write: MLWriter = new DefaultParamsWriter(this)
}

/**
 * Abstract class for utility classes that can load ML instances.
 *
 * @tparam T ML instance type
 */
@Since("1.6.0")
abstract class MLReader[T] extends BaseReadWrite {

  /**
   * Loads the ML component from the input path.
   */
  @Since("1.6.0")
  def load(path: String): T

  // override for Java compatibility
  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
}

/**
 * Trait for objects that provide `MLReader`.
 *
 * @tparam T ML instance type
 */
@Since("1.6.0")
trait MLReadable[T] {

  /**
   * Returns an `MLReader` instance for this class.
   */
  @Since("1.6.0")
  def read: MLReader[T]

  /**
   * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
   *
   * @note Implementing classes should override this to be Java-friendly.
   */
  @Since("1.6.0")
  def load(path: String): T = read.load(path)
}


/**
 * Helper trait for making simple `Params` types readable.  If a `Params` class stores
 * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
 * a default implementation of reading saved instances of the class.
 * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle
 * [[org.apache.spark.sql.Dataset]].
 *
 * @tparam T ML instance type
 * @see `DefaultParamsWritable`, the counterpart to this trait
 */
trait DefaultParamsReadable[T] extends MLReadable[T] {

  override def read: MLReader[T] = new DefaultParamsReader[T]
}

/**
 * Default `MLWriter` implementation for transformers and estimators that contain basic
 * (json4s-serializable) params and no data. This will not handle more complex params or types with
 * data (e.g., models with coefficients).
 *
 * @param instance object to save
 */
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {

  override protected def saveImpl(path: String): Unit = {
    DefaultParamsWriter.saveMetadata(instance, path, sc)
  }
}

private[ml] object DefaultParamsWriter {

  /**
   * Saves metadata + Params to: path + "/metadata"
   *  - class
   *  - timestamp
   *  - sparkVersion
   *  - uid
   *  - defaultParamMap
   *  - paramMap
   *  - (optionally, extra metadata)
   *
   * @param extraMetadata  Extra metadata to be saved at same level as uid, paramMap, etc.
   * @param paramMap  If given, this is saved in the "paramMap" field.
   *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
   *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
   */
  def saveMetadata(
      instance: Params,
      path: String,
      sc: SparkContext,
      extraMetadata: Option[JObject] = None,
      paramMap: Option[JValue] = None): Unit = {
    val metadataPath = new Path(path, "metadata").toString
    val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
  }

  /**
   * Helper for [[saveMetadata()]] which extracts the JSON to save.
   * This is useful for ensemble models which need to save metadata for many sub-models.
   *
   * @see [[saveMetadata()]] for details on what this includes.
   */
  def getMetadataToSave(
      instance: Params,
      sc: SparkContext,
      extraMetadata: Option[JObject] = None,
      paramMap: Option[JValue] = None): String = {
    val uid = instance.uid
    val cls = instance.getClass.getName
    val params = instance.paramMap.toSeq
    val defaultParams = instance.defaultParamMap.toSeq
    val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
      p.name -> parse(p.jsonEncode(v))
    }.toList))
    val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) =>
      p.name -> parse(p.jsonEncode(v))
    }.toList)
    val basicMetadata = ("class" -> cls) ~
      ("timestamp" -> System.currentTimeMillis()) ~
      ("sparkVersion" -> sc.version) ~
      ("uid" -> uid) ~
      ("paramMap" -> jsonParams) ~
      ("defaultParamMap" -> jsonDefaultParams)
    val metadata = extraMetadata match {
      case Some(jObject) =>
        basicMetadata ~ jObject
      case None =>
        basicMetadata
    }
    val metadataJson: String = compact(render(metadata))
    metadataJson
  }
}

/**
 * Default `MLReader` implementation for transformers and estimators that contain basic
 * (json4s-serializable) params and no data. This will not handle more complex params or types with
 * data (e.g., models with coefficients).
 *
 * @tparam T ML instance type
 * TODO: Consider adding check for correct class name.
 */
private[ml] class DefaultParamsReader[T] extends MLReader[T] {

  override def load(path: String): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = Utils.classForName(metadata.className)
    val instance =
      cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
    metadata.getAndSetParams(instance)
    instance.asInstanceOf[T]
  }
}

private[ml] object DefaultParamsReader {

  /**
   * All info from metadata file.
   *
   * @param params  paramMap, as a `JValue`
   * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4,
   *                      this is `JNothing`.
   * @param metadata  All metadata, including the other fields
   * @param metadataJson  Full metadata file String (for debugging)
   */
  case class Metadata(
      className: String,
      uid: String,
      timestamp: Long,
      sparkVersion: String,
      params: JValue,
      defaultParams: JValue,
      metadata: JValue,
      metadataJson: String) {


    private def getValueFromParams(params: JValue): Seq[(String, JValue)] = {
      params match {
        case JObject(pairs) => pairs
        case _ =>
          throw new IllegalArgumentException(
            s"Cannot recognize JSON metadata: $metadataJson.")
      }
    }

    /**
     * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
     * This can be useful for getting a Param value before an instance of `Params`
     * is available. This will look up `params` first, if not existing then looking up
     * `defaultParams`.
     */
    def getParamValue(paramName: String): JValue = {
      implicit val format = DefaultFormats

      // Looking up for `params` first.
      var pairs = getValueFromParams(params)
      var foundPairs = pairs.filter { case (pName, jsonValue) =>
        pName == paramName
      }
      if (foundPairs.length == 0) {
        // Looking up for `defaultParams` then.
        pairs = getValueFromParams(defaultParams)
        foundPairs = pairs.filter { case (pName, jsonValue) =>
          pName == paramName
        }
      }
      assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" +
        s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))

      foundPairs.map(_._2).head
    }

    /**
     * Extract Params from metadata, and set them in the instance.
     * This works if all Params (except params included by `skipParams` list) implement
     * [[org.apache.spark.ml.param.Param.jsonDecode()]].
     *
     * @param skipParams The params included in `skipParams` won't be set. This is useful if some
     *                   params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
     *                   and need special handling.
     */
    def getAndSetParams(
        instance: Params,
        skipParams: Option[List[String]] = None): Unit = {
      setParams(instance, skipParams, isDefault = false)

      // For metadata file prior to Spark 2.4, there is no default section.
      val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion)
      if (major > 2 || (major == 2 && minor >= 4)) {
        setParams(instance, skipParams, isDefault = true)
      }
    }

    private def setParams(
        instance: Params,
        skipParams: Option[List[String]],
        isDefault: Boolean): Unit = {
      implicit val format = DefaultFormats
      val paramsToSet = if (isDefault) defaultParams else params
      paramsToSet match {
        case JObject(pairs) =>
          pairs.foreach { case (paramName, jsonValue) =>
            if (skipParams == None || !skipParams.get.contains(paramName)) {
              val param = instance.getParam(paramName)
              val value = param.jsonDecode(compact(render(jsonValue)))
              if (isDefault) {
                instance.setDefault(param, value)
              } else {
                instance.set(param, value)
              }
            }
          }
        case _ =>
          throw new IllegalArgumentException(
            s"Cannot recognize JSON metadata: ${metadataJson}.")
      }
    }
  }

  /**
   * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
   *
   * @param expectedClassName  If non empty, this is checked against the loaded metadata.
   * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
   */
  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
    val metadataPath = new Path(path, "metadata").toString
    val metadataStr = sc.textFile(metadataPath, 1).first()
    parseMetadata(metadataStr, expectedClassName)
  }

  /**
   * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
   * This is a helper function for [[loadMetadata()]].
   *
   * @param metadataStr  JSON string of metadata
   * @param expectedClassName  If non empty, this is checked against the loaded metadata.
   * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
   */
  def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
    val metadata = parse(metadataStr)

    implicit val format = DefaultFormats
    val className = (metadata \ "class").extract[String]
    val uid = (metadata \ "uid").extract[String]
    val timestamp = (metadata \ "timestamp").extract[Long]
    val sparkVersion = (metadata \ "sparkVersion").extract[String]
    val defaultParams = metadata \ "defaultParamMap"
    val params = metadata \ "paramMap"
    if (expectedClassName.nonEmpty) {
      require(className == expectedClassName, s"Error loading metadata: Expected class name" +
        s" $expectedClassName but found class name $className")
    }

    Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr)
  }

  /**
   * Load a `Params` instance from the given path, and return it.
   * This assumes the instance implements [[MLReadable]].
   */
  def loadParamsInstance[T](path: String, sc: SparkContext): T =
    loadParamsInstanceReader(path, sc).load(path)

  /**
   * Load a `Params` instance reader from the given path, and return it.
   * This assumes the instance implements [[MLReadable]].
   */
  def loadParamsInstanceReader[T](path: String, sc: SparkContext): MLReader[T] = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = Utils.classForName(metadata.className)
    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
  }
}

/**
 * Default Meta-Algorithm read and write implementation.
 */
private[ml] object MetaAlgorithmReadWrite {
  /**
   * Examine the given estimator (which may be a compound estimator) and extract a mapping
   * from UIDs to corresponding `Params` instances.
   */
  def getUidMap(instance: Params): Map[String, Params] = {
    val uidList = getUidMapImpl(instance)
    val uidMap = uidList.toMap
    if (uidList.size != uidMap.size) {
      throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
        s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
    }
    uidMap
  }

  private def getUidMapImpl(instance: Params): List[(String, Params)] = {
    val subStages: Array[Params] = instance match {
      case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
      case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
      case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
      case ovr: OneVsRest => Array(ovr.getClassifier)
      case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
      case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
      case _: Params => Array.empty[Params]
    }
    val subStageMaps = subStages.flatMap(getUidMapImpl)
    List((instance.uid, instance)) ++ subStageMaps
  }
}

private[ml] class FileSystemOverwrite extends Logging {

  def handleOverwrite(path: String, shouldOverwrite: Boolean, session: SparkSession): Unit = {
    val hadoopConf = session.sessionState.newHadoopConf()
    val outputPath = new Path(path)
    val fs = outputPath.getFileSystem(hadoopConf)
    val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
    if (fs.exists(qualifiedOutputPath)) {
      if (shouldOverwrite) {
        logInfo(s"Path $path already exists. It will be overwritten.")
        // TODO: Revert back to the original content if save is not successful.
        fs.delete(qualifiedOutputPath, true)
      } else {
        throw new IOException(s"Path $path already exists. To overwrite it, " +
          s"please use write.overwrite().save(path) for Scala and use " +
          s"write().overwrite().save(path) for Java and Python.")
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy