commonTest.flow.FlowInvariantsTest.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kotlinx-coroutines-core
Show all versions of kotlinx-coroutines-core
Coroutines support libraries for Kotlin
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines.flow
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlin.coroutines.*
import kotlin.reflect.*
import kotlin.test.*
class FlowInvariantsTest : TestBase() {
private fun runParametrizedTest(
expectedException: KClass? = null,
testBody: suspend (flowFactory: (suspend FlowCollector.() -> Unit) -> Flow) -> Unit
) = runTest {
val r1 = runCatching { testBody { flow(it) } }.exceptionOrNull()
check(r1, expectedException)
reset()
val r2 = runCatching { testBody { abstractFlow(it) } }.exceptionOrNull()
check(r2, expectedException)
}
private fun abstractFlow(block: suspend FlowCollector.() -> Unit): Flow = object : AbstractFlow() {
override suspend fun collectSafely(collector: FlowCollector) {
collector.block()
}
}
private fun check(exception: Throwable?, expectedException: KClass?) {
if (expectedException != null && exception == null) fail("Expected $expectedException, but test completed successfully")
if (expectedException != null && exception != null) assertTrue(expectedException.isInstance(exception))
if (expectedException == null && exception != null) throw exception
}
@Test
fun testWithContextContract() = runParametrizedTest(IllegalStateException::class) { flow ->
flow {
withContext(NonCancellable) {
emit(1)
}
}.collect {
expectUnreached()
}
}
@Test
fun testWithDispatcherContractViolated() = runParametrizedTest(IllegalStateException::class) { flow ->
flow {
withContext(NamedDispatchers("foo")) {
emit(1)
}
}.collect {
expectUnreached()
}
}
@Test
fun testWithNameContractViolated() = runParametrizedTest(IllegalStateException::class) { flow ->
flow {
withContext(CoroutineName("foo")) {
emit(1)
}
}.collect {
expectUnreached()
}
}
@Test
fun testWithContextDoesNotChangeExecution() = runTest {
val flow = flow {
emit(NamedDispatchers.name())
}.flowOn(NamedDispatchers("original"))
var result = "unknown"
withContext(NamedDispatchers("misc")) {
flow
.flowOn(NamedDispatchers("upstream"))
.launchIn(this + NamedDispatchers("consumer")) {
onEach {
result = it
}
}.join()
}
assertEquals("original", result)
}
@Test
fun testScopedJob() = runParametrizedTest(IllegalStateException::class) { flow ->
flow { emit(1) }.buffer(EmptyCoroutineContext, flow).collect {
expect(1)
}
finish(2)
}
@Test
fun testScopedJobWithViolation() = runParametrizedTest(IllegalStateException::class) { flow ->
flow { emit(1) }.buffer(Dispatchers.Unconfined, flow).collect {
expect(1)
}
finish(2)
}
@Test
fun testMergeViolation() = runParametrizedTest { flow ->
fun Flow.merge(other: Flow): Flow = flow {
coroutineScope {
launch {
collect { value -> emit(value) }
}
other.collect { value -> emit(value) }
}
}
fun Flow.trickyMerge(other: Flow): Flow = flow {
coroutineScope {
launch {
collect { value ->
coroutineScope { emit(value) }
}
}
other.collect { value -> emit(value) }
}
}
val flowInstance = flowOf(1)
assertFailsWith { flowInstance.merge(flowInstance).toList() }
assertFailsWith { flowInstance.trickyMerge(flowInstance).toList() }
}
@Test
fun testNoMergeViolation() = runTest {
fun Flow.merge(other: Flow): Flow = channelFlow {
launch {
collect { value -> send(value) }
}
other.collect { value -> send(value) }
}
fun Flow.trickyMerge(other: Flow): Flow = channelFlow {
coroutineScope {
launch {
collect { value ->
coroutineScope { send(value) }
}
}
other.collect { value -> send(value) }
}
}
val flow = flowOf(1)
assertEquals(listOf(1, 1), flow.merge(flow).toList())
assertEquals(listOf(1, 1), flow.trickyMerge(flow).toList())
}
@Test
fun testScopedCoroutineNoViolation() = runParametrizedTest { flow ->
fun Flow.buffer(): Flow = flow {
coroutineScope {
val channel = produce {
collect {
send(it)
}
}
channel.consumeEach {
emit(it)
}
}
}
assertEquals(listOf(1, 1), flowOf(1, 1).buffer().toList())
}
private fun Flow.buffer(coroutineContext: CoroutineContext, flow: (suspend FlowCollector.() -> Unit) -> Flow): Flow = flow {
coroutineScope {
val channel = Channel()
launch {
collect { value ->
channel.send(value)
}
channel.close()
}
launch(coroutineContext) {
for (i in channel) {
emit(i)
}
}
}
}
@Test
fun testEmptyCoroutineContextMap() = runTest {
emptyContextTest {
map {
expect(it)
it + 1
}
}
}
@Test
fun testEmptyCoroutineContextTransform() = runTest {
emptyContextTest {
transform {
expect(it)
emit(it + 1)
}
}
}
@Test
fun testEmptyCoroutineContextTransformWhile() = runTest {
emptyContextTest {
transformWhile {
expect(it)
emit(it + 1)
true
}
}
}
@Test
fun testEmptyCoroutineContextViolationTransform() = runTest {
try {
emptyContextTest {
transform {
expect(it)
withContext(Dispatchers.Unconfined) {
emit(it + 1)
}
}
}
expectUnreached()
} catch (e: IllegalStateException) {
assertTrue(e.message!!.contains("Flow invariant is violated"), "But had: ${e.message}")
finish(2)
}
}
@Test
fun testEmptyCoroutineContextViolationTransformWhile() = runTest {
try {
emptyContextTest {
transformWhile {
expect(it)
withContext(Dispatchers.Unconfined) {
emit(it + 1)
}
true
}
}
expectUnreached()
} catch (e: IllegalStateException) {
assertTrue(e.message!!.contains("Flow invariant is violated"))
finish(2)
}
}
private suspend fun emptyContextTest(block: Flow.() -> Flow) {
suspend fun collector(): Int {
var result: Int = -1
channelFlow {
send(1)
}.block()
.collect {
expect(it)
result = it
}
return result
}
val result = withEmptyContext { collector() }
assertEquals(2, result)
finish(3)
}
}