All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.livy.repl.ReplDriver.scala Maven / Gradle / Ivy

The newest version!
package org.apache.livy.repl

import io.netty.channel.ChannelHandlerContext
import org.apache.livy.rsc.driver.{QueryPlan, SparkEntries, StatementSpyt}
import org.apache.livy.rsc.{BaseProtocol, RSCConf, ReplJobResults}
import org.apache.spark.{SparkConf, StoreUtils}
import tech.ytsaurus.spyt.patch.annotations.{OriginClass, Subclass}

@Subclass
@OriginClass("org.apache.livy.repl.ReplDriver")
class ReplDriverSpyt(conf: SparkConf, livyConf: RSCConf) extends ReplDriver(conf, livyConf) {
  override def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResults): ReplJobResults = {
    val jobResults = super.handle(ctx, msg)
    jobResults.statements.foreach { s =>
      s.asInstanceOf[StatementSpyt].setPlan(queryPlan(s.id))
    }
    jobResults
  }

  private def getEntries: SparkEntries = {
    val entries = classOf[org.apache.livy.repl.Session].getDeclaredField("entries")
    entries.setAccessible(true)
    val result = entries.get(session).asInstanceOf[SparkEntries]
    entries.setAccessible(false)
    result
  }

  private def queryPlan(stmtId: Int): QueryPlan = {
    val entries = getEntries
    val jobIds = entries.sc().sc.statusTracker.getJobIdsForGroup(stmtId.toString)
    val store = StoreUtils.getStatusStore(entries.sc().sc)
    val executionIds = jobIds.flatMap { id => store.asOption(store.jobWithAssociatedSql(id)).flatMap(_._2) }.distinct
    if (executionIds.isEmpty) {
      logger.info(s"No SQL executions found for statement $stmtId")
      null
    } else {
      if (executionIds.length > 1) {
        logger.warn(s"Found more than 1 satisfying execution ids: $executionIds for statement $stmtId")
      }
      val executionId = executionIds.head
      val sqlStore = entries.sparkSession().sharedState.statusStore
      val metrics = sqlStore.executionMetrics(executionId)
      val graph = sqlStore.planGraph(executionId)
      val dotContent = graph.makeDotFile(metrics)
      val metadata = graph.allNodes.sortBy(_.id).map(_.desc)
      new QueryPlan(dotContent, metadata.toArray)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy