All Downloads are FREE. Search and download functionalities are using the official Maven repository.

jvmTest.VirtualTimeSource.kt Maven / Gradle / Ivy

There is a newer version: 1.9.0
Show newest version
/*
 * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
 */

package kotlinx.coroutines

import java.io.*
import java.util.concurrent.*
import java.util.concurrent.locks.*

private const val SHUTDOWN_TIMEOUT = 1000L

internal inline fun withVirtualTimeSource(log: PrintStream? = null, block: () -> Unit) {
    DefaultExecutor.shutdownForTests(SHUTDOWN_TIMEOUT) // shutdown execution with old time source (in case it was working)
    val testTimeSource = VirtualTimeSource(log)
    timeSource = testTimeSource
    DefaultExecutor.ensureStarted() // should start with new time source
    try {
        block()
    } finally {
        DefaultExecutor.shutdownForTests(SHUTDOWN_TIMEOUT)
        testTimeSource.shutdown()
        timeSource = null // restore time source
    }
}

private const val NOT_PARKED = -1L

private class ThreadStatus {
    @Volatile @JvmField
    var parkedTill = NOT_PARKED
    @Volatile @JvmField
    var permit = false
    var registered = 0
    override fun toString(): String = "parkedTill = ${TimeUnit.NANOSECONDS.toMillis(parkedTill)} ms, permit = $permit"
}

private const val MAX_WAIT_NANOS = 10_000_000_000L // 10s
private const val REAL_TIME_STEP_NANOS = 200_000_000L // 200 ms
private const val REAL_PARK_NANOS = 10_000_000L // 10 ms -- park for a little to better track real-time

@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
internal class VirtualTimeSource(
    private val log: PrintStream?
) : AbstractTimeSource() {
    private val mainThread: Thread = Thread.currentThread()
    private var checkpointNanos: Long = System.nanoTime()

    @Volatile
    private var isShutdown = false

    @Volatile
    private var time: Long = 0

    private var trackedTasks = 0

    private val threads = ConcurrentHashMap()

    override fun currentTimeMillis(): Long = TimeUnit.NANOSECONDS.toMillis(time)
    override fun nanoTime(): Long = time

    override fun wrapTask(block: Runnable): Runnable {
        trackTask()
        return Runnable {
            try { block.run() }
            finally { unTrackTask() }
        }
    }

    @Synchronized
    override fun trackTask() {
        trackedTasks++
    }

    @Synchronized
    override fun unTrackTask() {
        assert(trackedTasks > 0)
        trackedTasks--
    }

    @Synchronized
    override fun registerTimeLoopThread() {
        val status = threads.getOrPut(Thread.currentThread()) { ThreadStatus() }!!
        status.registered++
    }

    @Synchronized
    override fun unregisterTimeLoopThread() {
        val currentThread = Thread.currentThread()
        val status = threads[currentThread]!!
        if (--status.registered == 0) {
            threads.remove(currentThread)
            wakeupAll()
        }
    }

    override fun parkNanos(blocker: Any, nanos: Long) {
        if (nanos <= 0) return
        val status = threads[Thread.currentThread()]!!
        assert(status.parkedTill == NOT_PARKED)
        status.parkedTill = time + nanos.coerceAtMost(MAX_WAIT_NANOS)
        while (true) {
            checkAdvanceTime()
            if (isShutdown || time >= status.parkedTill || status.permit) {
                status.parkedTill = NOT_PARKED
                status.permit = false
                break
            }
            LockSupport.parkNanos(blocker, REAL_PARK_NANOS)
        }
    }

    override fun unpark(thread: Thread) {
        val status = threads[thread] ?: return
        status.permit = true
        LockSupport.unpark(thread)
    }

    @Synchronized
    private fun checkAdvanceTime() {
        if (isShutdown) return
        val realNanos = System.nanoTime()
        if (realNanos > checkpointNanos + REAL_TIME_STEP_NANOS) {
            checkpointNanos = realNanos
            val minParkedTill = minParkedTill()
            time = (time + REAL_TIME_STEP_NANOS).coerceAtMost(if (minParkedTill < 0) Long.MAX_VALUE else minParkedTill)
            logTime("R")
            wakeupAll()
            return
        }
        if (threads[mainThread] == null) return
        if (trackedTasks != 0) return
        val minParkedTill = minParkedTill()
        if (minParkedTill <= time) return
        time = minParkedTill
        logTime("V")
        wakeupAll()
    }

    private fun logTime(s: String) {
        log?.println("[$s: Time = ${TimeUnit.NANOSECONDS.toMillis(time)} ms]")
    }

    private fun minParkedTill(): Long =
        threads.values.map { if (it.permit) NOT_PARKED else it.parkedTill }.minOrNull() ?: NOT_PARKED

    @Synchronized
    fun shutdown() {
        isShutdown = true
        wakeupAll()
        while (!threads.isEmpty()) (this as Object).wait()
    }

    private fun wakeupAll() {
        threads.keys.forEach { LockSupport.unpark(it) }
        (this as Object).notifyAll()
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy