All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.comet.fuzz.QueryRunner.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.comet.fuzz

import java.io.{BufferedWriter, FileWriter, PrintWriter, StringWriter}

import scala.io.Source

import org.apache.spark.sql.{Row, SparkSession}

object QueryRunner {

  def runQueries(
      spark: SparkSession,
      numFiles: Int,
      filename: String,
      showMatchingResults: Boolean,
      showFailedSparkQueries: Boolean = false): Unit = {

    val outputFilename = s"results-${System.currentTimeMillis()}.md"
    // scalastyle:off println
    println(s"Writing results to $outputFilename")
    // scalastyle:on println

    val w = new BufferedWriter(new FileWriter(outputFilename))

    // register input files
    for (i <- 0 until numFiles) {
      val table = spark.read.parquet(s"test$i.parquet")
      val tableName = s"test$i"
      table.createTempView(tableName)
      w.write(
        s"Created table $tableName with schema:\n\t" +
          s"${table.schema.fields.map(f => s"${f.name}: ${f.dataType}").mkString("\n\t")}\n\n")
    }

    val querySource = Source.fromFile(filename)
    try {
      querySource
        .getLines()
        .foreach(sql => {

          try {
            // execute with Spark
            spark.conf.set("spark.comet.enabled", "false")
            val df = spark.sql(sql)
            val sparkRows = df.collect()
            val sparkPlan = df.queryExecution.executedPlan.toString

            try {
              spark.conf.set("spark.comet.enabled", "true")
              val df = spark.sql(sql)
              val cometRows = df.collect()
              val cometPlan = df.queryExecution.executedPlan.toString

              if (sparkRows.length == cometRows.length) {
                var i = 0
                while (i < sparkRows.length) {
                  val l = sparkRows(i)
                  val r = cometRows(i)
                  assert(l.length == r.length)
                  for (j <- 0 until l.length) {
                    val same = (l(j), r(j)) match {
                      case (a: Float, b: Float) if a.isInfinity => b.isInfinity
                      case (a: Float, b: Float) if a.isNaN => b.isNaN
                      case (a: Float, b: Float) => (a - b).abs <= 0.000001f
                      case (a: Double, b: Double) if a.isInfinity => b.isInfinity
                      case (a: Double, b: Double) if a.isNaN => b.isNaN
                      case (a: Double, b: Double) => (a - b).abs <= 0.000001
                      case (a: Array[Byte], b: Array[Byte]) => a.sameElements(b)
                      case (a, b) => a == b
                    }
                    if (!same) {
                      showSQL(w, sql)
                      showPlans(w, sparkPlan, cometPlan)
                      w.write(s"First difference at row $i:\n")
                      w.write("Spark: `" + formatRow(l) + "`\n")
                      w.write("Comet: `" + formatRow(r) + "`\n")
                      i = sparkRows.length
                    }
                  }
                  i += 1
                }
              } else {
                showSQL(w, sql)
                showPlans(w, sparkPlan, cometPlan)
                w.write(
                  s"[ERROR] Spark produced ${sparkRows.length} rows and " +
                    s"Comet produced ${cometRows.length} rows.\n")
              }
            } catch {
              case e: Exception =>
                // the query worked in Spark but failed in Comet, so this is likely a bug in Comet
                showSQL(w, sql)
                w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
                w.write("```\n")
                val sw = new StringWriter()
                val p = new PrintWriter(sw)
                e.printStackTrace(p)
                p.close()
                w.write(s"${sw.toString}\n")
                w.write("```\n")
            }

            // flush after every query so that results are saved in the event of the driver crashing
            w.flush()

          } catch {
            case e: Exception =>
              // we expect many generated queries to be invalid
              if (showFailedSparkQueries) {
                showSQL(w, sql)
                w.write(s"Query failed in Spark: ${e.getMessage}\n")
              }
          }
        })

    } finally {
      w.close()
      querySource.close()
    }
  }

  private def formatRow(row: Row): String = {
    row.toSeq
      .map {
        case null => "NULL"
        case v: Array[Byte] => v.mkString
        case other => other.toString
      }
      .mkString(",")
  }

  private def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = {
    w.write("## SQL\n")
    w.write("```\n")
    val words = sql.split(" ")
    val currentLine = new StringBuilder
    for (word <- words) {
      if (currentLine.length + word.length + 1 > maxLength) {
        w.write(currentLine.toString.trim)
        w.write("\n")
        currentLine.setLength(0)
      }
      currentLine.append(word).append(" ")
    }
    if (currentLine.nonEmpty) {
      w.write(currentLine.toString.trim)
      w.write("\n")
    }
    w.write("```\n")
  }

  private def showPlans(w: BufferedWriter, sparkPlan: String, cometPlan: String): Unit = {
    w.write("### Spark Plan\n")
    w.write(s"```\n$sparkPlan\n```\n")
    w.write("### Comet Plan\n")
    w.write(s"```\n$cometPlan\n```\n")
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy