All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonTest.flow.FlowInvariantsTest.kt Maven / Gradle / Ivy

There is a newer version: 1.9.0
Show newest version
/*
 * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
 */

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlin.coroutines.*
import kotlin.reflect.*
import kotlin.test.*

class FlowInvariantsTest : TestBase() {

    private fun  runParametrizedTest(
        expectedException: KClass? = null,
        testBody: suspend (flowFactory: (suspend FlowCollector.() -> Unit) -> Flow) -> Unit
    ) = runTest {
        val r1 = runCatching { testBody { flow(it) } }.exceptionOrNull()
        check(r1, expectedException)
        reset()

        val r2 = runCatching { testBody { abstractFlow(it) } }.exceptionOrNull()
        check(r2, expectedException)
    }

    private fun  abstractFlow(block: suspend FlowCollector.() -> Unit): Flow = object : AbstractFlow() {
        override suspend fun collectSafely(collector: FlowCollector) {
            collector.block()
        }
    }

    private fun check(exception: Throwable?, expectedException: KClass?) {
        if (expectedException != null && exception == null) fail("Expected $expectedException, but test completed successfully")
        if (expectedException != null && exception != null) assertTrue(expectedException.isInstance(exception))
        if (expectedException == null && exception != null) throw exception
    }

    @Test
    fun testWithContextContract() = runParametrizedTest(IllegalStateException::class) { flow ->
        flow {
            withContext(NonCancellable) {
                emit(1)
            }
        }.collect {
            expectUnreached()
        }
    }

    @Test
    fun testWithDispatcherContractViolated() = runParametrizedTest(IllegalStateException::class) { flow ->
        flow {
            withContext(NamedDispatchers("foo")) {
                emit(1)
            }
        }.collect {
            expectUnreached()
        }
    }

    @Test
    fun testWithNameContractViolated() = runParametrizedTest(IllegalStateException::class) { flow ->
        flow {
            withContext(CoroutineName("foo")) {
                emit(1)
            }
        }.collect {
            expectUnreached()
        }
    }

    @Test
    fun testWithContextDoesNotChangeExecution() = runTest {
        val flow = flow {
            emit(NamedDispatchers.name())
        }.flowOn(NamedDispatchers("original"))

        var result = "unknown"
        withContext(NamedDispatchers("misc")) {
            flow
                .flowOn(NamedDispatchers("upstream"))
                .launchIn(this + NamedDispatchers("consumer")) {
                    onEach {
                        result = it
                    }
                }.join()
        }
        assertEquals("original", result)
    }

    @Test
    fun testScopedJob() = runParametrizedTest(IllegalStateException::class) { flow ->
        flow { emit(1) }.buffer(EmptyCoroutineContext, flow).collect {
            expect(1)
        }
        finish(2)
    }

    @Test
    fun testScopedJobWithViolation() = runParametrizedTest(IllegalStateException::class) { flow ->
        flow { emit(1) }.buffer(Dispatchers.Unconfined, flow).collect {
            expect(1)
        }
        finish(2)
    }

    @Test
    fun testMergeViolation() = runParametrizedTest { flow ->
        fun Flow.merge(other: Flow): Flow = flow {
            coroutineScope {
                launch {
                    collect { value -> emit(value) }
                }
                other.collect { value -> emit(value) }
            }
        }

        fun Flow.trickyMerge(other: Flow): Flow = flow {
            coroutineScope {
                launch {
                    collect { value ->
                        coroutineScope { emit(value) }
                    }
                }
                other.collect { value -> emit(value) }
            }
        }

        val flowInstance = flowOf(1)
        assertFailsWith { flowInstance.merge(flowInstance).toList() }
        assertFailsWith { flowInstance.trickyMerge(flowInstance).toList() }
    }

    @Test
    fun testNoMergeViolation() = runTest {
        fun Flow.merge(other: Flow): Flow = channelFlow {
            launch {
                collect { value -> send(value) }
            }
            other.collect { value -> send(value) }
        }

        fun Flow.trickyMerge(other: Flow): Flow = channelFlow {
            coroutineScope {
                launch {
                    collect { value ->
                        coroutineScope { send(value) }
                    }
                }
                other.collect { value -> send(value) }
            }
        }

        val flow = flowOf(1)
        assertEquals(listOf(1, 1), flow.merge(flow).toList())
        assertEquals(listOf(1, 1), flow.trickyMerge(flow).toList())
    }

    @Test
    fun testScopedCoroutineNoViolation() = runParametrizedTest { flow ->
        fun Flow.buffer(): Flow = flow {
            coroutineScope {
                val channel = produce {
                    collect {
                        send(it)
                    }
                }
                channel.consumeEach {
                    emit(it)
                }
            }
        }
        assertEquals(listOf(1, 1), flowOf(1, 1).buffer().toList())
    }

    private fun Flow.buffer(coroutineContext: CoroutineContext, flow: (suspend FlowCollector.() -> Unit) -> Flow): Flow = flow {
        coroutineScope {
            val channel = Channel()
            launch {
                collect { value ->
                    channel.send(value)
                }
                channel.close()
            }

            launch(coroutineContext) {
                for (i in channel) {
                    emit(i)
                }
            }
        }
    }

    @Test
    fun testEmptyCoroutineContextMap() = runTest {
        emptyContextTest {
            map {
                expect(it)
                it + 1
            }
        }
    }

    @Test
    fun testEmptyCoroutineContextTransform() = runTest {
        emptyContextTest {
            transform {
                expect(it)
                emit(it + 1)
            }
        }
    }

    @Test
    fun testEmptyCoroutineContextTransformWhile() = runTest {
        emptyContextTest {
            transformWhile {
                expect(it)
                emit(it + 1)
                true
            }
        }
    }

    @Test
    fun testEmptyCoroutineContextViolationTransform() = runTest {
        try {
            emptyContextTest {
                transform {
                    expect(it)
                    withContext(Dispatchers.Unconfined) {
                        emit(it + 1)
                    }
                }
            }
            expectUnreached()
        } catch (e: IllegalStateException) {
            assertTrue(e.message!!.contains("Flow invariant is violated"), "But had: ${e.message}")
            finish(2)
        }
    }

    @Test
    fun testEmptyCoroutineContextViolationTransformWhile() = runTest {
        try {
            emptyContextTest {
                transformWhile {
                    expect(it)
                    withContext(Dispatchers.Unconfined) {
                        emit(it + 1)
                    }
                    true
                }
            }
            expectUnreached()
        } catch (e: IllegalStateException) {
            assertTrue(e.message!!.contains("Flow invariant is violated"))
            finish(2)
        }
    }

    private suspend fun emptyContextTest(block: Flow.() -> Flow) {
        suspend fun collector(): Int {
            var result: Int = -1
            channelFlow {
                send(1)
            }.block()
                .collect {
                    expect(it)
                    result = it
                }
            return result
        }

        val result = withEmptyContext { collector() }
        assertEquals(2, result)
        finish(3)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy