All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.deploy.rest.RestSubmissionClientAppSpyt.scala Maven / Gradle / Ivy

The newest version!
package org.apache.spark.deploy.rest

import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.{SparkConf, SparkException}
import tech.ytsaurus.spyt.patch.annotations.{OriginClass, Subclass}

import java.io.File
import java.net.URI
import scala.annotation.tailrec
import scala.concurrent.duration._
import scala.util.{Failure, Random, Success, Try}

@Subclass
@OriginClass("org.apache.spark.deploy.rest.RestSubmissionClientApp")
class RestSubmissionClientAppSpyt extends RestSubmissionClientApp with Logging {

  override def run(appResource: String,
                   mainClass: String,
                   appArgs: Array[String],
                   conf: SparkConf,
                   env: Map[String, String]): SubmitRestProtocolResponse = {
    // Almost exact copy of super.run(...) method, with spark.rest.master property instead of spark.master
    val master = conf.getOption("spark.rest.master").getOrElse {
      throw new IllegalArgumentException("'spark.rest.master' must be set.")
    }
    val sparkProperties = conf.getAll.toMap
    val client = new RestSubmissionClientSpyt(master)
    val submitRequest = client.constructSubmitRequest(
      appResource, mainClass, appArgs, sparkProperties, env)
    client.createSubmission(submitRequest)
  }

  @tailrec
  private def getSubmissionStatus(submissionId: String,
                                  client: RestSubmissionClient,
                                  retry: Int,
                                  retryInterval: Duration,
                                  rnd: Random = new Random): SubmissionStatusResponse = {
    val response = Try(client.requestSubmissionStatus(submissionId)
      .asInstanceOf[SubmissionStatusResponse])
    response match {
      case Success(value) => value
      case Failure(exception) if retry > 0 =>
        log.error(s"Exception while getting submission status: ${exception.getMessage}")
        val sleepInterval = if (retryInterval > 1.second) {
          1000 + rnd.nextInt(retryInterval.toMillis.toInt - 1000)
        } else rnd.nextInt(retryInterval.toMillis.toInt)
        Thread.sleep(sleepInterval)
        getSubmissionStatus(submissionId, client, retry - 1, retryInterval, rnd)
      case Failure(exception) => throw exception
    }
  }

  def awaitAppTermination(submissionId: String,
                          conf: SparkConf,
                          checkStatusInterval: Duration): Unit = {
    import org.apache.spark.deploy.master.DriverState._

    val master = conf.getOption("spark.rest.master").getOrElse {
      throw new IllegalArgumentException("'spark.rest.master' must be set.")
    }
    val client = new RestSubmissionClient(master)
    val runningStates = Set(RUNNING.toString, SUBMITTED.toString)
    val finalStatus = Stream.continually {
      Thread.sleep(checkStatusInterval.toMillis)
      val response = getSubmissionStatus(submissionId, client, retry = 3, checkStatusInterval)
      logInfo(s"Driver report for $submissionId (state: ${response.driverState})")
      response
    }.find(response => !runningStates.contains(response.driverState)).get
    logInfo(s"Driver $submissionId finished with status ${finalStatus.driverState}")
    finalStatus.driverState match {
      case s if s == FINISHED.toString => // success
      case s if s == FAILED.toString =>
        throw new SparkException(s"Driver $submissionId failed")
      case _ =>
        throw new SparkException(s"Driver $submissionId failed with unexpected error")
    }
  }

  def shutdownYtClient(sparkConf: SparkConf): Unit = {
    val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf)
    val fs = FileSystem.get(new URI("yt:///"), hadoopConf)
    fs.close()
  }

  private def writeToFile(file: File, message: String): Unit = {
    val tmpFile = new File(file.getParentFile, s"${file.getName}_tmp")
    FileUtils.writeStringToFile(tmpFile, message)
    FileUtils.moveFile(tmpFile, file)
  }

  override def start(args: Array[String], conf: SparkConf): Unit = {
    val submissionIdFile = conf.getOption("spark.rest.client.submissionIdFile").map(new File(_))
    val submissionErrorFile = conf.getOption("spark.rest.client.submissionErrorFile")
      .map(new File(_))
    try {
      // Here starts an almost exact copy of super.start(...) method
      if (args.length < 2) {
        sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]")
        sys.exit(1)
      }
      val appResource = args(0)
      val mainClass = args(1)
      val appArgs = args.slice(2, args.length)
      val confEnv = conf.getAll.filter {
        case (key, _) => key.startsWith("spark.yt") || key.startsWith("spark.hadoop.yt")
      }.map {
        case (key, value) => key.toUpperCase().replace(".", "_") -> value
      }.toMap
      val env = RestSubmissionClient.filterSystemEnvironment(sys.env) ++ confEnv

      val submissionId = try {
        val response = run(appResource, mainClass, appArgs, conf, env)
        response match {
          case r: CreateSubmissionResponse => r.submissionId
          case _ => throw new IllegalStateException("Job is not submitted")
        }
      } finally {
        shutdownYtClient(conf)
      }

      submissionIdFile.foreach(writeToFile(_, submissionId))

      if (conf.getOption("spark.rest.client.awaitTermination.enabled").forall(_.toBoolean)) {
        val checkStatusInterval = conf.getOption("spark.rest.client.statusInterval")
          .map(_.toInt.seconds).getOrElse(5.seconds)
        awaitAppTermination(submissionId, conf, checkStatusInterval)
      }
    } catch {
      case e: Throwable =>
        submissionErrorFile.foreach(writeToFile(_, e.getMessage))
        throw e
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy