All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.parity.SparkParityBase.scala Maven / Gradle / Ivy

The newest version!
package org.apache.spark.ml.parity

import java.io.File
import java.nio.file.{Files, Path}
import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.scalatest.BeforeAndAfterAll
import ml.combust.mleap.runtime.MleapSupport._
import ml.combust.bundle.BundleFile
import ml.combust.bundle.serializer.SerializationFormat
import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types.{DataType, NodeShape, TensorType}
import ml.combust.mleap.runtime.frame.{BaseTransformer, MultiTransformer, SimpleTransformer}
import ml.combust.mleap.runtime.{MleapContext, frame}
import ml.combust.mleap.runtime.function.UserDefinedFunction
import org.apache.spark.ml.bundle.SparkBundleContext
import ml.combust.mleap.spark.SparkSupport._
import ml.combust.mleap.runtime.transformer.Pipeline

import scala.util.Using
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.Row
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.TestingUtils._
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers._

/**
  * Created by hollinwilkins on 10/30/16.
  */
object SparkParityBase extends AnyFunSpec {
  val sparkRegistry = SparkBundleContext.defaultContext
  val mleapRegistry = MleapContext.defaultContext

  def textDataset(spark: SparkSession): DataFrame = {
    spark.sqlContext.read.text(getClass.getClassLoader.getResource("datasources/carroll-alice.txt").toString).
      withColumnRenamed("value", "text")
  }
  def dataset(spark: SparkSession): DataFrame = {
    spark.sqlContext.read.format("avro").load(getClass.getClassLoader.getResource("datasources/lending_club_sample.avro").toString)
  }

  def multiClassClassificationDataset(spark: SparkSession): DataFrame = {
    spark.sqlContext.read.format("libsvm").load(getClass.getClassLoader.getResource("datasources/sample_multiclass_classification_data.txt").toString)
  }

  case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
  def parseRating(str: String): Rating = {
    val fields = str.split("::")
    assert(fields.size == 4)
    Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
  }

  def recommendationDataset(spark: SparkSession): DataFrame = {
    import spark.implicits._
    spark.read.textFile(this.getClass.getClassLoader.getResource("datasources/sample_movielens_ratings.txt").toString)
                         .map(parseRating)
                         .toDF()
  }
}


object SparkEnv {
  lazy val spark = {
    val session = SparkSession.builder()
      .appName("Spark/MLeap Parity Tests")
      .config("spark.ui.enabled", "false")
      .master("local[2]")
      .getOrCreate()
    session.sparkContext.setLogLevel("WARN")
    session
  }
}


abstract class SparkParityBase extends AnyFunSpec with BeforeAndAfterAll {
  lazy val baseDataset: DataFrame = SparkParityBase.dataset(spark)
  lazy val textDataset: DataFrame = SparkParityBase.textDataset(spark)
  lazy val recommendationDataset: DataFrame = SparkParityBase.recommendationDataset(spark)
  lazy val multiClassClassificationDataset: DataFrame = SparkParityBase.multiClassClassificationDataset(spark)

  val dataset: DataFrame
  val sparkTransformer: Transformer

  val spark = SparkEnv.spark

  var bundleCache: Option[File] = None

  def serializedModel(transformer: Transformer)
                     (implicit context: SparkBundleContext): File = {
    bundleCache.getOrElse {

      val tempDirPath = {
        val temp: Path = Files.createTempDirectory("mleap-spark-parity")
        temp.toFile.deleteOnExit()
        temp.toAbsolutePath
      }

      val file = new File(s"${tempDirPath}/${getClass.getName}.zip")

      Using(BundleFile(file)) { bf =>
        transformer.writeBundle.format(SerializationFormat.Json).save(bf)
      }.flatten

      bundleCache = Some(file)
      file
    }
  }

  def mleapTransformer(transformer: Transformer)
                      (implicit context: SparkBundleContext): frame.Transformer = {
    Using(BundleFile(serializedModel(transformer))) { bf =>
      bf.loadMleapBundle()
    }.flatten.get.root
  }

  def deserializedSparkTransformer(transformer: Transformer)
                                  (implicit context: SparkBundleContext): Transformer = {
    Using(BundleFile(serializedModel(transformer))) { bf =>
      bf.loadSparkBundle()
    }.flatten.get.root
  }

  def assertModelTypesMatchTransformerTypes(model: Model, shape: NodeShape, exec: UserDefinedFunction): Unit = {
    val inputFields = model.inputSchema.fields.map(_.name)
    val modelInputTypes = model.inputSchema.fields.map(_.dataType)
    val transformerInputTypes = exec.inputs.flatMap(_.dataTypes)

    val outputFields = model.outputSchema.fields.map(_.name)
    val modelOutputTypes = model.outputSchema.fields.map(_.dataType)
    val transformerOutputTypes = exec.outputTypes

    checkTypes(modelInputTypes, transformerInputTypes, inputFields)
    checkTypes(modelOutputTypes, transformerOutputTypes, outputFields)
  }

  def checkTypes(modelTypes: Seq[DataType], transformerTypes: Seq[DataType], fields: Seq[String]): Unit = {
    assert(modelTypes.size == modelTypes.size)
    modelTypes.zip(transformerTypes).zip(fields).foreach {
      case ((modelType, transformerType), field) => {
        if (modelType.isInstanceOf[TensorType]) {
          assert(
            transformerType.isInstanceOf[TensorType] && modelType.base == transformerType.base,
            s"Type of ${field} does not match, $transformerType")
        } else {
          assert(modelType == transformerType, s"Type of ${field} does not match")
        }
      }
    }
  }

  def checkRowWithRelTol(actualAnswer: Row, expectedAnswer: Row, eps: Double): Unit = {
    assert(actualAnswer.length == expectedAnswer.length,
      s"actual answer length ${actualAnswer.length} != " +
        s"expected answer length ${expectedAnswer.length}")
    var rowIdx = 0
    actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach {
      case (actual: Double, expected: Double) =>
        assert(actual ~== expected relTol eps)
        rowIdx += 1
      case (actual: Float, expected: Float) =>
        assert(actual ~== expected relTol eps)
        rowIdx += 1
      case (actual: Vector, expected: Vector) =>
        assert(actual ~= expected relTol eps)
        rowIdx += 1
      case (actual: Seq[_], expected: Seq[_]) =>
        assert(actual.length == expected.length, s"actual length ${actual.length} != " +
          s"expected length ${expected.length}")
        actual.zip(expected).foreach {
          case (actualElem: Double, expectedElem: Double) =>
            assert(actualElem ~== expectedElem relTol eps)
          case (actualElem: Float, expectedElem: Float) =>
            assert(actualElem ~== expectedElem relTol eps)
          case (actualElem, expectedElem) =>
            withClue(s"Field ${actualAnswer.schema(rowIdx)} differs.") {
              actualElem shouldBe expectedElem
            }
        }
        rowIdx += 1
      case (actual: Row, expected: Row) =>
        checkRowWithRelTol(actual, expected, eps)
        rowIdx += 1
      case (actual, expected) =>
        assert(actual == expected, s"$actual did not equal $expected")
        rowIdx += 1
    }
  }

  var relTolEps: Double = 1E-6
  def equalityTest(sparkDataset: DataFrame, mleapDataset: DataFrame): Unit = {
    val sparkCols = sparkDataset.columns.toSeq
    assert(mleapDataset.columns.toSet === sparkCols.toSet)
    val sparkRows = sparkDataset.collect()
    val mleapRows = mleapDataset.select(sparkCols.map(col): _*).collect()
    for ((sparkRow, mleapRow) <- sparkRows.zip(mleapRows)) {
      checkRowWithRelTol(sparkRow, mleapRow, relTolEps)
    }
  }

  def checkParamsEquality(original: Transformer, deserialized: Transformer, additionalIgnore: Set[String]): Unit = {
    val ignoredParams = unserializedParams.union(additionalIgnore)
    assert(original.params.length == deserialized.params.length)
    original.params.zip(deserialized.params).foreach {
      case (param1, param2) => if(!ignoredParams.contains(param1.name)) {
        assert(original.isDefined(param1) == deserialized.isDefined(param2),
          s"spark transformer param ${param1.name} is defined ${original.isDefined(param1)} deserialized is ${deserialized.isDefined(param2)}")

        if (original.isDefined(param1)) {
          val v1Value = original.getOrDefault(param1)
          val v2Value = deserialized.getOrDefault(param2)

          v1Value match {
            case v1Value: Array[_] => assert(v1Value sameElements v2Value.asInstanceOf[Array[_]])
            case _ => assert(v1Value == v2Value, s"$param1 is not equivalent")
          }
        }
      }
    }
  }

  def checkEquality(original: Transformer, deserialized: Transformer, additionalIgnoreParams: Set[String]): Unit = {
    assert(original.getClass == deserialized.getClass)
    assert(original.uid == deserialized.uid)
    checkParamsEquality(original, deserialized, additionalIgnoreParams)
    original match {
      case original: PipelineModel =>
        val deStages = deserialized.asInstanceOf[PipelineModel].stages
        assert(original.stages.length == deStages.length)
        original.stages.zip(deStages).foreach {
          case (o, d) => checkEquality(o, d, additionalIgnoreParams)
        }
      case _ =>
    }
  }

  val excludedColsForComparison = Array[String]()

  def parityTransformer(): Unit = {
    it("has parity between Spark/MLeap") {
      val sparkTransformed = sparkTransformer.transform(dataset)
      implicit val sbc = SparkBundleContext().withDataset(sparkTransformed)
      val mTransformer = mleapTransformer(sparkTransformer)
      val sparkDataset = sparkTransformed.toSparkLeapFrame.toSpark.drop(this.excludedColsForComparison: _*)
      val mleapDataset = mTransformer.sparkTransform(dataset).drop(this.excludedColsForComparison: _*)

      equalityTest(sparkDataset, mleapDataset)
    }

    it("serializes/deserializes the Spark model properly") {
      if (!ignoreSerializationTest) {
        checkEquality(sparkTransformer, deserializedSparkTransformer(sparkTransformer), Set())
      }
    }

    it("model input/output schema matches transformer UDF") {
      val mTransformer = mleapTransformer(sparkTransformer)
      mTransformer match {
        case transformer: SimpleTransformer =>
          assertModelTypesMatchTransformerTypes(transformer.model, transformer.shape, transformer.typedExec)
        case transformer: MultiTransformer =>
          assertModelTypesMatchTransformerTypes(transformer.model, transformer.shape, transformer.typedExec)
        case transformer: BaseTransformer =>
          assertModelTypesMatchTransformerTypes(transformer.model, transformer.shape, transformer.exec)
        case pipeline: Pipeline =>
          pipeline.transformers.foreach {
            case transformer: SimpleTransformer =>
              assertModelTypesMatchTransformerTypes(transformer.model, transformer.shape, transformer.typedExec)
            case transformer: MultiTransformer =>
              assertModelTypesMatchTransformerTypes(transformer.model, transformer.shape, transformer.typedExec)
            case stage: BaseTransformer =>
              assertModelTypesMatchTransformerTypes(stage.model, stage.shape, stage.exec)
            case _ => // no udf to check against
          }
        case _ => // no udf to check against
      }
   }
  }

  /**
    * Params that are only relevant during training and are not serialized
    */
  protected val unserializedParams: Set[String] = Set.empty

  /**
    * Can be set to true for models that are not serialized
    */
  protected val ignoreSerializationTest: Boolean = false

  it should behave like parityTransformer()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy