All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dimajix.spark.testing.QueryTest.scala Maven / Gradle / Ivy

package com.dimajix.spark.testing

import java.util.TimeZone

import scala.util.control.NonFatal

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.scalatest.Assertions


trait QueryTest { this:Assertions =>
    /**
      * Runs the plan and makes sure the answer matches the expected result.
      * If there was exception during the execution or the contents of the DataFrame does not
      * match the expected result, an error message will be returned. Otherwise, a [[None]] will
      * be returned.
      *
      * @param df the [[DataFrame]] to be executed
      * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
      */
    def checkAnswer(
                       df: DataFrame,
                       expectedAnswer: Seq[Row]): Unit = {
        //val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
        val result = try {
            val generated = df.collect().toSeq
            sameRows(expectedAnswer, generated, false).map { results =>
                s"""
                   |Results do not match for query:
                   |Timezone: ${TimeZone.getDefault}
                   |Timezone Env: ${sys.env.getOrElse("TZ", "")}
                   |
        |${df.queryExecution}
                   |== Results ==
                   |$results
       """.stripMargin
            }
        } catch {
            case NonFatal(e) =>
                val errorMessage =
                    s"""
                       |Exception thrown while executing query:
                       |${df.queryExecution}
                       |== Exception ==
                       |$e
                       |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
          """.stripMargin
                Some(errorMessage)
        }

        result match {
            case Some(errorMessage) => fail(errorMessage)
            case None =>
        }
    }

    def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
        checkAnswer(df, Seq(expectedAnswer))
    }

    def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
        checkAnswer(df, expectedAnswer.collect())
    }

    def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
        // Converts data to types that we can do equality comparison using Scala collections.
        // For BigDecimal type, the Scala type has a better definition of equality test (similar to
        // Java's java.math.BigDecimal.compareTo).
        // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
        // equality test.
        val converted: Seq[Row] = answer.map(prepareRow)
        if (!isSorted) converted.sortBy(_.toString()) else converted
    }

    // We need to call prepareRow recursively to handle schemas with struct types.
    def prepareRow(row: Row): Row = {
        Row.fromSeq(row.toSeq.map {
            case null => null
            case d: java.math.BigDecimal => BigDecimal(d)
            // Convert array to Seq for easy equality checkJob.
            case b: Array[_] => b.toSeq
            case r: Row => prepareRow(r)
            case o => o
        })
    }

    def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = {
        val maxLeftSize = left.map(_.length).foldLeft(0)((x, y) => if (x > y) x else y)
        val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
        val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")

        leftPadded.zip(rightPadded).map {
            case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r
        }
    }

    private def genError(expectedAnswer: Seq[Row],
                            sparkAnswer: Seq[Row],
                            isSorted: Boolean = false): String = {
        val getRowType: Option[Row] => String = row =>
            row.map(row =>
                if (row.schema == null) {
                    "struct<>"
                } else {
                    s"${row.schema.catalogString}"
                }).getOrElse("struct<>")

        s"""
           |== Results ==
           |${
            sideBySide(
                s"== Correct Answer - ${expectedAnswer.size} ==" +:
                    getRowType(expectedAnswer.headOption) +:
                    prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
                s"== Spark Answer - ${sparkAnswer.size} ==" +:
                    getRowType(sparkAnswer.headOption) +:
                    prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")
        }
    """.stripMargin
    }

    def sameRows(expectedAnswer: Seq[Row],
                    sparkAnswer: Seq[Row],
                    isSorted: Boolean = false): Option[String] = {
        if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) {
            return Some(genError(expectedAnswer, sparkAnswer, isSorted))
        }
        None
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy