All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.holdenkarau.spark.testing.DataFrameSuiteBase.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.holdenkarau.spark.testing

import java.io.File
import java.sql.Timestamp
import java.time.Duration

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.internal.SQLConf
import org.json4s._
import org.json4s.jackson.Serialization
import org.json4s.jackson.JsonMethods._
import org.scalactic.source
import org.scalatest.Suite
import org.scalatest.Tag
import org.scalatest.funsuite.AnyFunSuite

import scala.math.abs

/**
 * Base trait for testing Spark DataFrames in Scala.
 */
trait ScalaDataFrameSuiteBase extends AnyFunSuite with DataFrameSuiteBase {
  /*
   * If you need test your function with both codegen and non-codegen paths. This should be relatively
   * rare unless you are writing your own Spark expressions (w/ custom codegen).
   * This is taken from the "test" function inside of the PlanTest trait in SparkSQL.
   */
  def testCombined(
      testName: String,
      testTags: Tag*)(testFun: => Any)(implicit pos: source.Position): Unit = {
    System.setProperty("SPARK_TESTING", "yes") // codegen modes are not always respected
    val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
    val interpretedMode = CodegenObjectFactoryMode.NO_CODEGEN.toString

    test(testName + " (codegen path)", testTags: _*)(
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { testFun })(pos)
    test(testName + " (interpreted path)", testTags: _*)(
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode,
        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
        SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "0",
        SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "0"
      ) { testFun })(pos)
  }
  def testCodegenOnly(
      testName: String,
      testTags: Tag*)(testFun: => Any)(implicit pos: source.Position): Unit = {
    System.setProperty("SPARK_TESTING", "yes") // codegen modes are not always respected
    val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString

    test(testName + " (codegen path)", testTags: _*)(
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { testFun })(pos)
  }
  def testNonCodegen(
      testName: String,
      testTags: Tag*)(testFun: => Any)(implicit pos: source.Position): Unit = {
    System.setProperty("SPARK_TESTING", "yes") // codegen modes are not always respected
    val interpretedMode = CodegenObjectFactoryMode.NO_CODEGEN.toString

    test(testName + " (interpreted path)", testTags: _*)(
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode,
        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
        SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "1",
        SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "1"
      ) { testFun })(pos)
  }
}

/**
 * :: Experimental ::
 * Base class for testing Spark DataFrames.
 */
trait DataFrameSuiteBase extends TestSuite
    with SharedSparkContext with DataFrameSuiteBaseLike { self: Suite =>
  override def beforeAll(): Unit = {
    super.beforeAll()
    super.sqlBeforeAllTestCases()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    if (!reuseContextIfPossible) {
      if (spark != null) {
        spark.stop()
      }
      if (sc != null) {
        sc.stop()
      }
      SparkSessionProvider._sparkSession = null
    }
  }

  /**
   * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
   * configurations.
   * Taken from Spark SQLHelper.
   */
  protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
    val currentConf = spark.sessionState.conf
    val (keys, values) = pairs.unzip
    val currentValues = keys.map { key =>
      if (currentConf.contains(key)) {
        Some(currentConf.getConfString(key))
      } else {
        None
      }
    }
    (keys, values).zipped.foreach { (k: String, v: String) =>
      spark.sessionState.conf.setConfString(k, v)
    }
    try f finally {
      keys.zip(currentValues).foreach {
        case (key, Some(value)) => spark.sessionState.conf.setConfString(key, value)
        case (key, None) => spark.sessionState.conf.unsetConf(key)
      }
    }
  }
}

trait DataFrameSuiteBaseLike extends SparkContextProvider
    with TestSuiteLike with Serializable {
  val maxUnequalRowsToShow = 10
  @transient lazy val spark: SparkSession = SparkSessionProvider._sparkSession
  @transient lazy val sqlContext: SQLContext = SparkSessionProvider.sqlContext

  protected implicit def impSqlContext: SQLContext = sqlContext

  protected def enableHiveSupport: Boolean = true
  protected def enableIcebergSupport: Boolean = false

  lazy val tempDir = Utils.createTempDir()
  lazy val localMetastorePath = new File(tempDir, "metastore").getCanonicalPath
  lazy val localWarehousePath = new File(tempDir, "warehouse").getCanonicalPath
  val icebergWarehouse = new File(tempDir, "iceberg-warehouse").getCanonicalPath

  /**
   * Constructs a configuration for hive or iceberg, where the metastore is located in a
   * temp directory.
   */
  def builder(): org.apache.spark.sql.SparkSession.Builder = {
    // Yes this is using a lot of mutation on the builder.
    val builder = SparkSession.builder()
    // Long story with lz4 issues in 2.3+
    builder.config("spark.io.compression.codec", "snappy")
    // We have to mask all properties in hive-site.xml that relates to metastore
    // data source as we used a local metastore here.
    if (enableHiveSupport) {
      import org.apache.hadoop.hive.conf.HiveConf
      val hiveConfVars = HiveConf.ConfVars.values()
      val accessiableHiveConfVars = hiveConfVars.map(WrappedConfVar(_))
      accessiableHiveConfVars.foreach { confvar =>
        if (confvar.varname.contains("datanucleus") ||
          confvar.varname.contains("jdo")) {
          builder.config(confvar.varname, confvar.getDefaultExpr())
        }
      }
      builder.config(HiveConf.ConfVars.METASTOREURIS.varname, "")
      builder.config("javax.jdo.option.ConnectionURL",
        s"jdbc:derby:;databaseName=$localMetastorePath;create=true")
      builder.config("datanucleus.rdbms.datastoreAdapterClassName",
        "org.datanucleus.store.rdbms.adapter.DerbyAdapter")
    }
    builder.config("spark.sql.streaming.checkpointLocation",
      Utils.createTempDir().toPath().toString)
    builder.config("spark.sql.warehouse.dir",
      localWarehousePath)
    // Enable hive support if available
    try {
      if (enableHiveSupport) {
        builder.enableHiveSupport()
      }
    } catch {
      // Exception is thrown in Spark if hive is not present
      case e: IllegalArgumentException =>
    }
    if (enableIcebergSupport) {
      builder.config("spark.sql.extensions",
        "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
      builder.config("spark.sql.catalog.spark_catalog",
        "org.apache.iceberg.spark.SparkSessionCatalog")
      builder.config("spark.sql.catalog.spark_catalog.type",
        "hive")
      builder.config("spark.sql.catalog.local",
        "org.apache.iceberg.spark.SparkCatalog")
      builder.config("spark.sql.catalog.local.type",
        "hadoop")
      builder.config("spark.sql.catalog.local.warehouse",
        icebergWarehouse)
    }
    builder
  }

  def sqlBeforeAllTestCases(): Unit = {
    if ((SparkSessionProvider._sparkSession ne null) &&
      !SparkSessionProvider._sparkSession.sparkContext.isStopped) {
      // Use existing session if its around and running.
    } else {

      SparkSessionProvider._sparkSession = builder.getOrCreate()
    }
  }

  /**
   * Compare if two schemas are equal, ignoring __autoGeneratedAlias magic
   */
  def assertSchemasEqual(expected: StructType, result: StructType): Unit = {
    def dropInternal(s: StructField) = {
      // No metadata no need to filter
      val metadata = s.metadata
      if (!metadata.contains("__autoGeneratedAlias")) {
        s
      } else {
        val jsonString = metadata.json
        val metadataMap = (parse(jsonString).values.asInstanceOf[Map[String, Any]] - "__autoGeneratedAlias")
        implicit val formats = org.json4s.DefaultFormats
        val cleanedJson = Serialization.write(metadataMap)
        val cleanedMetadata = org.apache.spark.sql.types.Metadata.fromJson(cleanedJson)
        new StructField(s.name, s.dataType, s.nullable, cleanedMetadata)
      }
    }
    val cleanExpected = expected.iterator.map(dropInternal).toList
    val cleanResult = result.iterator.map(dropInternal).toList
    assert(cleanExpected, cleanResult)
  }

  /**
   * Compares if two [[DataFrame]]s are equal, checks the schema and then if that
   * matches checks if the rows are equal.
    *
    * @param customShow unit function to customize the '''show''' method
    *                   when dataframes are not equal. IE: '''df.show(false)''' or
    *                   '''df.toJSON.show(false)'''.
   */
  def assertDataFrameEquals(expected: DataFrame, result: DataFrame,
                            customShow: DataFrame => Unit = _.show()): Unit = {
    assertDataFrameApproximateEquals(expected, result, 0.0,
      Duration.ZERO, customShow)
  }

  /**
    * Compares if two [[DataFrame]]s are equal, checks that the schemas are the same.
    * When comparing inexact fields uses tol.
    *
    * @param tol max acceptable tolerance for numeric (between(0, 1)) &
    *            timestamp (millis).
    */
  @deprecated(
    "Use `assertDataFrameApproximateEquals` with timestamp tolerance",
    since = "1.5.0"
  )
  def assertDataFrameApproximateEquals(expected: DataFrame, result: DataFrame,
                                       tol: Double): Unit =
    assertDataFrameApproximateEquals(expected, result, tol,
      Duration.ofMillis(tol.toLong), _.show())

  /**
    * Compares if two [[DataFrame]]s are equal, checks that the schemas are the same.
    * When comparing inexact fields uses tol & tolTimestamp.
    *
    * @param tol          max acceptable numeric tolerance, should be less than 1.
    * @param tolTimestamp max acceptable timestamp tolerance.
    * @param customShow   unit function to customize the '''show''' method
    *                     when dataframes are not equal. IE: '''df.show(false)''' or
    *                     '''df.toJSON.show(false)'''.
    */
  def assertDataFrameApproximateEquals(
    expected: DataFrame, result: DataFrame,
    tol: Double, tolTimestamp: Duration,
    customShow: DataFrame => Unit = _.show()): Unit = {
    import scala.collection.JavaConverters._

    assertSchemasEqual(expected.schema, result.schema)

    try {
      expected.rdd.cache
      result.rdd.cache
      assert("Length not Equal", expected.rdd.count, result.rdd.count)

      val expectedIndexValue = zipWithIndex(expected.rdd)
      val resultIndexValue = zipWithIndex(result.rdd)

      val unequalRDD = expectedIndexValue.join(resultIndexValue).
        filter { case (_, (r1, r2)) =>
          val approxEquals = DataFrameSuiteBase
            .approxEquals(r1, r2, tol, tolTimestamp)

          !(r1.equals(r2) || approxEquals)
        }

      val unEqualRows = unequalRDD.take(maxUnequalRowsToShow)
      if (unEqualRows.nonEmpty) {
        val unequalSchema = StructType(
          StructField("source_dataframe", StringType) ::
            expected.schema.fields.toList)

        val df = spark.createDataFrame(
          unEqualRows
            .flatMap(un =>
                       Seq(tagRow(un._2._1, "expected", unequalSchema),
                           tagRow(un._2._2, "result", unequalSchema)))
            .toList.asJava, unequalSchema
        )

        customShow(df)
        fail("There are some unequal rows")
      }
    } finally {
      expected.rdd.unpersist()
      result.rdd.unpersist()
    }
  }

  private[testing] def tagRow(
    row: Row, tag: String, schema: StructType): Row = row match {
    case generic: GenericRowWithSchema =>
      new GenericRowWithSchema((tag :: generic.toSeq.toList).toArray, schema)
    case _ => throw new UnsupportedOperationException(
      s"row of type ${row.getClass} is not supported")
  }


  /**
    * Compares if two [[DataFrame]]s are equal without caring about order of rows, by
    * finding elements in one DataFrame that is not in the other. The resulting
    * DataFrame should be empty inferring the two DataFrames have the same elements.
    * Also verifies that the schema is identical.
    */
  def assertDataFrameNoOrderEquals(expected: DataFrame, result: DataFrame) {
    assertSchemasEqual(expected.schema, result.schema)
    assertDataFrameDataEquals(expected, result)
  }


  /**
   * Compares if two [[DataFrame]]s are equal without caring about order of rows, by
   * finding elements in one DataFrame that is not in the other. The resulting
   * DataFrame should be empty inferring the two DataFrames have the same elements.
   * Does not compare the schema.
   */
  def assertDataFrameDataEquals(expected: DataFrame, result: DataFrame): Unit = {
    val expectedCol = "assertDataFrameNoOrderEquals_expected"
    val actualCol = "assertDataFrameNoOrderEquals_actual"
    try {
      expected.rdd.cache
      result.rdd.cache
      assert("Column size not Equal", expected.columns.size, result.columns.size)
      assert("Length not Equal", expected.rdd.count, result.rdd.count)

      val columns = expected.columns.map(s => col(s))
      val expectedElementsCount = expected
        .groupBy(columns: _*)
        .agg(count(lit(1)).as(expectedCol))
      val resultElementsCount = result
        .groupBy(columns: _*)
        .agg(count(lit(1)).as(actualCol))

      val joinExprs = expected.columns
        .map(s => expected.col(s) <=> result.col(s)).reduce(_.and(_))
      val diff = expectedElementsCount
        .join(resultElementsCount, joinExprs, "full_outer")
        .filter(not(col(expectedCol) <=> col(actualCol)))

      assertEmpty(diff.take(maxUnequalRowsToShow))
    } finally {
      expected.rdd.unpersist()
      result.rdd.unpersist()
    }
  }


  /**
   * Compares if two [[DataFrame]]s are equal without caring about order of rows, by
   * finding elements in one DataFrame that is not in the other.
   * Similar to the function assertDataFrameDataEquals but for small [[DataFrame]]
   * that can be collected in memory for the comparison.
   */
  def assertSmallDataFrameDataEquals(
    expected: DataFrame, result: DataFrame): Unit = {

    val cols = expected.columns
    assert("Column size not Equal", cols.size, result.columns.size)
    assertTrue(expected.collect.toSet ==
      result.select(cols.head, cols.tail: _*).collect.toSet)
  }

  /**
   * Zip RDD's with precise indexes. This is used so we can join two DataFrame's
   * Rows together regardless of if the source is different but still compare
   * based on the order.
   */
  private[testing] def zipWithIndex[U](rdd: RDD[U]) = {
    rdd.zipWithIndex().map{ case (row, idx) => (idx, row) }
  }

  def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
    DataFrameSuiteBase.approxEquals(r1, r2, tol, Duration.ofMillis(tol.toLong))
  }

  def approxEquals(r1: Row, r2: Row, tolTimestamp: Duration): Boolean = {
    DataFrameSuiteBase.approxEquals(r1, r2, tolTimestamp)
  }

  def approxEquals(r1: Row, r2: Row, tol: Double,
                   tolTimestamp: Duration): Boolean = {
    DataFrameSuiteBase.approxEquals(r1, r2, tol, tolTimestamp)
  }
}

object DataFrameSuiteBase {

  /** Approximate equality, based on equals from [[Row]] */
  def approxEquals(r1: Row, r2: Row, tol: Double): Boolean =
    approxEquals(r1, r2, tol, Duration.ofNanos((tol*1000).toLong))

  /** Approximate equality, based on equals from [[Row]] */
  def approxEquals(r1: Row, r2: Row, tolTimestamp: Duration): Boolean =
    approxEquals(r1, r2, 0, tolTimestamp)

  private def compareTimestamp(t1: Timestamp, t2: Timestamp,
                               tolTimestamp: Duration): Boolean = {
    !(Duration.between(t1.toInstant, t2.toInstant).abs.compareTo(tolTimestamp) > 0)
  }

  private def compareDouble(d1: Double, d2: Double, tol: Double): Boolean =
    !((java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(d2)) || (abs(d1 - d2) > tol))


  private def compareFloat(f1: Float, f2: Float, tol: Double): Boolean = {
    if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(f2)) {
      return false
    }
    if (abs(f1 - f2) > tol) {
      return false
    }
    true
  }

  private def compareJavaBigDecimal(d1: java.math.BigDecimal,
                                    d2: java.math.BigDecimal,
                                    tol: Double): Boolean = {
    if (d1.compareTo(d2) != 0) {
      if (d1.subtract(d2).abs.compareTo(new java.math.BigDecimal(tol)) > 0) {
        return false
      }
    }
    true
  }

  private def compareScalaBigDecimal(d1: scala.math.BigDecimal,
                                    d2: scala.math.BigDecimal,
                                    tol: Double): Boolean = {
    if ((d1 - d2).abs > tol) {
      return false
    }
    true
  }

  /** Approximate equality, based on equals from [[Row]] */
  def approxEquals(r1: Row, r2: Row, tol: Double,
                   tolTimestamp: Duration): Boolean = {
    if (r1.length != r2.length) {
      return false
    } else {
      (0 until r1.length).foreach(idx => {
        if (r1.isNullAt(idx) != r2.isNullAt(idx)) {
          return false
        }

        if (!r1.isNullAt(idx)) {
          val o1 = r1.get(idx)
          val o2 = r2.get(idx)
          o1 match {
            case b1: Array[Byte] =>
              if (!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
                return false
              }

            case f1: Float =>
              if (!compareFloat(f1, o2.asInstanceOf[Float], tol)) {
                return false
              }

            case d1: Double =>
              if (!compareDouble(d1, o2.asInstanceOf[Double], tol)) {
                return false
              }

            case d1: java.math.BigDecimal =>
              if (!compareJavaBigDecimal(d1, o2.asInstanceOf[java.math.BigDecimal], tol)) {
                return false
              }

            case d1: scala.math.BigDecimal =>
              if (!compareScalaBigDecimal(d1, o2.asInstanceOf[scala.math.BigDecimal], tol)) {
                return false
              }

            case t1: Timestamp =>
              if (!compareTimestamp(t1, o2.asInstanceOf[Timestamp], tolTimestamp)) {
                return false
              }

            case row1: Row =>
              if (!approxEquals(row1, o2.asInstanceOf[Row], tol, tolTimestamp)) {
                return false
              }

            case head :: _ if head.isInstanceOf[Row] =>
              o1.asInstanceOf[Seq[Row]].zip(o2.asInstanceOf[Seq[Row]]).foreach {
                case (row1, row2) if !approxEquals(row1, row2, tol, tolTimestamp) =>
                  return false
                case _ =>
              }

            case head :: _ if head.isInstanceOf[Timestamp] =>
              o1.asInstanceOf[Seq[Timestamp]].zip(o2.asInstanceOf[Seq[Timestamp]]).foreach {
                case (t1, t2) if !compareTimestamp(t1, t2, tolTimestamp) =>
                  return false
                case _ =>
              }

            case head :: _ if head.isInstanceOf[Double] =>
              o1.asInstanceOf[Seq[Double]].zip(o2.asInstanceOf[Seq[Double]]).foreach {
                case (d1, d2) if !compareDouble(d1, d2, tol) =>
                  return false
                case _ =>
              }

            case head :: _ if head.isInstanceOf[Float] =>
              o1.asInstanceOf[Seq[Float]].zip(o2.asInstanceOf[Seq[Float]]).foreach {
                case (f1, f2) if !compareFloat(f1, f2, tol) =>
                  return false
                case _ =>
              }

            case head :: _ if head.isInstanceOf[java.math.BigDecimal] =>
              o1.asInstanceOf[Seq[java.math.BigDecimal]].zip(o2.asInstanceOf[Seq[java.math.BigDecimal]]).foreach {
                case (d1, d2) if !compareJavaBigDecimal(d1, d2, tol) =>
                  return false
                case _ =>
              }

            case head :: _ if head.isInstanceOf[scala.math.BigDecimal] =>
              o1.asInstanceOf[Seq[scala.math.BigDecimal]].zip(o2.asInstanceOf[Seq[scala.math.BigDecimal]]).foreach {
                case (d1, d2) if !compareScalaBigDecimal(d1, d2, tol) =>
                  return false
                case _ =>
              }

            case _ =>
              if (o1 != o2) return false
          }
        }
      })
    }
    true
  }
}

object SparkSessionProvider {
  @transient var _sparkSession: SparkSession = _
  def sqlContext: SQLContext = EvilSessionTools.extractSQLContext(_sparkSession)
  def sparkSession = _sparkSession
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy