All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.TestCoroutineScheduler.kt Maven / Gradle / Ivy

The newest version!
package kotlinx.coroutines.test

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.selects.*
import kotlin.coroutines.*
import kotlin.jvm.*
import kotlin.time.*
import kotlin.time.Duration.Companion.milliseconds

/**
 * This is a scheduler for coroutines used in tests, providing the delay-skipping behavior.
 *
 * [Test dispatchers][TestDispatcher] are parameterized with a scheduler. Several dispatchers can share the
 * same scheduler, in which case their knowledge about the virtual time will be synchronized. When the dispatchers
 * require scheduling an event at a later point in time, they notify the scheduler, which will establish the order of
 * the tasks.
 *
 * The scheduler can be queried to advance the time (via [advanceTimeBy]), run all the scheduled tasks advancing the
 * virtual time as needed (via [advanceUntilIdle]), or run the tasks that are scheduled to run as soon as possible but
 * haven't yet been dispatched (via [runCurrent]).
 */
public class TestCoroutineScheduler : AbstractCoroutineContextElement(TestCoroutineScheduler),
    CoroutineContext.Element {

    /** @suppress */
    public companion object Key : CoroutineContext.Key

    /** This heap stores the knowledge about which dispatchers are interested in which moments of virtual time. */
    // TODO: all the synchronization is done via a separate lock, so a non-thread-safe priority queue can be used.
    private val events = ThreadSafeHeap>()

    /** Establishes that [currentTime] can't exceed the time of the earliest event in [events]. */
    private val lock = SynchronizedObject()

    /** This counter establishes some order on the events that happen at the same virtual time. */
    private val count = atomic(0L)

    /** The current virtual time in milliseconds. */
    @ExperimentalCoroutinesApi
    public var currentTime: Long = 0
        get() = synchronized(lock) { field }
        private set

    /** A channel for notifying about the fact that a foreground work dispatch recently happened. */
    private val dispatchEventsForeground: Channel = Channel(CONFLATED)

    /** A channel for notifying about the fact that a dispatch recently happened. */
    private val dispatchEvents: Channel = Channel(CONFLATED)

    /**
     * Registers a request for the scheduler to notify [dispatcher] at a virtual moment [timeDeltaMillis] milliseconds
     * later via [TestDispatcher.processEvent], which will be called with the provided [marker] object.
     *
     * Returns the handler which can be used to cancel the registration.
     */
    internal fun  registerEvent(
        dispatcher: TestDispatcher,
        timeDeltaMillis: Long,
        marker: T,
        context: CoroutineContext,
        isCancelled: (T) -> Boolean
    ): DisposableHandle {
        require(timeDeltaMillis >= 0) { "Attempted scheduling an event earlier in time (with the time delta $timeDeltaMillis)" }
        checkSchedulerInContext(this, context)
        val count = count.getAndIncrement()
        val isForeground = context[BackgroundWork] === null
        return synchronized(lock) {
            val time = addClamping(currentTime, timeDeltaMillis)
            val event = TestDispatchEvent(dispatcher, count, time, marker as Any, isForeground) { isCancelled(marker) }
            events.addLast(event)
            /** can't be moved above: otherwise, [onDispatchEventForeground] or [onDispatchEvent] could consume the
             * token sent here before there's actually anything in the event queue. */
            sendDispatchEvent(context)
            DisposableHandle {
                synchronized(lock) {
                    events.remove(event)
                }
            }
        }
    }

    /**
     * Runs the next enqueued task, advancing the virtual time to the time of its scheduled awakening,
     * unless [condition] holds.
     */
    internal fun tryRunNextTaskUnless(condition: () -> Boolean): Boolean {
        val event = synchronized(lock) {
            if (condition()) return false
            val event = events.removeFirstOrNull() ?: return false
            if (currentTime > event.time)
                currentTimeAheadOfEvents()
            currentTime = event.time
            event
        }
        event.dispatcher.processEvent(event.marker)
        return true
    }

    /**
     * Runs the enqueued tasks in the specified order, advancing the virtual time as needed until there are no more
     * tasks associated with the dispatchers linked to this scheduler.
     *
     * A breaking change from `TestCoroutineDispatcher.advanceTimeBy` is that it no longer returns the total number of
     * milliseconds by which the execution of this method has advanced the virtual time. If you want to recreate that
     * functionality, query [currentTime] before and after the execution to achieve the same result.
     */
    public fun advanceUntilIdle(): Unit = advanceUntilIdleOr { events.none(TestDispatchEvent<*>::isForeground) }

    /**
     * [condition]: guaranteed to be invoked under the lock.
     */
    internal fun advanceUntilIdleOr(condition: () -> Boolean) {
        while (true) {
            if (!tryRunNextTaskUnless(condition))
                return
        }
    }

    /**
     * Runs the tasks that are scheduled to execute at this moment of virtual time.
     */
    public fun runCurrent() {
        val timeMark = synchronized(lock) { currentTime }
        while (true) {
            val event = synchronized(lock) {
                events.removeFirstIf { it.time <= timeMark } ?: return
            }
            event.dispatcher.processEvent(event.marker)
        }
    }

    /**
     * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTimeMillis], running the
     * scheduled tasks in the meantime.
     *
     * Breaking changes from [TestCoroutineDispatcher.advanceTimeBy]:
     * - Intentionally doesn't return a `Long` value, as its use cases are unclear. We may restore it in the future;
     *   please describe your use cases at [the issue tracker](https://github.com/Kotlin/kotlinx.coroutines/issues/).
     *   For now, it's possible to query [currentTime] before and after execution of this method, to the same effect.
     * - It doesn't run the tasks that are scheduled at exactly [currentTime] + [delayTimeMillis]. For example,
     *   advancing the time by one millisecond used to run the tasks at the current millisecond *and* the next
     *   millisecond, but now will stop just before executing any task starting at the next millisecond.
     * - Overflowing the target time used to lead to nothing being done, but will now run the tasks scheduled at up to
     *   (but not including) [Long.MAX_VALUE].
     *
     * @throws IllegalArgumentException if passed a negative [delay][delayTimeMillis].
     */
    @ExperimentalCoroutinesApi
    public fun advanceTimeBy(delayTimeMillis: Long): Unit = advanceTimeBy(delayTimeMillis.milliseconds)

    /**
     * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTime], running the
     * scheduled tasks in the meantime.
     *
     * @throws IllegalArgumentException if passed a negative [delay][delayTime].
     */
    public fun advanceTimeBy(delayTime: Duration) {
        require(!delayTime.isNegative()) { "Can not advance time by a negative delay: $delayTime" }
        val startingTime = currentTime
        val targetTime = addClamping(startingTime, delayTime.inWholeMilliseconds)
        while (true) {
            val event = synchronized(lock) {
                val timeMark = currentTime
                val event = events.removeFirstIf { targetTime > it.time }
                when {
                    event == null -> {
                        currentTime = targetTime
                        return
                    }
                    timeMark > event.time -> currentTimeAheadOfEvents()
                    else -> {
                        currentTime = event.time
                        event
                    }
                }
            }
            event.dispatcher.processEvent(event.marker)
        }
    }

    /**
     * Checks that the only tasks remaining in the scheduler are cancelled.
     */
    internal fun isIdle(strict: Boolean = true): Boolean =
        synchronized(lock) {
            if (strict) events.isEmpty else events.none { !it.isCancelled() }
        }

    /**
     * Notifies this scheduler about a dispatch event.
     *
     * [context] is the context in which the task will be dispatched.
     */
    internal fun sendDispatchEvent(context: CoroutineContext) {
        dispatchEvents.trySend(Unit)
        if (context[BackgroundWork] !== BackgroundWork)
            dispatchEventsForeground.trySend(Unit)
    }

    /**
     * Waits for a notification about a dispatch event.
     */
    internal suspend fun receiveDispatchEvent() = dispatchEvents.receive()

    /**
     * Consumes the knowledge that a dispatch event happened recently.
     */
    internal val onDispatchEvent: SelectClause1 get() = dispatchEvents.onReceive

    /**
     * Consumes the knowledge that a foreground work dispatch event happened recently.
     */
    internal val onDispatchEventForeground: SelectClause1 get() = dispatchEventsForeground.onReceive

    /**
     * Returns the [TimeSource] representation of the virtual time of this scheduler.
     */
    public val timeSource: TimeSource.WithComparableMarks = object : AbstractLongTimeSource(DurationUnit.MILLISECONDS) {
        override fun read(): Long = currentTime
    }
}

// Some error-throwing functions for pretty stack traces
private fun currentTimeAheadOfEvents(): Nothing = invalidSchedulerState()

private fun invalidSchedulerState(): Nothing =
    throw IllegalStateException("The test scheduler entered an invalid state. Please report this at https://github.com/Kotlin/kotlinx.coroutines/issues.")

/** [ThreadSafeHeap] node representing a scheduled task, ordered by the planned execution time. */
private class TestDispatchEvent(
    @JvmField val dispatcher: TestDispatcher,
    private val count: Long,
    @JvmField val time: Long,
    @JvmField val marker: T,
    @JvmField val isForeground: Boolean,
    // TODO: remove once the deprecated API is gone
    @JvmField val isCancelled: () -> Boolean
) : Comparable>, ThreadSafeHeapNode {
    override var heap: ThreadSafeHeap<*>? = null
    override var index: Int = 0

    override fun compareTo(other: TestDispatchEvent<*>) =
        compareValuesBy(this, other, TestDispatchEvent<*>::time, TestDispatchEvent<*>::count)

    override fun toString() = "TestDispatchEvent(time=$time, dispatcher=$dispatcher${if (isForeground) "" else ", background"})"
}

// works with positive `a`, `b`
private fun addClamping(a: Long, b: Long): Long = (a + b).let { if (it >= 0) it else Long.MAX_VALUE }

internal fun checkSchedulerInContext(scheduler: TestCoroutineScheduler, context: CoroutineContext) {
    context[TestCoroutineScheduler]?.let {
        check(it === scheduler) {
            "Detected use of different schedulers. If you need to use several test coroutine dispatchers, " +
                "create one `TestCoroutineScheduler` and pass it to each of them."
        }
    }
}

/**
 * A coroutine context key denoting that the work is to be executed in the background.
 * @see [TestScope.backgroundScope]
 */
internal object BackgroundWork : CoroutineContext.Key, CoroutineContext.Element {
    override val key: CoroutineContext.Key<*>
        get() = this

    override fun toString(): String = "BackgroundWork"
}

private fun ThreadSafeHeap.none(predicate: (T) -> Boolean) where T: ThreadSafeHeapNode, T: Comparable =
    find(predicate) == null




© 2015 - 2024 Weber Informatics LLC | Privacy Policy