org.apache.spark.BarrierTaskContext.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark-core Show documentation
Show all versions of spark-core Show documentation
Shaded version of Apache Spark 2.x.x for Presto
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.util.{Properties, Timer, TimerTask}
import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
/**
* :: Experimental ::
* A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage.
* Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task.
*/
@Experimental
@Since("2.4.0")
class BarrierTaskContext private[spark] (
taskContext: TaskContext) extends TaskContext with Logging {
// Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls.
private val barrierCoordinator: RpcEndpointRef = {
val env = SparkEnv.get
RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv)
}
private val timer = new Timer("Barrier task timer for barrier() calls.")
// Local barrierEpoch that identify a barrier() call from current task, it shall be identical
// with the driver side epoch.
private var barrierEpoch = 0
// Number of tasks of the current barrier stage, a barrier() call must collect enough requests
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size
/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
val startTime = System.currentTimeMillis()
val timerTask = new TimerTask {
override def run(): Unit = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
s"under the global sync since $startTime, has been waiting for " +
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " +
s"is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
try {
barrierCoordinator.askSync[Unit](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout"))
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
"global sync successfully, waited for " +
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " +
s"$barrierEpoch.")
} catch {
case e: SparkException =>
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
"to perform global sync, waited for " +
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " +
s"is $barrierEpoch.")
throw e
} finally {
timerTask.cancel()
timer.purge()
}
}
/**
* :: Experimental ::
* Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID.
*/
@Experimental
@Since("2.4.0")
def getTaskInfos(): Array[BarrierTaskInfo] = {
val addressesStr = Option(taskContext.getLocalProperty("addresses")).getOrElse("")
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
}
// delegate methods
override def isCompleted(): Boolean = taskContext.isCompleted()
override def isInterrupted(): Boolean = taskContext.isInterrupted()
override def isRunningLocally(): Boolean = taskContext.isRunningLocally()
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
taskContext.addTaskCompletionListener(listener)
this
}
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
taskContext.addTaskFailureListener(listener)
this
}
override def stageId(): Int = taskContext.stageId()
override def stageAttemptNumber(): Int = taskContext.stageAttemptNumber()
override def partitionId(): Int = taskContext.partitionId()
override def attemptNumber(): Int = taskContext.attemptNumber()
override def taskAttemptId(): Long = taskContext.taskAttemptId()
override def getLocalProperty(key: String): String = taskContext.getLocalProperty(key)
override def taskMetrics(): TaskMetrics = taskContext.taskMetrics()
override def getMetricsSources(sourceName: String): Seq[Source] = {
taskContext.getMetricsSources(sourceName)
}
override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted()
override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason()
override private[spark] def taskMemoryManager(): TaskMemoryManager = {
taskContext.taskMemoryManager()
}
override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
taskContext.registerAccumulator(a)
}
override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
taskContext.setFetchFailed(fetchFailed)
}
override private[spark] def markInterrupted(reason: String): Unit = {
taskContext.markInterrupted(reason)
}
override private[spark] def markTaskFailed(error: Throwable): Unit = {
taskContext.markTaskFailed(error)
}
override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {
taskContext.markTaskCompleted(error)
}
override private[spark] def fetchFailed: Option[FetchFailedException] = {
taskContext.fetchFailed
}
override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties
}
@Experimental
@Since("2.4.0")
object BarrierTaskContext {
/**
* :: Experimental ::
* Returns the currently active BarrierTaskContext. This can be called inside of user functions to
* access contextual information about running barrier tasks.
*/
@Experimental
@Since("2.4.0")
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy