All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.holdenkarau.spark.testing.DatasetSuiteBase.scala Maven / Gradle / Ivy

The newest version!
package com.holdenkarau.spark.testing

import org.apache.spark.rdd.RDD
import org.scalatest.Suite

import scala.reflect.ClassTag

import java.time.Duration

import org.apache.spark.sql.{DataFrame, Dataset}

trait DatasetSuiteBase extends DataFrameSuiteBase
    with DatasetSuiteBaseLike { self: Suite =>
}

trait DatasetSuiteBaseLike extends DataFrameSuiteBaseLike {

  /**
   * Check if two Datasets are equals, Datasets should have the same type.
   * This method could be customized by overriding equals method for
   * the given class type.
   */
  def assertDatasetEquals[U](expected: Dataset[U], result: Dataset[U])
                            (implicit UCT: ClassTag[U]): Unit = {
    try {
      expected.rdd.cache
      result.rdd.cache
      assert("Length not Equal", expected.rdd.count, result.rdd.count)

      val expectedIndexValue: RDD[(Long, U)] = zipWithIndex(expected.rdd)
      val resultIndexValue: RDD[(Long, U)] = zipWithIndex(result.rdd)
      val unequalRDD = expectedIndexValue.join(resultIndexValue).filter
      { case (idx, (o1, o2)) => !o1.equals(o2) }

      assertEmpty(unequalRDD.take(maxUnequalRowsToShow))
    } finally {
      expected.rdd.unpersist()
      result.rdd.unpersist()
    }
  }

  /**
    * Compares if two Datasets are equal, Datasets should have the same type.
    * When comparing inexact fields uses tol.
    *
    * @param tol        max acceptable tolerance for numeric (between(0, 1)) &
    *                   timestamp (millis).
    */
  @deprecated(
    "Use `assertDatasetApproximateEquals` with timestamp tolerance",
    since = "1.5.0"
  )
  def assertDatasetApproximateEquals[U](expected: Dataset[U], result: Dataset[U],
                                        tol: Double)
    (implicit UCT: ClassTag[U]): Unit = {
    assertDataFrameApproximateEquals(expected.toDF, result.toDF, tol,
      Duration.ofMillis(tol.toLong), _.show())
  }

  /**
    * Compares if two Datasets are equal, Datasets should have the same type.
    * When comparing inexact fields uses tol & tolTimestamp.
    *
    * @param tol          max acceptable tolerance for numeric (between(0, 1))
    * @param tolTimestamp max acceptable timestamp tolerance.
    * @param customShow   unit function to customize the '''show''' method
    *                     when dataframes are not equal. IE: '''df.show(false)''' or
    *                     '''df.toJSON.show(false)'''.
    */
  def assertDatasetApproximateEquals[U]
    (expected: Dataset[U], result: Dataset[U], tol: Double,
     tolTimestamp: Duration,
     customShow: DataFrame => Unit = _.show())
    (implicit UCT: ClassTag[U]): Unit = {

    assertDataFrameApproximateEquals(expected.toDF, result.toDF, tol,
      tolTimestamp, customShow)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy