All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.spark.sql.perf.Benchmark.scala Maven / Gradle / Ivy

/*
 * Copyright 2015 Databricks Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.databricks.spark.sql.perf

import java.util.UUID

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent._
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.Try

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedRelation}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.{SparkContext, SparkEnv}

import com.databricks.spark.sql.perf.cpu._

/**
 * A collection of queries that test a particular aspect of Spark SQL.
 *
 * @param sqlContext An existing SQLContext.
 */
abstract class Benchmark(
    @transient val sqlContext: SQLContext)
  extends Serializable {

  import sqlContext.implicits._

  def this() = this(SQLContext.getOrCreate(SparkContext.getOrCreate()))

  val resultsLocation =
    sqlContext.getAllConfs.getOrElse(
      "spark.sql.perf.results",
      "/spark/sql/performance")

  protected def sparkContext = sqlContext.sparkContext

  protected implicit def toOption[A](a: A): Option[A] = Option(a)

  val buildInfo = Try(getClass.getClassLoader.loadClass("org.apache.spark.BuildInfo")).map { cls =>
    cls.getMethods
      .filter(_.getReturnType == classOf[String])
        .filterNot(_.getName == "toString")
        .map(m => m.getName -> m.invoke(cls).asInstanceOf[String])
        .toMap
  }.getOrElse(Map.empty)

  def currentConfiguration = BenchmarkConfiguration(
    sqlConf = sqlContext.getAllConfs,
    sparkConf = sparkContext.getConf.getAll.toMap,
    defaultParallelism = sparkContext.defaultParallelism,
    buildInfo = buildInfo)

  /**
   * A Variation represents a setting (e.g. the number of shuffle partitions or if tables
   * are cached in memory) that we want to change in a experiment run.
   * A Variation has three parts, `name`, `options`, and `setup`.
   * The `name` is the identifier of a Variation. `options` is a Seq of options that
   * will be used for a query. Basically, a query will be executed with every option
   * defined in the list of `options`. `setup` defines the needed action for every
   * option. For example, the following Variation is used to change the number of shuffle
   * partitions of a query. The name of the Variation is "shufflePartitions". There are
   * two options, 200 and 2000. The setup is used to set the value of property
   * "spark.sql.shuffle.partitions".
   *
   * {{{
   *   Variation("shufflePartitions", Seq("200", "2000")) {
   *     case num => sqlContext.setConf("spark.sql.shuffle.partitions", num)
   *   }
   * }}}
   */
  case class Variation[T](name: String, options: Seq[T])(val setup: T => Unit)

  val codegen = Variation("codegen", Seq("on", "off")) {
    case "off" => sqlContext.setConf("spark.sql.codegen", "false")
    case "on" => sqlContext.setConf("spark.sql.codegen", "true")
  }

  val unsafe = Variation("unsafe", Seq("on", "off")) {
    case "off" => sqlContext.setConf("spark.sql.unsafe.enabled", "false")
    case "on" => sqlContext.setConf("spark.sql.unsafe.enabled", "true")
  }

  val tungsten = Variation("tungsten", Seq("on", "off")) {
    case "off" => sqlContext.setConf("spark.sql.tungsten.enabled", "false")
    case "on" => sqlContext.setConf("spark.sql.tungsten.enabled", "true")
  }

  /**
   * Starts an experiment run with a given set of executions to run.
   * @param executionsToRun a list of executions to run.
   * @param includeBreakdown If it is true, breakdown results of an execution will be recorded.
   *                         Setting it to true may significantly increase the time used to
   *                         run an execution.
   * @param iterations The number of iterations to run of each execution.
   * @param variations [[Variation]]s used in this run.  The cross product of all variations will be
   *                   run for each execution * iteration.
   * @param tags Tags of this run.
   * @param timeout wait at most timeout milliseconds for each query, 0 means wait forever
   * @return It returns a ExperimentStatus object that can be used to
   *         track the progress of this experiment run.
   */
  def runExperiment(
      executionsToRun: Seq[Benchmarkable],
      includeBreakdown: Boolean = false,
      iterations: Int = 3,
      variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("true")) { _ => {} }),
      tags: Map[String, String] = Map.empty,
      timeout: Long = 0L) = {

    class ExperimentStatus {
      val currentResults = new collection.mutable.ArrayBuffer[BenchmarkResult]()
      val currentRuns = new collection.mutable.ArrayBuffer[ExperimentRun]()
      val currentMessages = new collection.mutable.ArrayBuffer[String]()

      def logMessage(msg: String) = {
        println(msg)
        currentMessages += msg
      }

      // Stats for HTML status message.
      @volatile var currentExecution = ""
      @volatile var currentPlan = "" // for queries only
      @volatile var currentConfig = ""
      @volatile var failures = 0
      @volatile var startTime = 0L

      /** An optional log collection task that will run after the experiment. */
      @volatile var logCollection: () => Unit = () => {}


      def cartesianProduct[T](xss: List[List[T]]): List[List[T]] = xss match {
        case Nil => List(Nil)
        case h :: t => for(xh <- h; xt <- cartesianProduct(t)) yield xh :: xt
      }

      val timestamp = System.currentTimeMillis()
      val resultPath = s"$resultsLocation/timestamp=$timestamp"
      val combinations = cartesianProduct(variations.map(l => (0 until l.options.size).toList).toList)
      val resultsFuture = Future {

        // If we're running queries, create tables for them
        executionsToRun
          .collect { case query: Query => query }
          .flatMap { query =>
            try {
              query.newDataFrame().queryExecution.logical.collect {
                case UnresolvedRelation(t, _) => t.table
              }
            } catch {
              // ignore the queries that can't be parsed
              case e: Exception => Seq()
            }
          }
          .distinct
          .foreach { name =>
            try {
              sqlContext.table(name)
              logMessage(s"Table $name exists.")
            } catch {
              case ae: Exception =>
                val table = allTables
                  .find(_.name == name)
                if (table.isDefined) {
                  logMessage(s"Creating table: $name")
                  table.get.data
                    .write
                    .mode("overwrite")
                    .saveAsTable(name)
                } else {
                  // the table could be subquery
                  logMessage(s"Couldn't read table $name and its not defined as a Benchmark.Table.")
                }
            }
          }

        // Run the benchmarks!
        val results = (1 to iterations).flatMap { i =>
          combinations.map { setup =>
            val currentOptions = variations.asInstanceOf[Seq[Variation[Any]]].zip(setup).map {
              case (v, idx) =>
                v.setup(v.options(idx))
                v.name -> v.options(idx).toString
            }
            currentConfig = currentOptions.map { case (k,v) => s"$k: $v" }.mkString(", ")

            val result = ExperimentRun(
              timestamp = timestamp,
              iteration = i,
              tags = currentOptions.toMap ++ tags,
              configuration = currentConfiguration,

              executionsToRun.flatMap { q =>
                val setup = s"iteration: $i, ${currentOptions.map { case (k, v) => s"$k=$v"}.mkString(", ")}"
                logMessage(s"Running execution ${q.name} $setup")

                currentExecution = q.name
                currentPlan = q match {
                  case query: Query =>
                    try {
                      query.newDataFrame().queryExecution.executedPlan.toString()
                    } catch {
                      case e: Exception =>
                        s"failed to parse: $e"
                    }
                  case _ => ""
                }
                startTime = System.currentTimeMillis()

                val singleResult =
                  q.benchmark(includeBreakdown, setup, currentMessages, timeout)

                singleResult.failure.foreach { f =>
                  failures += 1
                  logMessage(s"Execution '${q.name}' failed: ${f.message}")
                }
                singleResult.executionTime.foreach { time =>
                  logMessage(s"Execution time: ${time / 1000}s")
                }
                currentResults += singleResult
                singleResult :: Nil
              })

            currentRuns += result

            result
          }
        }

        try {
          val resultsTable = sqlContext.createDataFrame(results)
          logMessage(s"Results written to table: 'sqlPerformance' at $resultPath")
          results.toDF()
            .coalesce(1)
            .write
            .format("json")
            .save(resultPath)

          results.toDF()
        } catch {
          case e: Throwable => logMessage(s"Failed to write data: $e")
        }

        logCollection()
      }

      def scheduleCpuCollection(fs: FS) = {
        logCollection = () => {
          logMessage(s"Begining CPU log collection")
          try {
            val location = cpu.collectLogs(sqlContext, fs, timestamp)
            logMessage(s"cpu results recorded to $location")
          } catch {
            case e: Throwable =>
              logMessage(s"Error collecting logs: $e")
              throw e
          }
        }
      }

      def cpuProfile = new Profile(sqlContext, sqlContext.read.json(getCpuLocation(timestamp)))

      def cpuProfileHtml(fs: FS) = {
        s"""
           |

CPU Profile

|Permalink: sqlContext.read.json("${getCpuLocation(timestamp)}")
|${cpuProfile.buildGraph(fs)} """.stripMargin } /** Waits for the finish of the experiment. */ def waitForFinish(timeoutInSeconds: Int) = { Await.result(resultsFuture, timeoutInSeconds.seconds) } /** Returns results from an actively running experiment. */ def getCurrentResults() = { val tbl = sqlContext.createDataFrame(currentResults) tbl.registerTempTable("currentResults") tbl } /** Returns full iterations from an actively running experiment. */ def getCurrentRuns() = { val tbl = sqlContext.createDataFrame(currentRuns) tbl.registerTempTable("currentRuns") tbl } def tail(n: Int = 20) = { currentMessages.takeRight(n).mkString("\n") } def status = if (resultsFuture.isCompleted) { if (resultsFuture.value.get.isFailure) "Failed" else "Successful" } else { "Running" } override def toString = s"""Permalink: table("sqlPerformance").where('timestamp === ${timestamp}L)""" def html: String = { val maybeQueryPlan: String = if (currentPlan.nonEmpty) { s""" |

QueryPlan

|
              |${currentPlan.replaceAll("\n", "
")} |
""".stripMargin } else { "" } s""" |

$status Experiment

|Permalink: sqlContext.read.json("$resultPath")
|Iterations complete: ${currentRuns.size / combinations.size} / $iterations
|Failures: $failures
|Executions run: ${currentResults.size} / ${iterations * combinations.size * executionsToRun.size} |
|Run time: ${(System.currentTimeMillis() - timestamp) / 1000}s
| |

Current Execution: $currentExecution

|Runtime: ${(System.currentTimeMillis() - startTime) / 1000}s
|$currentConfig
|$maybeQueryPlan |

Logs

|
           |${tail()}
           |
""".stripMargin } } new ExperimentStatus } case class Table( name: String, data: DataFrame) import reflect.runtime._, universe._ import reflect.runtime._ import universe._ @transient private val runtimeMirror = universe.runtimeMirror(getClass.getClassLoader) @transient val myType = runtimeMirror.classSymbol(getClass).toType def singleTables = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Table]) .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Table]) def groupedTables = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Seq[Table]]) .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Table]]) @transient lazy val allTables: Seq[Table] = (singleTables ++ groupedTables).toSeq def singleQueries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Benchmarkable]) .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Benchmarkable]) def groupedQueries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Seq[Benchmarkable]]) .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Benchmarkable]]) @transient lazy val allQueries = (singleQueries ++ groupedQueries).toSeq def html: String = { val singleQueries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Query]) .map(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Query]) .mkString(",") val queries = myType.declarations .filter(m => m.isMethod) .map(_.asMethod) .filter(_.asMethod.returnType =:= typeOf[Seq[Query]]) .map { method => val queries = runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Query]] val queryList = queries.map(_.name).mkString(", ") s""" |

${method.name}

|
    $queryList
""".stripMargin }.mkString("\n") s""" |

Spark SQL Performance Benchmarking

|

Available Queries

|$singleQueries |$queries """.stripMargin } /** Factory object for benchmark queries. */ case object Query { def apply( name: String, sqlText: String, description: String, executionMode: ExecutionMode = ExecutionMode.ForeachResults): Query = { new Query(name, sqlContext.sql(sqlText), description, Some(sqlText), executionMode) } def apply( name: String, dataFrameBuilder: => DataFrame, description: String): Query = { new Query(name, dataFrameBuilder, description, None, ExecutionMode.CollectResults) } } object RDDCount { def apply( name: String, rdd: RDD[_]) = { new SparkPerfExecution( name, Map.empty, () => Unit, () => rdd.count(), rdd.toDebugString) } } /** A class for benchmarking Spark perf results. */ class SparkPerfExecution( override val name: String, parameters: Map[String, String], prepare: () => Unit, run: () => Unit, description: String = "") extends Benchmarkable { override def toString: String = s""" |== $name == |$description """.stripMargin protected override val executionMode: ExecutionMode = ExecutionMode.SparkPerfResults protected override def beforeBenchmark(): Unit = { prepare() } protected override def doBenchmark( includeBreakdown: Boolean, description: String = "", messages: ArrayBuffer[String]): BenchmarkResult = { try { val timeMs = measureTimeMs(run()) BenchmarkResult( name = name, mode = executionMode.toString, parameters = parameters, executionTime = Some(timeMs)) } catch { case e: Exception => BenchmarkResult( name = name, mode = executionMode.toString, parameters = parameters, failure = Some(Failure(e.getClass.getSimpleName, e.getMessage))) } } } /** Holds one benchmark query and its metadata. */ class Query( override val name: String, buildDataFrame: => DataFrame, val description: String = "", val sqlText: Option[String] = None, override val executionMode: ExecutionMode = ExecutionMode.ForeachResults) extends Benchmarkable with Serializable { override def toString: String = { try { s""" |== Query: $name == |${buildDataFrame.queryExecution.analyzed} """.stripMargin } catch { case e: Exception => s""" |== Query: $name == | Can't be analyzed: $e | | $description """.stripMargin } } lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect { case UnresolvedRelation(tableIdentifier, _) => { // We are ignoring the database name. tableIdentifier.table } } def newDataFrame() = buildDataFrame protected override def doBenchmark( includeBreakdown: Boolean, description: String = "", messages: ArrayBuffer[String]): BenchmarkResult = { try { val dataFrame = buildDataFrame val queryExecution = dataFrame.queryExecution // We are not counting the time of ScalaReflection.convertRowToScala. val parsingTime = measureTimeMs { queryExecution.logical } val analysisTime = measureTimeMs { queryExecution.analyzed } val optimizationTime = measureTimeMs { queryExecution.optimizedPlan } val planningTime = measureTimeMs { queryExecution.executedPlan } val breakdownResults = if (includeBreakdown) { val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i))) val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap val timeMap = new mutable.HashMap[Int, Double] physicalOperators.reverse.map { case (index, node) => messages += s"Breakdown: ${node.simpleString}" val newNode = buildDataFrame.queryExecution.executedPlan(index) val executionTime = measureTimeMs { newNode.execute().foreach((row: Any) => Unit) } timeMap += ((index, executionTime)) val childIndexes = node.children.map(indexMap) val childTime = childIndexes.map(timeMap).sum messages += s"Breakdown time: $executionTime (+${executionTime - childTime})" BreakdownResult( node.nodeName, node.simpleString.replaceAll("#\\d+", ""), index, childIndexes, executionTime, executionTime - childTime) } } else { Seq.empty[BreakdownResult] } // The executionTime for the entire query includes the time of type conversion from catalyst // to scala. // The executionTime for the entire query includes the time of type conversion // from catalyst to scala. var result: Option[Long] = None val executionTime = measureTimeMs { executionMode match { case ExecutionMode.CollectResults => dataFrame.rdd.collect() case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit } case ExecutionMode.WriteParquet(location) => dataFrame.saveAsParquetFile(s"$location/$name.parquet") case ExecutionMode.HashResults => val columnStr = dataFrame.schema.map(_.name).mkString(",") // SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query) val row = dataFrame .selectExpr(s"hash($columnStr) as hashValue") .groupBy() .sum("hashValue") .head() result = if (row.isNullAt(0)) None else Some(row.getLong(0)) } } val joinTypes = dataFrame.queryExecution.executedPlan.collect { case k if k.nodeName contains "Join" => k.nodeName } BenchmarkResult( name = name, mode = executionMode.toString, joinTypes = joinTypes, tables = tablesInvolved, parsingTime = parsingTime, analysisTime = analysisTime, optimizationTime = optimizationTime, planningTime = planningTime, executionTime = executionTime, result = result, queryExecution = dataFrame.queryExecution.toString, breakDown = breakdownResults) } catch { case e: Exception => BenchmarkResult( name = name, mode = executionMode.toString, failure = Failure(e.getClass.getName, e.getMessage)) } } /** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */ def checkResult: Query = { new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults) } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy