All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.util.ReadWrite.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml.util

import java.io.IOException
import java.util.{Locale, ServiceLoader}


import com.tencent.angel.sona.ml
import com.tencent.angel.sona.ml.classification.{OneVsRest, OneVsRestModel}
import com.tencent.angel.sona.ml.feature.RFormulaModel
import com.tencent.angel.sona.ml.param.{Param, ParamPair, Params}
import com.tencent.angel.sona.ml.tuning.ValidatorParams
import com.tencent.angel.sona.ml.{Pipeline, PipelineModel, PipelineStage}
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.util.SparkUtil
import org.apache.spark.{SparkContext, SparkException}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JObject, _}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

/**
  * Trait for `MLWriter` and `MLReader`.
  */
private[sona] sealed trait BaseReadWrite {
  private var optionSparkSession: Option[SparkSession] = None

  /**
    * Sets the Spark SQLContext to use for saving/loading.
    *
    * @deprecated Use session instead. This method will be removed in 3.0.0.
    */

  @deprecated("Use session instead. This method will be removed in 3.0.0.", "2.0.0")
  def context(sqlContext: SQLContext): this.type = {
    optionSparkSession = Option(sqlContext.sparkSession)
    this
  }

  /**
    * Sets the Spark Session to use for saving/loading.
    */

  def session(sparkSession: SparkSession): this.type = {
    optionSparkSession = Option(sparkSession)
    this
  }

  /**
    * Returns the user-specified Spark Session or the default.
    */
  protected final def sparkSession: SparkSession = {
    if (optionSparkSession.isEmpty) {
      optionSparkSession = Some(SparkSession.builder().getOrCreate())
    }
    optionSparkSession.get
  }

  /**
    * Returns the user-specified SQL context or the default.
    */
  protected final def sqlContext: SQLContext = sparkSession.sqlContext

  /** Returns the underlying `SparkContext`. */
  protected final def sc: SparkContext = sparkSession.sparkContext
}

/**
  * Abstract class to be implemented by objects that provide ML exportability.
  *
  * A new instance of this class will be instantiated each time a save call is made.
  *
  * Must have a valid zero argument constructor which will be called to instantiate.
  *
  * @since 2.4.0
  */

trait MLWriterFormat {
  /**
    * Function to write the provided pipeline stage out.
    *
    * @param path      The path to write the result out to.
    * @param session   SparkSession associated with the write request.
    * @param optionMap User provided options stored as strings.
    * @param stage     The pipeline stage to be saved.
    */

  def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
            stage: PipelineStage): Unit
}

/**
  * ML export formats for should implement this trait so that users can specify a shortname rather
  * than the fully qualified class name of the exporter.
  *
  * A new instance of this class will be instantiated each time a save call is made.
  *
  * @since 2.4.0
  */


trait MLFormatRegister extends MLWriterFormat {
  /**
    * The string that represents the format that this format provider uses. This is, along with
    * stageName, is overridden by children to provide a nice alias for the writer. For example:
    *
    * {{{
    *   override def format(): String =
    *       "pmml"
    * }}}
    * Indicates that this format is capable of saving a pmml model.
    *
    * Must have a valid zero argument constructor which will be called to instantiate.
    *
    * Format discovery is done using a ServiceLoader so make sure to list your format in
    * META-INF/services.
    *
    * @since 2.4.0
    */

  def format(): String

  /**
    * The string that represents the stage type that this writer supports. This is, along with
    * format, is overridden by children to provide a nice alias for the writer. For example:
    *
    * {{{
    *   override def stageName(): String =
    *       "org.apache.spark.ml.regression.LinearRegressionModel"
    * }}}
    * Indicates that this format is capable of saving Spark's own PMML model.
    *
    * Format discovery is done using a ServiceLoader so make sure to list your format in
    * META-INF/services.
    *
    * @since 2.4.0
    */

  def stageName(): String

  private[angel] def shortName(): String = s"${format()}+${stageName()}"
}

/**
  * Abstract class for utility classes that can save ML instances in Spark's internal format.
  */

abstract class MLWriter extends BaseReadWrite with Logging {

  protected var shouldOverwrite: Boolean = false

  /**
    * Saves the ML instances to the input path.
    */

  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  def save(path: String): Unit = {
    new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc)
    saveImpl(path)
  }

  /**
    * `save()` handles overwriting and then calls this method.  Subclasses should override this
    * method to implement the actual saving of the instance.
    */

  protected def saveImpl(path: String): Unit

  /**
    * Overwrites if the output path already exists.
    */

  def overwrite(): this.type = {
    shouldOverwrite = true
    this
  }

  /**
    * Map to store extra options for this writer.
    */
  protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()

  /**
    * Adds an option to the underlying MLWriter. See the documentation for the specific model's
    * writer for possible options. The option name (key) is case-insensitive.
    */

  def option(key: String, value: String): this.type = {
    require(key != null && !key.isEmpty)
    optionMap.put(key.toLowerCase(Locale.ROOT), value)
    this
  }

  // override for Java compatibility

  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)

  // override for Java compatibility

  override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}

/**
  * A ML Writer which delegates based on the requested format.
  */


class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
  private var source: String = "internal"

  /**
    * Specifies the format of ML export (e.g. "pmml", "internal", or
    * the fully qualified class name for export).
    */

  def format(source: String): this.type = {
    this.source = source
    this
  }

  /**
    * Dispatches the save to the correct MLFormat.
    */

  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  @throws[SparkException]("If multiple sources for a given short name format are found.")
  override protected def saveImpl(path: String): Unit = {
    val loader = SparkUtil.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
    val stageName = stage.getClass.getName
    val targetName = s"$source+$stageName"
    val formats = serviceLoader.asScala.toList
    val shortNames = formats.map(_.shortName())
    val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
      // requested name did not match any given registered alias
      case Nil =>
        Try(loader.loadClass(source)) match {
          case Success(writer) =>
            // Found the ML writer using the fully qualified path
            writer
          case Failure(error) =>
            throw new SparkException(
              s"Could not load requested format $source for $stageName ($targetName) had $formats" +
                s"supporting $shortNames", error)
        }
      case head :: Nil =>
        head.getClass
      case _ =>
        // Multiple sources
        throw new SparkException(
          s"Multiple writers found for $source+$stageName, try using the class name of the writer")
    }
    if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
      val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
      writer.write(path, sparkSession, optionMap, stage)
    } else {
      throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
    }
  }

  // override for Java compatibility
  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)

  // override for Java compatibility
  override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}

/**
  * Trait for classes that provide `MLWriter`.
  */

trait MLWritable {

  /**
    * Returns an `MLWriter` instance for this ML instance.
    */

  def write: MLWriter

  /**
    * Saves this ML instance to the input path, a shortcut of `write.save(path)`.
    */

  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
  def save(path: String): Unit = write.save(path)
}

/**
  * Trait for classes that provide `GeneralMLWriter`.
  */


trait GeneralMLWritable extends MLWritable {
  /**
    * Returns an `MLWriter` instance for this ML instance.
    */

  override def write: GeneralMLWriter
}

/**
  * :: DeveloperApi ::
  *
  * Helper trait for making simple `Params` types writable.  If a `Params` class stores
  * all data as [[Param]] values, then extending this trait will provide
  * a default implementation of writing saved instances of the class.
  * This only handles simple [[Param]] types; e.g., it will not handle
  * [[org.apache.spark.sql.Dataset]].
  *
  * @see `DefaultParamsReadable`, the counterpart to this trait
  */
@DeveloperApi
trait DefaultParamsWritable extends MLWritable {
  self: Params =>

  override def write: MLWriter = new DefaultParamsWriter(this)
}

/**
  * Abstract class for utility classes that can load ML instances.
  *
  * @tparam T ML instance type
  */

abstract class MLReader[T] extends BaseReadWrite {

  /**
    * Loads the ML component from the input path.
    */

  def load(path: String): T

  // override for Java compatibility
  override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)

  // override for Java compatibility
  override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}

/**
  * Trait for objects that provide `MLReader`.
  *
  * @tparam T ML instance type
  */

trait MLReadable[T] {

  /**
    * Returns an `MLReader` instance for this class.
    */

  def read: MLReader[T]

  /**
    * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
    *
    * @note Implementing classes should override this to be Java-friendly.
    */

  def load(path: String): T = read.load(path)
}


/**
  * :: DeveloperApi ::
  *
  * Helper trait for making simple `Params` types readable.  If a `Params` class stores
  * all data as [[Param]] values, then extending this trait will provide
  * a default implementation of reading saved instances of the class.
  * This only handles simple [[Param]] types; e.g., it will not handle
  * [[org.apache.spark.sql.Dataset]].
  *
  * @tparam T ML instance type
  * @see `DefaultParamsWritable`, the counterpart to this trait
  */
@DeveloperApi
trait DefaultParamsReadable[T] extends MLReadable[T] {

  override def read: MLReader[T] = new DefaultParamsReader[T]
}

/**
  * Default `MLWriter` implementation for transformers and estimators that contain basic
  * (json4s-serializable) params and no data. This will not handle more complex params or types with
  * data (e.g., models with coefficients).
  *
  * @param instance object to save
  */
private[angel] class DefaultParamsWriter(instance: Params) extends MLWriter {

  override protected def saveImpl(path: String): Unit = {
    DefaultParamsWriter.saveMetadata(instance, path, sc)
  }
}

private[angel] object DefaultParamsWriter {

  /**
    * Saves metadata + Params to: path + "/metadata"
    *  - class
    *  - timestamp
    *  - sparkVersion
    *  - uid
    *  - defaultParamMap
    *  - paramMap
    *  - (optionally, extra metadata)
    *
    * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
    * @param paramMap      If given, this is saved in the "paramMap" field.
    *                      Otherwise, all [[Param]]s are encoded using
    *                      [[ml.param.Param.jsonEncode()]].
    */
  def saveMetadata(
                    instance: Params,
                    path: String,
                    sc: SparkContext,
                    extraMetadata: Option[JObject] = None,
                    paramMap: Option[JValue] = None): Unit = {
    val metadataPath = new Path(path, "metadata").toString
    val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
  }

  /**
    * Helper for [[saveMetadata()]] which extracts the JSON to save.
    * This is useful for ensemble models which need to save metadata for many sub-models.
    *
    * @see [[saveMetadata()]] for details on what this includes.
    */
  def getMetadataToSave(
                         instance: Params,
                         sc: SparkContext,
                         extraMetadata: Option[JObject] = None,
                         paramMap: Option[JValue] = None): String = {
    val uid = instance.uid
    val cls = instance.getClass.getName
    val params = instance.paramMap.toSeq
    val defaultParams = instance.defaultParamMap.toSeq
    val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
      p.name -> parse(p.jsonEncode(v))
    }.toList))
    val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) =>
      p.name -> parse(p.jsonEncode(v))
    }.toList)
    val basicMetadata = ("class" -> cls) ~
      ("timestamp" -> System.currentTimeMillis()) ~
      ("sparkVersion" -> sc.version) ~
      ("uid" -> uid) ~
      ("paramMap" -> jsonParams) ~
      ("defaultParamMap" -> jsonDefaultParams)
    val metadata = extraMetadata match {
      case Some(jObject) =>
        basicMetadata ~ jObject
      case None =>
        basicMetadata
    }
    val metadataJson: String = compact(render(metadata))
    metadataJson
  }
}

/**
  * Default `MLReader` implementation for transformers and estimators that contain basic
  * (json4s-serializable) params and no data. This will not handle more complex params or types with
  * data (e.g., models with coefficients).
  *
  * @tparam T ML instance type
  *           TODO: Consider adding check for correct class name.
  */
private[angel] class DefaultParamsReader[T] extends MLReader[T] {

  override def load(path: String): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = SparkUtil.classForName(metadata.className)
    val instance =
      cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
    metadata.getAndSetParams(instance)
    instance.asInstanceOf[T]
  }
}

private[angel] object DefaultParamsReader {

  /**
    * All info from metadata file.
    *
    * @param params        paramMap, as a `JValue`
    * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4,
    *                      this is `JNothing`.
    * @param metadata      All metadata, including the other fields
    * @param metadataJson  Full metadata file String (for debugging)
    */
  case class Metadata(
                       className: String,
                       uid: String,
                       timestamp: Long,
                       sparkVersion: String,
                       params: JValue,
                       defaultParams: JValue,
                       metadata: JValue,
                       metadataJson: String) {


    private def getValueFromParams(params: JValue): Seq[(String, JValue)] = {
      params match {
        case JObject(pairs) => pairs
        case _ =>
          throw new IllegalArgumentException(
            s"Cannot recognize JSON metadata: $metadataJson.")
      }
    }

    /**
      * Get the JSON value of the [[Param]] of the given name.
      * This can be useful for getting a Param value before an instance of `Params`
      * is available. This will look up `params` first, if not existing then looking up
      * `defaultParams`.
      */
    def getParamValue(paramName: String): JValue = {
      implicit val format = DefaultFormats

      // Looking up for `params` first.
      var pairs = getValueFromParams(params)
      var foundPairs = pairs.filter { case (pName, jsonValue) =>
        pName == paramName
      }
      if (foundPairs.length == 0) {
        // Looking up for `defaultParams` then.
        pairs = getValueFromParams(defaultParams)
        foundPairs = pairs.filter { case (pName, jsonValue) =>
          pName == paramName
        }
      }
      assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" +
        s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))

      foundPairs.map(_._2).head
    }

    /**
      * Extract Params from metadata, and set them in the instance.
      * This works if all Params (except params included by `skipParams` list) implement
      * [[ml.param.Param.jsonDecode()]].
      *
      * @param skipParams The params included in `skipParams` won't be set. This is useful if some
      *                   params don't implement [[ml.param.Param.jsonDecode()]]
      *                   and need special handling.
      */
    def getAndSetParams(
                         instance: Params,
                         skipParams: Option[List[String]] = None): Unit = {
      setParams(instance, skipParams, isDefault = false)

      // For metadata file prior to Spark 2.4, there is no default section.
      val (major, minor) = SparkUtil.majorMinorVersion(sparkVersion)
      if (major > 2 || (major == 2 && minor >= 4)) {
        setParams(instance, skipParams, isDefault = true)
      }
    }

    private def setParams(
                           instance: Params,
                           skipParams: Option[List[String]],
                           isDefault: Boolean): Unit = {
      implicit val format = DefaultFormats
      val paramsToSet = if (isDefault) defaultParams else params
      paramsToSet match {
        case JObject(pairs) =>
          pairs.foreach { case (paramName, jsonValue) =>
            if (skipParams == None || !skipParams.get.contains(paramName)) {
              val param = instance.getParam(paramName)
              val value = param.jsonDecode(compact(render(jsonValue)))
              if (isDefault) {
                Params.setDefault(instance, param, value)
              } else {
                instance.set(param, value)
              }
            }
          }
        case _ =>
          throw new IllegalArgumentException(
            s"Cannot recognize JSON metadata: ${metadataJson}.")
      }
    }
  }

  /**
    * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
    *
    * @param expectedClassName If non empty, this is checked against the loaded metadata.
    * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
    */
  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
    val metadataPath = new Path(path, "metadata").toString
    val metadataStr = sc.textFile(metadataPath, 1).first()
    parseMetadata(metadataStr, expectedClassName)
  }

  /**
    * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
    * This is a helper function for [[loadMetadata()]].
    *
    * @param metadataStr       JSON string of metadata
    * @param expectedClassName If non empty, this is checked against the loaded metadata.
    * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
    */
  def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
    val metadata = parse(metadataStr)

    implicit val format = DefaultFormats
    val className = (metadata \ "class").extract[String]
    val uid = (metadata \ "uid").extract[String]
    val timestamp = (metadata \ "timestamp").extract[Long]
    val sparkVersion = (metadata \ "sparkVersion").extract[String]
    val defaultParams = metadata \ "defaultParamMap"
    val params = metadata \ "paramMap"
    if (expectedClassName.nonEmpty) {
      require(className == expectedClassName, s"Error loading metadata: Expected class name" +
        s" $expectedClassName but found class name $className")
    }

    Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr)
  }

  /**
    * Load a `Params` instance from the given path, and return it.
    * This assumes the instance implements [[MLReadable]].
    */
  def loadParamsInstance[T](path: String, sc: SparkContext): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = SparkUtil.classForName(metadata.className)
    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
  }
}

/**
  * Default Meta-Algorithm read and write implementation.
  */
private[angel] object MetaAlgorithmReadWrite {
  /**
    * Examine the given estimator (which may be a compound estimator) and extract a mapping
    * from UIDs to corresponding `Params` instances.
    */
  def getUidMap(instance: Params): Map[String, Params] = {
    val uidList = getUidMapImpl(instance)
    val uidMap = uidList.toMap
    if (uidList.size != uidMap.size) {
      throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
        s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
    }
    uidMap
  }

  private def getUidMapImpl(instance: Params): List[(String, Params)] = {
    val subStages: Array[Params] = instance match {
      case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
      case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
      case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
      case ovr: OneVsRest => Array(ovr.getClassifier)
      case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
      case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
      case _: Params => Array.empty[Params]
    }
    val subStageMaps = subStages.flatMap(getUidMapImpl)
    List((instance.uid, instance)) ++ subStageMaps
  }
}

private[angel] class FileSystemOverwrite extends Logging {

  def handleOverwrite(path: String, shouldOverwrite: Boolean, sc: SparkContext): Unit = {
    val hadoopConf = sc.hadoopConfiguration
    val outputPath = new Path(path)
    val fs = outputPath.getFileSystem(hadoopConf)
    val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
    if (fs.exists(qualifiedOutputPath)) {
      if (shouldOverwrite) {
        logInfo(s"Path $path already exists. It will be overwritten.")
        // TODO: Revert back to the original content if save is not successful.
        fs.delete(qualifiedOutputPath, true)
      } else {
        throw new IOException(s"Path $path already exists. To overwrite it, " +
          s"please use write.overwrite().save(path) for Scala and use " +
          s"write().overwrite().save(path) for Java and Python.")
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy