All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.spark.sql.perf.Benchmarkable.scala Maven / Gradle / Ivy

/*
 * Copyright 2015 Databricks Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.databricks.spark.sql.perf

import java.util.UUID

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkEnv, SparkContext}

import scala.collection.mutable.ArrayBuffer

/** A trait to describe things that can be benchmarked. */
trait Benchmarkable {
  val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
  val sparkContext = sqlContext.sparkContext

  val name: String
  protected val executionMode: ExecutionMode

  final def benchmark(
      includeBreakdown: Boolean,
      description: String = "",
      messages: ArrayBuffer[String],
      timeout: Long): BenchmarkResult = {
    sparkContext.setJobDescription(s"Execution: $name, $description")
    beforeBenchmark()
    val result = runBenchmark(includeBreakdown, description, messages, timeout)
    afterBenchmark(sqlContext.sparkContext)
    result
  }

  protected def beforeBenchmark(): Unit = { }

  private def afterBenchmark(sc: SparkContext): Unit = {
    // Best-effort clean up of weakly referenced RDDs, shuffles, and broadcasts
    System.gc()
    // Remove any leftover blocks that still exist
    sc.getExecutorStorageStatus
        .flatMap { status => status.blocks.map { case (bid, _) => bid } }
        .foreach { bid => SparkEnv.get.blockManager.master.removeBlock(bid) }
  }

  private def runBenchmark(
      includeBreakdown: Boolean,
      description: String = "",
      messages: ArrayBuffer[String],
      timeout: Long): BenchmarkResult = {
    val jobgroup = UUID.randomUUID().toString
    var result: BenchmarkResult = null
    val thread = new Thread("benchmark runner") {
      override def run(): Unit = {
        sparkContext.setJobGroup(jobgroup, s"benchmark $name", true)
        result = doBenchmark(includeBreakdown, description, messages)
      }
    }
    thread.setDaemon(true)
    thread.start()
    thread.join(timeout)
    if (thread.isAlive) {
      sparkContext.cancelJobGroup(jobgroup)
      thread.interrupt()
      result = BenchmarkResult(
        name = name,
        mode = executionMode.toString,
        failure = Some(Failure("Timeout", s"timeout after ${timeout / 1000} seconds"))
      )
    }
    result
  }

  protected def doBenchmark(
      includeBreakdown: Boolean,
      description: String = "",
      messages: ArrayBuffer[String]): BenchmarkResult

  protected def measureTimeMs[A](f: => A): Double = {
    val startTime = System.nanoTime()
    f
    val endTime = System.nanoTime()
    (endTime - startTime).toDouble / 1000000
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy