All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.Pipeline.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tencent.angel.sona.ml

import java.{util => ju}

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import com.tencent.angel.sona.ml.param.{Param, ParamMap, Params}
import org.apache.spark.internal.Logging
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType


/**
  * :: DeveloperApi ::
  * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
  */
abstract class PipelineStage extends Params with Logging {

  /**
    * :: DeveloperApi ::
    *
    * Check transform validity and derive the output schema from the input schema.
    *
    * We check validity for interactions between parameters during `transformSchema` and
    * raise an exception if any parameter value is invalid. Parameter value checks which
    * do not depend on other parameters are handled by `Param.validate()`.
    *
    * Typical implementation should first conduct verification on schema change and parameter
    * validity, including complex parameter interaction checks.
    */
  
  def transformSchema(schema: StructType): StructType

  /**
    * :: DeveloperApi ::
    *
    * Derives the output schema from the input schema and parameters, optionally with logging.
    *
    * This should be optimistic.  If it is unclear whether the schema will be valid, then it should
    * be assumed valid until proven otherwise.
    */
  
  protected def transformSchema(
                                 schema: StructType,
                                 logging: Boolean): StructType = {
    if (logging) {
      logDebug(s"Input schema: ${schema.json}")
    }
    val outputSchema = transformSchema(schema)
    if (logging) {
      logDebug(s"Expected output schema: ${outputSchema.json}")
    }
    outputSchema
  }

  override def copy(extra: ParamMap): PipelineStage
}

/**
  * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
  * of which is either an [[Estimator]] or a [[Transformer]]. When `Pipeline.fit` is called, the
  * stages are executed in order. If a stage is an [[Estimator]], its `Estimator.fit` method will
  * be called on the input dataset to fit a model. Then the model, which is a transformer, will be
  * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]],
  * its `Transformer.transform` method will be called to produce the dataset for the next stage.
  * The fitted model from a [[Pipeline]] is a [[PipelineModel]], which consists of fitted models and
  * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
  * an identity transformer.
  */

class Pipeline(
                override val uid: String) extends Estimator[PipelineModel] with MLWritable {


  def this() = this(Identifiable.randomUID("pipeline"))

  /**
    * param for pipeline stages
    *
    * @group param
    */

  val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")

  /** @group setParam */

  def setStages(value: Array[_ <: PipelineStage]): this.type = {
    set(stages, value.asInstanceOf[Array[PipelineStage]])
    this
  }

  // Below, we clone stages so that modifications to the list of stages will not change
  // the Param value in the Pipeline.
  /** @group getParam */

  def getStages: Array[PipelineStage] = $(stages).clone()

  /**
    * Fits the pipeline to the input dataset with additional parameters. If a stage is an
    * [[Estimator]], its `Estimator.fit` method will be called on the input dataset to fit a model.
    * Then the model, which is a transformer, will be used to transform the dataset as the input to
    * the next stage. If a stage is a [[Transformer]], its `Transformer.transform` method will be
    * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an
    * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the
    * pipeline stages. If there are no stages, the output model acts as an identity transformer.
    *
    * @param dataset input dataset
    * @return fitted pipeline
    */

  override def fit(dataset: Dataset[_]): PipelineModel = {
    transformSchema(dataset.schema, logging = true)
    val theStages = $(stages)
    // Search for the last estimator.
    var indexOfLastEstimator = -1
    theStages.zipWithIndex.foreach { case (stage, index) =>
      stage match {
        case _: Estimator[_] =>
          indexOfLastEstimator = index
        case _ =>
      }
    }
    var curDataset = dataset
    val transformers = ListBuffer.empty[Transformer]
    theStages.zipWithIndex.foreach { case (stage, index) =>
      if (index <= indexOfLastEstimator) {
        val transformer = stage match {
          case estimator: Estimator[_] =>
            estimator.fit(curDataset)
          case t: Transformer =>
            t
          case _ =>
            throw new IllegalArgumentException(
              s"Does not support stage $stage of type ${stage.getClass}")
        }
        if (index < indexOfLastEstimator) {
          curDataset = transformer.transform(curDataset)
        }
        transformers += transformer
      } else {
        transformers += stage.asInstanceOf[Transformer]
      }
    }

    new PipelineModel(uid, transformers.toArray).setParent(this)
  }


  override def copy(extra: ParamMap): Pipeline = {
    val map = extractParamMap(extra)
    val newStages = map(stages).map(_.copy(extra))
    new Pipeline(uid).setStages(newStages)
  }


  override def transformSchema(schema: StructType): StructType = {
    val theStages = $(stages)
    require(theStages.toSet.size == theStages.length,
      "Cannot have duplicate components in a pipeline.")
    theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
  }


  override def write: MLWriter = new Pipeline.PipelineWriter(this)
}


object Pipeline extends MLReadable[Pipeline] {


  override def read: MLReader[Pipeline] = new PipelineReader


  override def load(path: String): Pipeline = super.load(path)

  private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter {

    SharedReadWrite.validateStages(instance.getStages)

    override protected def saveImpl(path: String): Unit =
      SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
  }

  private class PipelineReader extends MLReader[Pipeline] {

    /** Checked against metadata when loading model */
    private val className = classOf[Pipeline].getName

    override def load(path: String): Pipeline = {
      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
      new Pipeline(uid).setStages(stages)
    }
  }

  /**
    * Methods for `MLReader` and `MLWriter` shared between [[Pipeline]] and [[PipelineModel]]
    */
  private[sona] object SharedReadWrite {

    import org.json4s.JsonDSL._

    /** Check that all stages are Writable */
    def validateStages(stages: Array[PipelineStage]): Unit = {
      stages.foreach {
        case stage: MLWritable => // good
        case other =>
          throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
            s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
            s" ${other.uid} of type ${other.getClass}")
      }
    }

    /**
      * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
      *  - save metadata to path/metadata
      *  - save stages to stages/IDX_UID
      */
    def saveImpl(
                  instance: Params,
                  stages: Array[PipelineStage],
                  sc: SparkContext,
                  path: String): Unit = {
      val stageUids = stages.map(_.uid)
      val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
      DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))

      // Save stages
      val stagesDir = new Path(path, "stages").toString
      stages.zipWithIndex.foreach { case (stage, idx) =>
        stage.asInstanceOf[MLWritable].write.save(
          getStagePath(stage.uid, idx, stages.length, stagesDir))
      }
    }

    /**
      * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
      *
      * @return (UID, list of stages)
      */
    def load(
              expectedClassName: String,
              sc: SparkContext,
              path: String): (String, Array[PipelineStage]) = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)

      implicit val format = DefaultFormats
      val stagesDir = new Path(path, "stages").toString
      val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
      val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
        val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
        DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
      }
      (metadata.uid, stages)
    }

    /** Get path for saving the given stage. */
    def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
      val stageIdxDigits = numStages.toString.length
      val idxFormat = s"%0${stageIdxDigits}d"
      val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
      new Path(stagesDir, stageDir).toString
    }
  }

}

/**
  * Represents a fitted pipeline.
  */

class PipelineModel private[angel](
                                    override val uid: String,
                                    val stages: Array[Transformer])
  extends Model[PipelineModel] with MLWritable with Logging {

  /** A Java/Python-friendly auxiliary constructor. */
  private[sona] def this(uid: String, stages: ju.List[Transformer]) = {
    this(uid, stages.asScala.toArray)
  }


  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
  }


  override def transformSchema(schema: StructType): StructType = {
    stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
  }


  override def copy(extra: ParamMap): PipelineModel = {
    new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
  }


  override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
}


object PipelineModel extends MLReadable[PipelineModel] {

  import Pipeline.SharedReadWrite


  override def read: MLReader[PipelineModel] = new PipelineModelReader


  override def load(path: String): PipelineModel = super.load(path)

  private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {

    SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])

    override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
      instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
  }

  private class PipelineModelReader extends MLReader[PipelineModel] {

    /** Checked against metadata when loading model */
    private val className = classOf[PipelineModel].getName

    override def load(path: String): PipelineModel = {
      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
      val transformers = stages map {
        case stage: Transformer => stage
        case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
          s" was not a Transformer.  Bad stage ${other.uid} of type ${other.getClass}")
      }
      new PipelineModel(uid, transformers)
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy