All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dispatch.test.CoroutineTestExtension.kt Maven / Gradle / Ivy

/*
 * Copyright (C) 2020 Rick Busarow
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package dispatch.test

import dispatch.test.CoroutineTestExtension.*
import dispatch.test.internal.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.*
import org.junit.jupiter.api.extension.*
import org.junit.jupiter.api.extension.support.*
import kotlin.coroutines.*

/**
 * JUnit 5 [ParameterResolver] [extension][Extension] for injecting and managing a [TestProvidedCoroutineScope] in a test instance.
 * This creates a new instance of [TestProvidedCoroutineScope] each time the scope is injected, optionally using a custom [ScopeFactory].
 *
 * If this extension is initialized via [RegisterExtension], there is also a [scope] property which is automatically managed.
 *
 * ### Before Each:
 * [Dispatchers.Main] is set to the [TestCoroutineDispatcher] used by the [CoroutineContext].
 *
 * ### After Each:
 * * [cleanupTestCoroutines][TestCoroutineScope.cleanupTestCoroutines] is called to ensure there are no leaking coroutines.  Any unfinished coroutine
 * will throw an [UncompletedCoroutinesError].
 * * [Dispatchers.Main] is reset via [Dispatchers.resetMain].
 *
 * ### Requires JUnit 5.
 * ``` groovy
 * dependencies {
 *   testImplementation "org.junit.jupiter:junit-jupiter:5.6.2"
 * }
 * ```
 *
 * @see CoroutineTestExtension
 * @see CoroutineTest
 * @param scopeFactory *optional* factory for a custom [TestProvidedCoroutineScope].  If a factory is not provided,
 * the resultant scope uses the same [TestCoroutineDispatcher] for each property in its [TestDispatcherProvider]
 * @sample samples.RegisterSample
 * @sample samples.RegisterWithFactorySample
 * @sample samples.CoroutineTestExtensionExtendWithSample
 */
@ExperimentalCoroutinesApi
public class CoroutineTestExtension(
  private val scopeFactory: ScopeFactory = ScopeFactory()
) : TypeBasedParameterResolver(),
    BeforeEachCallback,
    AfterEachCallback {

  private val lazyScope = lazy { scopeFactory.create() }

  /**
   * A single `TestProvidedCoroutineScope` instance which is reset via [cleanUpTestCoroutines][TestCoroutineScope.cleanupTestCoroutines] after each test.
   */
  val scope: TestProvidedCoroutineScope
    get() = lazyScope.value

  private val contextScopeMap = mutableMapOf()

  /**
   * @suppress
   */
  override fun resolveParameter(
    parameterContext: ParameterContext,
    extensionContext: ExtensionContext
  ): TestProvidedCoroutineScope {

    val scope = contextScopeMap[extensionContext]

    // if a scope was already created for this context (not sure how), return it
    if (scope != null) return scope

    val annotation = extensionContext.getAnnotationRecursive()

    // if the extension is registered via @RegisterExtension,
    // then its factory would be passed at instance creation and is already set
    val newScope = if (annotation == null) {
      scopeFactory.create()
    } else {
      @Suppress("UNCHECKED_CAST")
      val factoryClass = annotation.scopeFactory

      val factory = factoryClass.java.getConstructor().newInstance() as? ScopeFactory
        ?: throw ParameterResolutionException(
          """A ${CoroutineTest::class.simpleName} annotation was found with an incompatible factory type.
            |
            |The specified factory must be <${ScopeFactory::class.qualifiedName}> or a subtype.
            |
            |The provided factory type was:  <${factoryClass.qualifiedName}>
            |""".trimMargin()
        )

      factory.create()
    }

    contextScopeMap[extensionContext] = newScope

    return newScope
  }

  /**
   * @suppress
   */
  override fun beforeEach(context: ExtensionContext) {
    contextScopeMap[context]?.dispatcherProvider?.main?.let {
      Dispatchers.setMain(it)
    }
  }

  /**
   * @suppress
   */
  override fun afterEach(context: ExtensionContext) {

    if (lazyScope.isInitialized()) {
      scope.cleanupTestCoroutines()
    }

    contextScopeMap[context]?.cleanupTestCoroutines()
    Dispatchers.resetMain()
  }

  /**
   * Class used to create the [TestProvidedCoroutineScope] used in [CoroutineTestExtension].
   *
   * In order to provide a custom implementation of [TestProvidedCoroutineScope]:
   * 1. Create a custom factory which has a default constructor and extends this `ScopeFactory`
   * 2. Annotate your test class with [CoroutineTest] and pass your custom factory's `KClass` in as its parameter.
   *
   * @sample samples.CoroutineTestNamedFactorySample
   */
  @ExperimentalCoroutinesApi
  open class ScopeFactory {

    /**
     * Creates an instance of [TestProvidedCoroutineScope].  Uses the no-arg factory by default.
     */
    open fun create(): TestProvidedCoroutineScope =
      TestProvidedCoroutineScope()
  }
}

/**
 * Factory function for creating a [CoroutineTestExtension].
 *
 * @see CoroutineTestExtension
 * @see CoroutineTest
 * @param scopeFactory *optional* factory for a custom [TestProvidedCoroutineScope].  If a factory is not provided,
 * the resultant scope uses the same [TestCoroutineDispatcher] for each property in its [TestDispatcherProvider]
 * @sample samples.RegisterSample
 * @sample samples.RegisterWithFactorySample
 */
@ExperimentalCoroutinesApi
public inline fun coroutineTestExtension(
  /**
   * This factory lambda creates the [TestProvidedCoroutineScope] which is managed by the [CoroutineTestExtension].
   *
   * By default, it creates a standard [TestProvidedCoroutineScope].
   */
  crossinline scopeFactory: () -> TestProvidedCoroutineScope = { TestProvidedCoroutineScope() }
): CoroutineTestExtension =
  CoroutineTestExtension(object : ScopeFactory() {
    override fun create(): TestProvidedCoroutineScope = scopeFactory.invoke()
  })




© 2015 - 2024 Weber Informatics LLC | Privacy Policy