All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.classification.SmileClassifier.scala Maven / Gradle / Ivy

/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package org.apache.spark.ml.classification

import java.io.{ObjectInputStream, ObjectOutputStream}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.ml.classification.{ClassificationModel => SparkClassificationModel, Classifier => SparkClassifier}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Instrumentation._
import org.apache.spark.ml.util.{Identifiable, _}
import org.apache.spark.sql.Dataset
import org.apache.spark.storage.StorageLevel
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JObject}

/**
  * Params for SmileClassifier
  */
private[ml] trait SmileClassifierParams extends Params with ClassifierParams {

  type Trainer = (Array[Array[Double]], Array[Int]) => smile.classification.Classifier[Array[Double]]

  /**
    * Param for the smile classification trainer
    *
    * @group param
    */
  val trainer: Param[Trainer] = new Param[Trainer](this, "trainer", "Smile classifier trainer")

  /** @group getParam */
  def getTrainer: Trainer = $(trainer)

}

private[ml] object SmileClassifierParams {
  def saveImpl(instance: SmileClassifierParams, path: String, sc: SparkContext, extraMetadata: Option[JObject] = None): Unit = {
    val params = instance.extractParamMap().toSeq
    val jsonParams = render(
      params
        .filter { case ParamPair(p, _) => p.name != "trainer" }
        .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
        .toList
    )

    DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
  }

  def loadImpl(path: String, sc: SparkContext, expectedClassName: String): DefaultParamsReader.Metadata = {
    DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
  }
}

/**
  * A Spark Estimator based on Smile's classification algorithms.
  * It allows to add a Smile model into a Spark MLLib Pipeline.
  *
  * @note SmileClassifier will collect the training Dataset
  *       to the Spark Driver but train the model on a Spark Executor.
  */
class SmileClassifier(override val uid: String)
  extends SparkClassifier[Vector, SmileClassifier, SmileClassificationModel]
    with SmileClassifierParams
    with MLWritable {

  def this() = this(Identifiable.randomUID("SmileClassifier"))

  /** @group setParam */
  def setTrainer(value: (Array[Array[Double]], Array[Int]) => smile.classification.Classifier[Array[Double]]): this.type =
    set(trainer, value.asInstanceOf[Trainer])

  override protected def train(dataset: Dataset[_]): SmileClassificationModel =
    instrumented { instr =>
      instr.logPipelineStage(this)
      instr.logDataset(dataset)
      instr.logParams(this, labelCol, featuresCol, predictionCol)

      val spark = dataset.sparkSession
      val sc = spark.sparkContext

      val df = dataset.select($(labelCol), $(featuresCol))

      val persist = dataset.storageLevel == StorageLevel.NONE && (df.storageLevel == StorageLevel.NONE)
      if (persist) df.persist(StorageLevel.MEMORY_AND_DISK)

      val x = df.select(getFeaturesCol).collect().map(row => row.getAs[Vector](0).toArray)
      val y = df.select(getLabelCol).collect().map(row => row.getDouble(0).toInt)

      val trainersRDD = sc.parallelize(Seq(getTrainer))
      val model = trainersRDD.map(_.apply(x, y)).collect()(0)
      val numClasses = getNumClasses(df)

      if (persist) df.unpersist()

      new SmileClassificationModel(numClasses, model)
    }

  override def copy(extra: ParamMap): SmileClassifier = {
    val copy = new SmileClassifier(uid)
    copyValues(copy, extra)
    copy.setTrainer(getTrainer)
  }

  override def write: MLWriter = new SmileClassifier.SmileClassifierWriter(this)
}

object SmileClassifier extends MLReadable[SmileClassifier] {

  override def read: MLReader[SmileClassifier] = new SmileClassifierReader

  override def load(path: String): SmileClassifier = super.load(path)

  /** [[MLWriter]] instance for [[SmileClassifier]] */
  private[SmileClassifier] class SmileClassifierWriter(instance: SmileClassifier) extends MLWriter {
    override protected def saveImpl(path: String): Unit = {
      SmileClassifierParams.saveImpl(instance, path, sc)
    }
  }

  /** [[MLReader]] instance for [[SmileClassifier]] */
  private class SmileClassifierReader extends MLReader[SmileClassifier] {
    /** Checked against metadata when loading model */
    private val className = classOf[SmileClassifier].getName

    override def load(path: String): SmileClassifier = {
      val metadata = SmileClassifierParams.loadImpl(path, sc, className)
      val bc = new SmileClassifier(metadata.uid)
      metadata.getAndSetParams(bc)
      bc
    }
  }
}

/**
  * Model produced by [[SmileClassifier]].
  * This stores the models resulting from training the smile classifier.
  *
  * @param model smile classifier
  */
class SmileClassificationModel(override val uid: String,
                               override val numClasses: Int,
                               val model: smile.classification.Classifier[Array[Double]])
  extends SparkClassificationModel[Vector, SmileClassificationModel]
    with SmileClassifierParams
    with MLWritable {

  def this(numClasses: Int, model: smile.classification.Classifier[Array[Double]]) =
    this(Identifiable.randomUID("SmileClassificationModel"), numClasses, model)

  override def predictRaw(features: Vector): Vector = {
    // The Spark API predictRaw function outputs a Vector with
    // "confidence scores" for each class.
    val posteriori = Array.ofDim[Double](numClasses)
    if (model.soft()) {
      model.predict(features.toArray, posteriori)
    } else {
        // The "hard" confidence for the predicted class.
      posteriori(model.predict(features.toArray)) = 1.0
    }

    Vectors.dense(posteriori)
  }

  override def copy(extra: ParamMap): SmileClassificationModel = {
    val copy = new SmileClassificationModel(uid, numClasses, model)
    copyValues(copy, extra).setParent(parent)
  }

  override def write: MLWriter = new SmileClassificationModel.SmileClassificationModelWriter(this)
}

object SmileClassificationModel extends MLReadable[SmileClassificationModel] {
  override def read: MLReader[SmileClassificationModel] = new SmileClassificationModelReader

  override def load(path: String): SmileClassificationModel = super.load(path)

  /** [[MLWriter]] instance for [[SmileClassificationModel]] */
  private[SmileClassificationModel] class SmileClassificationModelWriter(instance: SmileClassificationModel) extends MLWriter {
    override protected def saveImpl(path: String): Unit = {
      val extraJson = "numClasses" -> instance.numClasses
      val fs = FileSystem.get(sc.hadoopConfiguration)
      try {
        val fileOut = fs.create(new Path(path, "model"))
        val objectOut = new ObjectOutputStream(fileOut)
        objectOut.writeObject(instance.model)
        objectOut.close()
      } catch {
        case ex: Exception => ex.printStackTrace()
      }

      SmileClassifierParams.saveImpl(instance, path, sc, Some(extraJson))
    }
  }

  /** [[MLReader]] instance for [[SmileClassificationModel]] */
  private class SmileClassificationModelReader extends MLReader[SmileClassificationModel] {
    /** Checked against metadata when loading model */
    private val className = classOf[SmileClassificationModel].getName

    override def load(path: String): SmileClassificationModel = {
      val metadata = SmileClassifierParams.loadImpl(path, sc, className)
      val fs = FileSystem.get(sc.hadoopConfiguration)
      implicit val format: DefaultFormats = DefaultFormats
      val numClasses = (metadata.metadata \ "numClasses").extract[Int]
      val impl = try {
        val fileIn = fs.open(new Path(path, "model"))
        val objectIn = new ObjectInputStream(fileIn)
        val obj = objectIn.readObject
        objectIn.close()
        obj
      } catch {
        case ex: Exception =>
          ex.printStackTrace()
          return null
      }

      val model = new SmileClassificationModel(
        metadata.uid,
        numClasses,
        impl.asInstanceOf[smile.classification.Classifier[Array[Double]]]
      )

      metadata.getAndSetParams(model)
      model
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy