scala.concurrent.impl.ExecutionContextImpl.scala Maven / Gradle / Ivy
/* __ *\
** ________ ___ / / ___ Scala API **
** / __/ __// _ | / / / _ | (c) 2003-2013, LAMP/EPFL **
** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
** /____/\___/_/ |_/____/_/ | | **
** |/ **
\* */
package scala.concurrent.impl
import java.util.concurrent.{ ForkJoinPool, ForkJoinWorkerThread, ForkJoinTask, Callable, Executor, ExecutorService, ThreadFactory, TimeUnit }
import java.util.concurrent.atomic.AtomicInteger
import java.util.Collection
import scala.concurrent.{ BlockContext, ExecutionContext, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService }
import scala.annotation.tailrec
private[scala] class ExecutionContextImpl private[impl] (val executor: Executor, val reporter: Throwable => Unit) extends ExecutionContextExecutor {
require(executor ne null, "Executor must not be null")
override def execute(runnable: Runnable) = executor execute runnable
override def reportFailure(t: Throwable) = reporter(t)
}
private[concurrent] object ExecutionContextImpl {
// Implement BlockContext on FJP threads
final class DefaultThreadFactory(
daemonic: Boolean,
maxThreads: Int,
prefix: String,
uncaught: Thread.UncaughtExceptionHandler) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory {
require(prefix ne null, "DefaultThreadFactory.prefix must be non null")
require(maxThreads > 0, "DefaultThreadFactory.maxThreads must be greater than 0")
private final val currentNumberOfThreads = new AtomicInteger(0)
@tailrec private final def reserveThread(): Boolean = currentNumberOfThreads.get() match {
case `maxThreads` | Int.`MaxValue` => false
case other => currentNumberOfThreads.compareAndSet(other, other + 1) || reserveThread()
}
@tailrec private final def deregisterThread(): Boolean = currentNumberOfThreads.get() match {
case 0 => false
case other => currentNumberOfThreads.compareAndSet(other, other - 1) || deregisterThread()
}
def wire[T <: Thread](thread: T): T = {
thread.setDaemon(daemonic)
thread.setUncaughtExceptionHandler(uncaught)
thread.setName(prefix + "-" + thread.getId())
thread
}
// As per ThreadFactory contract newThread should return `null` if cannot create new thread.
def newThread(runnable: Runnable): Thread =
if (reserveThread())
wire(new Thread(new Runnable {
// We have to decrement the current thread count when the thread exits
override def run() = try runnable.run() finally deregisterThread()
})) else null
def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread =
if (reserveThread()) {
wire(new ForkJoinWorkerThread(fjp) with BlockContext {
// We have to decrement the current thread count when the thread exits
final override def onTermination(exception: Throwable): Unit = deregisterThread()
final override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = {
var result: T = null.asInstanceOf[T]
ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker {
@volatile var isdone = false
override def block(): Boolean = {
result = try {
// When we block, switch out the BlockContext temporarily so that nested blocking does not created N new Threads
BlockContext.withBlockContext(BlockContext.defaultBlockContext) { thunk }
} finally {
isdone = true
}
true
}
override def isReleasable = isdone
})
result
}
})
} else null
}
def createDefaultExecutorService(reporter: Throwable => Unit): ExecutorService = {
def getInt(name: String, default: String) = (try System.getProperty(name, default) catch {
case e: SecurityException => default
}) match {
case s if s.charAt(0) == 'x' => (Runtime.getRuntime.availableProcessors * s.substring(1).toDouble).ceil.toInt
case other => other.toInt
}
def range(floor: Int, desired: Int, ceiling: Int) = scala.math.min(scala.math.max(floor, desired), ceiling)
val numThreads = getInt("scala.concurrent.context.numThreads", "x1")
// The hard limit on the number of active threads that the thread factory will produce
// SI-8955 Deadlocks can happen if maxNoOfThreads is too low, although we're currently not sure
// about what the exact threshhold is. numThreads + 256 is conservatively high.
val maxNoOfThreads = getInt("scala.concurrent.context.maxThreads", "x1")
val desiredParallelism = range(
getInt("scala.concurrent.context.minThreads", "1"),
numThreads,
maxNoOfThreads)
// The thread factory must provide additional threads to support managed blocking.
val maxExtraThreads = getInt("scala.concurrent.context.maxExtraThreads", "256")
val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause)
}
val threadFactory = new ExecutionContextImpl.DefaultThreadFactory(daemonic = true,
maxThreads = maxNoOfThreads + maxExtraThreads,
prefix = "scala-execution-context-global",
uncaught = uncaughtExceptionHandler)
new ForkJoinPool(desiredParallelism, threadFactory, uncaughtExceptionHandler, true) {
override def execute(runnable: Runnable): Unit = {
val fjt: ForkJoinTask[_] = runnable match {
case t: ForkJoinTask[_] => t
case r => new ExecutionContextImpl.AdaptedForkJoinTask(r)
}
Thread.currentThread match {
case fjw: ForkJoinWorkerThread if fjw.getPool eq this => fjt.fork()
case _ => super.execute(fjt)
}
}
}
}
final class AdaptedForkJoinTask(runnable: Runnable) extends ForkJoinTask[Unit] {
final override def setRawResult(u: Unit): Unit = ()
final override def getRawResult(): Unit = ()
final override def exec(): Boolean = try { runnable.run(); true } catch {
case anything: Throwable =>
val t = Thread.currentThread
t.getUncaughtExceptionHandler match {
case null =>
case some => some.uncaughtException(t, anything)
}
throw anything
}
}
def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl =
new ExecutionContextImpl(Option(e).getOrElse(createDefaultExecutorService(reporter)), reporter)
def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter):
ExecutionContextImpl with ExecutionContextExecutorService = {
new ExecutionContextImpl(Option(es).getOrElse(createDefaultExecutorService(reporter)), reporter)
with ExecutionContextExecutorService {
final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService]
override def execute(command: Runnable) = executor.execute(command)
override def shutdown() { asExecutorService.shutdown() }
override def shutdownNow() = asExecutorService.shutdownNow()
override def isShutdown = asExecutorService.isShutdown
override def isTerminated = asExecutorService.isTerminated
override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit)
override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable)
override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t)
override def submit(runnable: Runnable) = asExecutorService.submit(runnable)
override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables)
override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit)
override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables)
override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy