org.apache.pekko.testkit.TestKit.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pekko-testkit_2.13 Show documentation
Show all versions of pekko-testkit_2.13 Show documentation
Apache Pekko is a toolkit for building highly concurrent, distributed, and resilient message-driven applications for Java and Scala.
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, which was derived from Akka.
*/
/*
* Copyright (C) 2009-2022 Lightbend Inc.
*/
package org.apache.pekko.testkit
import java.util.concurrent._
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
import scala.collection.immutable
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import scala.annotation.nowarn
import org.apache.pekko
import pekko.actor._
import pekko.actor.DeadLetter
import pekko.actor.IllegalActorStateException
import pekko.actor.Terminated
import pekko.annotation.InternalApi
import pekko.util.{ BoxedType, Timeout }
object TestActor {
type Ignore = Option[PartialFunction[Any, Boolean]]
abstract class AutoPilot {
def run(sender: ActorRef, msg: Any): AutoPilot
def noAutoPilot: AutoPilot = NoAutoPilot
def keepRunning: AutoPilot = KeepRunning
}
case object NoAutoPilot extends AutoPilot {
def run(sender: ActorRef, msg: Any): AutoPilot = this
}
case object KeepRunning extends AutoPilot {
def run(sender: ActorRef, msg: Any): AutoPilot = sys.error("must not call")
}
final case class SetIgnore(i: Ignore) extends NoSerializationVerificationNeeded
final case class Watch(ref: ActorRef) extends NoSerializationVerificationNeeded
final case class UnWatch(ref: ActorRef) extends NoSerializationVerificationNeeded
final case class SetAutoPilot(ap: AutoPilot) extends NoSerializationVerificationNeeded
final case class Spawn(props: Props, name: Option[String] = None, strategy: Option[SupervisorStrategy] = None)
extends NoSerializationVerificationNeeded {
def apply(context: ActorRefFactory): ActorRef = name match {
case Some(n) => context.actorOf(props, n)
case None => context.actorOf(props)
}
}
trait Message {
def msg: AnyRef
def sender: ActorRef
}
final case class RealMessage(msg: AnyRef, sender: ActorRef) extends Message
case object NullMessage extends Message {
override def msg: AnyRef = throw IllegalActorStateException("last receive did not dequeue a message")
override def sender: ActorRef = throw IllegalActorStateException("last receive did not dequeue a message")
}
val FALSE = (_: Any) => false
/** INTERNAL API */
private[TestActor] class DelegatingSupervisorStrategy extends SupervisorStrategy {
import SupervisorStrategy._
private var delegates = Map.empty[ActorRef, SupervisorStrategy]
private def delegate(child: ActorRef) = delegates.get(child).getOrElse(stoppingStrategy)
def update(child: ActorRef, supervisor: SupervisorStrategy): Unit = delegates += (child -> supervisor)
override def decider = defaultDecider // not actually invoked
override def handleChildTerminated(context: ActorContext, child: ActorRef, children: Iterable[ActorRef]): Unit = {
delegates -= child
}
override def processFailure(
context: ActorContext,
restart: Boolean,
child: ActorRef,
cause: Throwable,
stats: ChildRestartStats,
children: Iterable[ChildRestartStats]): Unit = {
delegate(child).processFailure(context, restart, child, cause, stats, children)
}
override def handleFailure(
context: ActorContext,
child: ActorRef,
cause: Throwable,
stats: ChildRestartStats,
children: Iterable[ChildRestartStats]): Boolean = {
delegate(child).handleFailure(context, child, cause, stats, children)
}
}
// make creator serializable, for VerifySerializabilitySpec
def props(queue: BlockingDeque[Message]): Props = Props(classOf[TestActor], queue)
}
class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor {
import TestActor._
override val supervisorStrategy: DelegatingSupervisorStrategy = new DelegatingSupervisorStrategy
var ignore: Ignore = None
var autopilot: AutoPilot = NoAutoPilot
def receive = {
case SetIgnore(ign) => ignore = ign
case Watch(ref) => context.watch(ref)
case UnWatch(ref) => context.unwatch(ref)
case SetAutoPilot(pilot) => autopilot = pilot
case spawn: Spawn =>
val actor = spawn(context)
for (s <- spawn.strategy) supervisorStrategy(actor) = s
queue.offerLast(RealMessage(actor, self))
case x: AnyRef =>
autopilot = autopilot.run(sender(), x) match {
case KeepRunning => autopilot
case other => other
}
val observe = ignore.map(ignoreFunc => !ignoreFunc.applyOrElse(x, FALSE)).getOrElse(true)
if (observe) queue.offerLast(RealMessage(x, sender()))
}
override def postStop() = {
import pekko.util.ccompat.JavaConverters._
queue.asScala.foreach { m =>
context.system.deadLetters.tell(DeadLetter(m.msg, m.sender, self), m.sender)
}
}
}
/**
* Implementation trait behind the [[pekko.testkit.TestKit]] class: you may use
* this if inheriting from a concrete class is not possible.
*
* This trait requires the concrete class mixing it in to provide an
* [[pekko.actor.ActorSystem]] which is available before this traits’s
* constructor is run. The recommended way is this:
*
* {{{
* class MyTest extends TestKitBase {
* implicit lazy val system = ActorSystem() // may add arguments here
* ...
* }
* }}}
*/
trait TestKitBase {
import TestActor.{ Message, NullMessage, RealMessage, Spawn }
implicit def system: ActorSystem
def testKitSettings = TestKitExtension(system)
private val queue = new LinkedBlockingDeque[Message]()
private[pekko] var lastMessage: Message = NullMessage
def lastSender = lastMessage.sender
/**
* Defines the testActor name.
*/
protected def testActorName: String = "testActor"
/**
* ActorRef of the test actor. Access is provided to enable e.g.
* registration as message target.
*/
lazy val testActor: ActorRef = {
val impl = system.asInstanceOf[ExtendedActorSystem]
val ref = impl.systemActorOf(
TestActor.props(queue).withDispatcher(CallingThreadDispatcher.Id),
"%s-%d".format(testActorName, TestKit.testActorId.incrementAndGet))
awaitCond(ref match {
case r: RepointableRef => r.isStarted
case _ => true
}, 3.seconds.dilated, 10.millis)
ref
}
private var end: Duration = Duration.Undefined
/**
* if last assertion was expectNoMsg, disable timing failure upon within()
* block end.
*/
private var lastWasNoMsg = false
/**
* Ignore all messages in the test actor for which the given partial
* function returns true.
*/
def ignoreMsg(f: PartialFunction[Any, Boolean]): Unit = { testActor ! TestActor.SetIgnore(Some(f)) }
/**
* Stop ignoring messages in the test actor.
*/
def ignoreNoMsg(): Unit = { testActor ! TestActor.SetIgnore(None) }
/**
* Have the testActor watch someone (i.e. `context.watch(...)`).
*/
def watch(ref: ActorRef): ActorRef = {
testActor ! TestActor.Watch(ref)
ref
}
/**
* Have the testActor stop watching someone (i.e. `context.unwatch(...)`).
*/
def unwatch(ref: ActorRef): ActorRef = {
testActor ! TestActor.UnWatch(ref)
ref
}
/**
* Install an AutoPilot to drive the testActor: the AutoPilot will be run
* for each received message and can be used to send or forward messages,
* etc. Each invocation must return the AutoPilot for the next round.
*/
def setAutoPilot(pilot: TestActor.AutoPilot): Unit = testActor ! TestActor.SetAutoPilot(pilot)
/**
* Obtain current time (`System.nanoTime`) as Duration.
*/
def now: FiniteDuration = System.nanoTime.nanos
/**
* Obtain time remaining for execution of the innermost enclosing `within`
* block or missing that it returns the properly dilated default for this
* case from settings (key "pekko.test.single-expect-default").
*/
def remainingOrDefault = remainingOr(testKitSettings.SingleExpectDefaultTimeout.dilated)
/**
* Obtain time remaining for execution of the innermost enclosing `within`
* block or throw an [[java.lang.AssertionError]] if no `within` block surrounds this
* call.
*/
def remaining: FiniteDuration = end match {
case f: FiniteDuration => f - now
case _ => throw new AssertionError("`remaining` may not be called outside of `within`")
}
/**
* Obtain time remaining for execution of the innermost enclosing `within`
* block or missing that it returns the given duration.
*/
def remainingOr(duration: FiniteDuration): FiniteDuration = end match {
case x if x eq Duration.Undefined => duration
case x if !x.isFinite => throw new IllegalArgumentException("`end` cannot be infinite")
case f: FiniteDuration => f - now
case _ => throw new IllegalArgumentException("`end` cannot be infinite")
}
private def remainingOrDilated(max: Duration): FiniteDuration = max match {
case x if x eq Duration.Undefined => remainingOrDefault
case x if !x.isFinite => throw new IllegalArgumentException("max duration cannot be infinite")
case f: FiniteDuration => f.dilated
case _ => throw new IllegalArgumentException("max duration cannot be infinite")
}
/**
* Query queue status.
*/
def msgAvailable = !queue.isEmpty
/**
* Await until the given condition evaluates to `true` or the timeout
* expires, whichever comes first.
*
* If no timeout is given, take it from the innermost enclosing `within`
* block.
*
* Note that the timeout is scaled using Duration.dilated,
* which uses the configuration entry "pekko.test.timefactor".
*/
def awaitCond(
p: => Boolean,
max: Duration = Duration.Undefined,
interval: Duration = 100.millis,
message: String = ""): Unit = {
val _max = remainingOrDilated(max)
val stop = now + _max
@tailrec
def poll(t: Duration): Unit = {
if (!p) {
assert(now < stop, s"timeout ${_max} expired: $message")
Thread.sleep(t.toMillis)
poll((stop - now) min interval)
}
}
poll(_max min interval)
}
/**
* Evaluate the given assert every `interval` until it does not throw an exception and return the
* result.
*
* If the `max` timeout expires the last exception is thrown.
*
* If no timeout is given, take it from the innermost enclosing `within`
* block.
*
* Note that the timeout is scaled using Duration.dilated,
* which uses the configuration entry "pekko.test.timefactor".
*/
def awaitAssert[A](a: => A, max: Duration = Duration.Undefined, interval: Duration = 100.millis): A = {
val _max = remainingOrDilated(max)
val stop = now + _max
@tailrec
def poll(t: Duration): A = {
// cannot use null-ness of result as signal it failed
// because Java API and not wanting to return a value will be "return null"
var failed = false
val result: A =
try {
val aRes = a
failed = false
aRes
} catch {
case NonFatal(e) =>
failed = true
if ((now + t) >= stop) throw e
else null.asInstanceOf[A]
}
if (!failed) result
else {
Thread.sleep(t.toMillis)
poll((stop - now) min interval)
}
}
poll(_max min interval)
}
/**
* Evaluate the given assert every `interval` until exception is thrown or `max` timeout is expired.
*
* Returns the result of last evaluation of the assertion.
*
* If no timeout is given, take it from the innermost enclosing `within`
* block.
*
* Note that the timeout is scaled using Duration.dilated,
* which uses the configuration entry "pekko.test.timefactor".
*/
def assertForDuration[A](a: => A, max: FiniteDuration, interval: Duration = 100.millis): A = {
val _max = remainingOrDilated(max)
val stop = now + _max
@tailrec
def poll(t: Duration): A = {
// cannot use null-ness of result as signal it failed
// because Java API and not wanting to return a value will be "return null"
val instantNow = now
val result =
try {
a
} catch {
case e: Throwable => throw e
}
if (instantNow < stop) {
Thread.sleep(t.toMillis)
poll((stop - now) min interval)
} else {
result
}
}
poll(max min interval)
}
/**
* Execute code block while bounding its execution time between `min` and
* `max`. `within` blocks may be nested. All methods in this trait which
* take maximum wait times are available in a version which implicitly uses
* the remaining time governed by the innermost enclosing `within` block.
*
* Note that the timeout is scaled using Duration.dilated, which uses the
* configuration entry "pekko.test.timefactor", while the min Duration is not.
*
* {{{
* val ret = within(50 millis) {
* test ! "ping"
* expectMsgClass(classOf[String])
* }
* }}}
*/
def within[T](min: FiniteDuration, max: FiniteDuration)(f: => T): T = {
val _max = max.dilated
val start = now
val rem = if (end == Duration.Undefined) Duration.Inf else end - start
assert(rem >= min, s"required min time $min not possible, only ${format(min.unit, rem)} left")
lastWasNoMsg = false
val max_diff = _max min rem
val prev_end = end
end = start + max_diff
val ret =
try f
finally end = prev_end
val diff = now - start
assert(min <= diff, s"block took ${format(min.unit, diff)}, should at least have been $min")
if (!lastWasNoMsg) {
assert(diff <= max_diff, s"block took ${format(_max.unit, diff)}, exceeding ${format(_max.unit, max_diff)}")
}
ret
}
/**
* Same as calling `within(0 seconds, max)(f)`.
*/
def within[T](max: FiniteDuration)(f: => T): T = within(0 seconds, max)(f)
/**
* Same as `expectMsg(remainingOrDefault, obj)`, but correctly treating the timeFactor.
*/
def expectMsg[T](obj: T): T = expectMsg_internal(remainingOrDefault, obj)
/**
* Receive one message from the test actor and assert that it equals the
* given object. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsg[T](max: FiniteDuration, obj: T): T = expectMsg_internal(max.dilated, obj)
/**
* Receive one message from the test actor and assert that it equals the
* given object. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsg[T](max: FiniteDuration, hint: String, obj: T): T = expectMsg_internal(max.dilated, obj, Some(hint))
private def expectMsg_internal[T](max: Duration, obj: T, hint: Option[String] = None): T = {
val o = receiveOne(max)
val hintOrEmptyString = hint.map(": " + _).getOrElse("")
assert(o ne null, s"timeout ($max) during expectMsg while waiting for $obj" + hintOrEmptyString)
assert(obj == o, s"expected $obj, found $o" + hintOrEmptyString)
o.asInstanceOf[T]
}
/**
* Receive one message from the test actor and assert that the given
* partial function accepts it. Wait time is bounded by the given duration,
* with an AssertionFailure being thrown in case of timeout.
*
* Use this variant to implement more complicated or conditional
* processing.
*
* @return the received object as transformed by the partial function
*/
def expectMsgPF[T](max: Duration = Duration.Undefined, hint: String = "")(f: PartialFunction[Any, T]): T = {
val _max = remainingOrDilated(max)
val o = receiveOne(_max)
assert(o ne null, s"timeout (${_max}) during expectMsg: $hint")
assert(f.isDefinedAt(o), s"expected: $hint but got unexpected message $o")
f(o)
}
/**
* Receive one message from the test actor and assert that it is the Terminated message of the given ActorRef.
* Before calling this method, you have to `watch` the target actor ref.
* Wait time is bounded by the given duration, with an AssertionFailure being thrown in case of timeout.
*
* @param target the actor ref expected to be Terminated
* @param max wait no more than max time, otherwise throw AssertionFailure
* @return the received Terminated message
*/
def expectTerminated(target: ActorRef, max: Duration = Duration.Undefined): Terminated =
expectMsgPF(max, "Terminated " + target) {
case t @ Terminated(`target`) => t
}
/**
* Hybrid of expectMsgPF and receiveWhile: receive messages while the
* partial function matches and returns false. Use it to ignore certain
* messages while waiting for a specific message.
*
* @return the last received message, i.e. the first one for which the
* partial function returned true
*/
def fishForMessage(max: Duration = Duration.Undefined, hint: String = "")(f: PartialFunction[Any, Boolean]): Any = {
val _max = remainingOrDilated(max)
val end = now + _max
@tailrec
def recv: Any = {
val o = receiveOne(end - now)
assert(o ne null, s"timeout (${_max}) during fishForMessage, hint: $hint")
assert(f.isDefinedAt(o), s"fishForMessage($hint) found unexpected message $o")
if (f(o)) o else recv
}
recv
}
/**
* Waits for specific message that partial function matches while ignoring all other messages coming in the meantime.
* Use it to ignore any number of messages while waiting for a specific one.
*
* @return result of applying partial function to the last received message,
* i.e. the first one for which the partial function is defined
*/
def fishForSpecificMessage[T](max: Duration = Duration.Undefined, hint: String = "")(
f: PartialFunction[Any, T]): T = {
val _max = remainingOrDilated(max)
val end = now + _max
@tailrec
def recv: T = {
val o = receiveOne(end - now)
assert(o ne null, s"timeout (${_max}) during fishForSpecificMessage, hint: $hint")
if (f.isDefinedAt(o)) f(o) else recv
}
recv
}
/**
* Same as `expectMsgType[T](remainingOrDefault)`, but correctly treating the timeFactor.
*/
def expectMsgType[T](implicit t: ClassTag[T]): T =
expectMsgClass_internal(remainingOrDefault, t.runtimeClass.asInstanceOf[Class[T]])
/**
* Receive one message from the test actor and assert that it conforms to the
* given type (after erasure). Wait time is bounded by the given duration,
* with an AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgType[T](max: FiniteDuration)(implicit t: ClassTag[T]): T =
expectMsgClass_internal(max.dilated, t.runtimeClass.asInstanceOf[Class[T]])
/**
* Same as `expectMsgClass(remainingOrDefault, c)`, but correctly treating the timeFactor.
*/
def expectMsgClass[C](c: Class[C]): C = expectMsgClass_internal(remainingOrDefault, c)
/**
* Receive one message from the test actor and assert that it conforms to
* the given class. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgClass[C](max: FiniteDuration, c: Class[C]): C = expectMsgClass_internal(max.dilated, c)
private def expectMsgClass_internal[C](max: FiniteDuration, c: Class[C]): C = {
val o = receiveOne(max)
assert(o ne null, s"timeout ($max) during expectMsgClass waiting for $c")
assert(BoxedType(c).isInstance(o), s"expected $c, found ${o.getClass} ($o)")
o.asInstanceOf[C]
}
/**
* Same as `expectMsgAnyOf(remainingOrDefault, obj...)`, but correctly treating the timeFactor.
*/
def expectMsgAnyOf[T](obj: T*): T = expectMsgAnyOf_internal(remainingOrDefault, obj: _*)
/**
* Receive one message from the test actor and assert that it equals one of
* the given objects. Wait time is bounded by the given duration, with an
* AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgAnyOf[T](max: FiniteDuration, obj: T*): T = expectMsgAnyOf_internal(max.dilated, obj: _*)
private def expectMsgAnyOf_internal[T](max: FiniteDuration, obj: T*): T = {
val o = receiveOne(max)
assert(o ne null, s"timeout ($max) during expectMsgAnyOf waiting for ${obj.mkString("(", ", ", ")")}")
assert(obj.exists(_ == o), s"found unexpected $o")
o.asInstanceOf[T]
}
/**
* Same as `expectMsgAnyClassOf(remainingOrDefault, obj...)`, but correctly treating the timeFactor.
*/
def expectMsgAnyClassOf[C](obj: Class[_ <: C]*): C = expectMsgAnyClassOf_internal(remainingOrDefault, obj: _*)
/**
* Receive one message from the test actor and assert that it conforms to
* one of the given classes. Wait time is bounded by the given duration,
* with an AssertionFailure being thrown in case of timeout.
*
* @return the received object
*/
def expectMsgAnyClassOf[C](max: FiniteDuration, obj: Class[_ <: C]*): C =
expectMsgAnyClassOf_internal(max.dilated, obj: _*)
private def expectMsgAnyClassOf_internal[C](max: FiniteDuration, obj: Class[_ <: C]*): C = {
val o = receiveOne(max)
assert(o ne null, s"timeout ($max) during expectMsgAnyClassOf waiting for ${obj.mkString("(", ", ", ")")}")
assert(obj.exists(c => BoxedType(c).isInstance(o)), s"found unexpected $o")
o.asInstanceOf[C]
}
/**
* Same as `expectMsgAllOf(remainingOrDefault, obj...)`, but correctly treating the timeFactor.
*/
def expectMsgAllOf[T](obj: T*): immutable.Seq[T] = expectMsgAllOf_internal(remainingOrDefault, obj: _*)
/**
* Receive a number of messages from the test actor matching the given
* number of objects and assert that for each given object one is received
* which equals it and vice versa. This construct is useful when the order in
* which the objects are received is not fixed. Wait time is bounded by the
* given duration, with an AssertionFailure being thrown in case of timeout.
*
*
* dispatcher ! SomeWork1()
* dispatcher ! SomeWork2()
* expectMsgAllOf(1 second, Result1(), Result2())
*
*/
def expectMsgAllOf[T](max: FiniteDuration, obj: T*): immutable.Seq[T] = expectMsgAllOf_internal(max.dilated, obj: _*)
private def checkMissingAndUnexpected(
missing: Seq[Any],
unexpected: Seq[Any],
missingMessage: String,
unexpectedMessage: String): Unit = {
assert(
missing.isEmpty && unexpected.isEmpty,
(if (missing.isEmpty) "" else missing.mkString(missingMessage + " [", ", ", "] ")) +
(if (unexpected.isEmpty) "" else unexpected.mkString(unexpectedMessage + " [", ", ", "]")))
}
private def expectMsgAllOf_internal[T](max: FiniteDuration, obj: T*): immutable.Seq[T] = {
val recv = receiveN_internal(obj.size, max)
val missing = obj.filterNot(x => recv.exists(x == _))
val unexpected = recv.filterNot(x => obj.exists(x == _))
checkMissingAndUnexpected(missing, unexpected, "not found", "found unexpected")
recv.asInstanceOf[immutable.Seq[T]]
}
/**
* Same as `expectMsgAllClassOf(remainingOrDefault, obj...)`, but correctly treating the timeFactor.
*/
def expectMsgAllClassOf[T](obj: Class[_ <: T]*): immutable.Seq[T] =
internalExpectMsgAllClassOf(remainingOrDefault, obj: _*)
/**
* Receive a number of messages from the test actor matching the given
* number of classes and assert that for each given class one is received
* which is of that class (equality, not conformance). This construct is
* useful when the order in which the objects are received is not fixed.
* Wait time is bounded by the given duration, with an AssertionFailure
* being thrown in case of timeout.
*/
def expectMsgAllClassOf[T](max: FiniteDuration, obj: Class[_ <: T]*): immutable.Seq[T] =
internalExpectMsgAllClassOf(max.dilated, obj: _*)
private def internalExpectMsgAllClassOf[T](max: FiniteDuration, obj: Class[_ <: T]*): immutable.Seq[T] = {
val recv = receiveN_internal(obj.size, max)
val missing = obj.filterNot(x => recv.exists(_.getClass eq BoxedType(x)))
val unexpected = recv.filterNot(x => obj.exists(c => BoxedType(c) eq x.getClass))
checkMissingAndUnexpected(missing, unexpected, "not found", "found non-matching object(s)")
recv.asInstanceOf[immutable.Seq[T]]
}
/**
* Same as `expectMsgAllConformingOf(remainingOrDefault, obj...)`, but correctly treating the timeFactor.
*/
def expectMsgAllConformingOf[T](obj: Class[_ <: T]*): immutable.Seq[T] =
internalExpectMsgAllConformingOf(remainingOrDefault, obj: _*)
/**
* Receive a number of messages from the test actor matching the given
* number of classes and assert that for each given class one is received
* which conforms to that class (and vice versa). This construct is useful
* when the order in which the objects are received is not fixed. Wait time
* is bounded by the given duration, with an AssertionFailure being thrown in
* case of timeout.
*
* Beware that one object may satisfy all given class constraints, which
* may be counter-intuitive.
*/
def expectMsgAllConformingOf[T](max: FiniteDuration, obj: Class[_ <: T]*): immutable.Seq[T] =
internalExpectMsgAllConformingOf(max.dilated, obj: _*)
private def internalExpectMsgAllConformingOf[T](max: FiniteDuration, obj: Class[_ <: T]*): immutable.Seq[T] = {
val recv = receiveN_internal(obj.size, max)
val missing = obj.filterNot(x => recv.exists(BoxedType(x).isInstance(_)))
val unexpected = recv.filterNot(x => obj.exists(c => BoxedType(c).isInstance(x)))
checkMissingAndUnexpected(missing, unexpected, "not found", "found non-matching object(s)")
recv.asInstanceOf[immutable.Seq[T]]
}
/**
* Assert that no message is received. Waits for the default period configured as
* `pekko.test.expect-no-message-default`.
* That timeout is scaled using the configuration entry "pekko.test.timefactor".
*/
@deprecated(message = "Use expectNoMessage instead", since = "Akka 2.5.5")
def expectNoMsg(): Unit = expectNoMessage()
/**
* Assert that no message is received for the specified time.
* NOTE! Supplied value is always dilated.
*/
@deprecated(message = "Use expectNoMessage instead", since = "Akka 2.5.5")
def expectNoMsg(max: FiniteDuration): Unit = {
expectNoMsg_internal(max.dilated)
}
/**
* Assert that no message is received for the specified time.
* Supplied value is not dilated.
*/
def expectNoMessage(max: FiniteDuration) = {
expectNoMsg_internal(max)
}
/**
* Assert that no message is received. Waits for the default period configured as
* `pekko.test.expect-no-message-default`.
* That timeout is scaled using the configuration entry "pekko.test.timefactor".
*/
def expectNoMessage(): Unit =
expectNoMsg_internal(testKitSettings.ExpectNoMessageDefaultTimeout.dilated)
private def expectNoMsg_internal(max: FiniteDuration): Unit = {
val finish = System.nanoTime() + max.toNanos
val pollInterval = 100.millis
def leftNow = (finish - System.nanoTime()).nanos
var elem: AnyRef = queue.peekFirst()
var left = leftNow
while (left.toNanos > 0 && elem == null) {
// Use of (left / 2) gives geometric series limited by finish time similar to (1/2)^n limited by 1,
// so it is very precise
Thread.sleep(pollInterval.toMillis min (left / 2).toMillis)
left = leftNow
if (left.toNanos > 0) {
elem = queue.peekFirst()
}
}
if (elem ne null) {
// we pop the message, such that subsequent expectNoMessage calls can
// assert on the next period without a message
queue.pop()
val diff = (max.toNanos - left.toNanos).nanos
val m = s"assertion failed: received unexpected message $elem after ${diff.toMillis} millis"
throw new java.lang.AssertionError(m)
} else {
lastWasNoMsg = true
}
}
/**
* Receive a series of messages until one does not match the given partial
* function or the idle timeout is met (disabled by default) or the overall
* maximum duration is elapsed or expected messages count is reached.
* Returns the sequence of messages.
*
* Note that it is not an error to hit the `max` duration in this case.
*
* One possible use of this method is for testing whether messages of
* certain characteristics are generated at a certain rate:
*
* {{{
* test ! ScheduleTicks(100 millis)
* val series = receiveWhile(750 millis) {
* case Tick(count) => count
* }
* assert(series == (1 to 7).toList)
* }}}
*/
def receiveWhile[T](max: Duration = Duration.Undefined, idle: Duration = Duration.Inf, messages: Int = Int.MaxValue)(
f: PartialFunction[AnyRef, T]): immutable.Seq[T] = {
val stop = now + remainingOrDilated(max)
var msg: Message = NullMessage
@tailrec
def doit(acc: List[T], count: Int): List[T] = {
if (count >= messages) acc.reverse
else {
receiveOne((stop - now) min idle)
lastMessage match {
case NullMessage =>
lastMessage = msg
acc.reverse
case RealMessage(o, _) if f.isDefinedAt(o) =>
msg = lastMessage
doit(f(o) :: acc, count + 1)
case RealMessage(_, _) =>
queue.offerFirst(lastMessage)
lastMessage = msg
acc.reverse
case unexpected => throw new RuntimeException(s"Unexpected: $unexpected") // exhaustiveness check
}
}
}
val ret = doit(Nil, 0)
lastWasNoMsg = true
ret
}
/**
* Same as `receiveN(n, remaining)` but correctly taking into account
* Duration.timeFactor.
*/
def receiveN(n: Int): immutable.Seq[AnyRef] = receiveN_internal(n, remainingOrDefault)
/**
* Receive N messages in a row before the given deadline.
*/
def receiveN(n: Int, max: FiniteDuration): immutable.Seq[AnyRef] = receiveN_internal(n, max.dilated)
private def receiveN_internal(n: Int, max: Duration): immutable.Seq[AnyRef] = {
val stop = max + now
for { x <- 1 to n } yield {
val timeout = stop - now
val o = receiveOne(timeout)
assert(o ne null, s"timeout ($max) while expecting $n messages (got ${x - 1})")
o
}
}
/**
* Receive one message from the internal queue of the TestActor. If the given
* duration is zero, the queue is polled (non-blocking).
*
* This method does NOT automatically scale its Duration parameter!
*/
def receiveOne(max: Duration): AnyRef = {
val message =
if (max == 0.seconds) {
queue.pollFirst
} else if (max.isFinite) {
queue.pollFirst(max.length, max.unit)
} else {
queue.takeFirst
}
lastWasNoMsg = false
message match {
case null =>
lastMessage = NullMessage
null
case RealMessage(msg, _) =>
lastMessage = message
msg
case unexpected => throw new RuntimeException(s"Unexpected: $unexpected") // exhaustiveness check
}
}
/**
* Shut down an actor system and wait for termination.
* On failure debug output will be logged about the remaining actors in the system.
*
* If verifySystemShutdown is true, then an exception will be thrown on failure.
*/
def shutdown(
actorSystem: ActorSystem = system,
duration: Duration = Duration.Undefined,
verifySystemShutdown: Boolean = false): Unit = {
TestKit.shutdownActorSystem(actorSystem, duration, verifySystemShutdown)
}
/**
* Spawns an actor as a child of this test actor, and returns the child's ActorRef.
* @param props Props to create the child actor
* @param name Actor name for the child actor
* @param supervisorStrategy Strategy should decide what to do with failures in the actor.
*/
def childActorOf(props: Props, name: String, supervisorStrategy: SupervisorStrategy): ActorRef = {
testActor ! Spawn(props, Some(name), Some(supervisorStrategy))
expectMsgType[ActorRef]
}
/**
* Spawns an actor as a child of this test actor with an auto-generated name, and returns the child's ActorRef.
* @param props Props to create the child actor
* @param supervisorStrategy Strategy should decide what to do with failures in the actor.
*/
def childActorOf(props: Props, supervisorStrategy: SupervisorStrategy): ActorRef = {
testActor ! Spawn(props, None, Some(supervisorStrategy))
expectMsgType[ActorRef]
}
/**
* Spawns an actor as a child of this test actor with a stopping supervisor strategy, and returns the child's ActorRef.
* @param props Props to create the child actor
* @param name Actor name for the child actor
*/
def childActorOf(props: Props, name: String): ActorRef = {
testActor ! Spawn(props, Some(name), None)
expectMsgType[ActorRef]
}
/**
* Spawns an actor as a child of this test actor with an auto-generated name and stopping supervisor strategy, returning the child's ActorRef.
* @param props Props to create the child actor
*/
def childActorOf(props: Props): ActorRef = {
testActor ! Spawn(props, None, None)
expectMsgType[ActorRef]
}
private def format(u: TimeUnit, d: Duration) = "%.3f %s".format(d.toUnit(u), u.toString.toLowerCase)
}
/**
* Test kit for testing actors. Inheriting from this class enables reception of
* replies from actors, which are queued by an internal actor and can be
* examined using the `expectMsg...` methods. Assertions and bounds concerning
* timing are available in the form of `within` blocks.
*
* {{{
* class Test extends TestKit(ActorSystem()) {
* try {
*
* val test = system.actorOf(Props[SomeActor])
*
* within (1.second) {
* test ! SomeWork
* expectMsg(Result1) // bounded to 1 second
* expectMsg(Result2) // bounded to the remainder of the 1 second
* }
*
* } finally {
* system.terminate()
* }
*
* } finally {
* system.terminate()
* }
* }
* }}}
*
* Beware of two points:
*
* - the ActorSystem passed into the constructor needs to be shutdown,
* otherwise thread pools and memory will be leaked
* - this class is not thread-safe (only one actor with one queue, one stack
* of `within` blocks); it is expected that the code is executed from a
* constructor as shown above, which makes this a non-issue, otherwise take
* care not to run tests within a single test class instance in parallel.
*
* It should be noted that for CI servers and the like all maximum Durations
* are scaled using their Duration.dilated method, which uses the
* TestKitExtension.Settings.TestTimeFactor settable via pekko.conf entry "pekko.test.timefactor".
*/
@nowarn // 'early initializers' are deprecated on 2.13 and will be replaced with trait parameters on 2.14. https://github.com/akka/akka/issues/26753
class TestKit(_system: ActorSystem) extends TestKitBase {
implicit val system: ActorSystem = _system
}
object TestKit {
/**
* INTERNAL API
*/
@InternalApi
private[pekko] val testActorId = new AtomicInteger(0)
/**
* Await until the given condition evaluates to `true` or the timeout
* expires, whichever comes first.
*/
def awaitCond(p: => Boolean, max: Duration, interval: Duration = 100.millis, noThrow: Boolean = false): Boolean = {
val stop = now + max
@tailrec
def poll(): Boolean = {
if (!p) {
val toSleep = stop - now
if (toSleep <= Duration.Zero) {
if (noThrow) false
else throw new AssertionError(s"timeout $max expired")
} else {
Thread.sleep((toSleep min interval).toMillis)
poll()
}
} else true
}
poll()
}
/**
* Obtain current timestamp as Duration for relative measurements (using System.nanoTime).
*/
def now: Duration = System.nanoTime().nanos
/**
* Shut down an actor system and wait for termination.
* On failure debug output will be logged about the remaining actors in the system.
*
* The `duration` is dilated by the timefactor.
*
* If verifySystemShutdown is true, then an exception will be thrown on failure.
*/
def shutdownActorSystem(
actorSystem: ActorSystem,
duration: Duration = Duration.Undefined,
verifySystemShutdown: Boolean = false): Unit = {
val d = duration match {
case f: FiniteDuration => f.dilated(actorSystem)
case _ => 10.seconds.dilated(actorSystem).min(10.seconds)
}
actorSystem.terminate()
try Await.ready(actorSystem.whenTerminated, d)
catch {
case _: TimeoutException =>
val msg = "Failed to stop [%s] within [%s] \n%s".format(
actorSystem.name,
d,
actorSystem.asInstanceOf[ActorSystemImpl].printTree)
if (verifySystemShutdown) throw new RuntimeException(msg)
else println(msg)
}
}
}
/**
* TestKit-based probe which allows sending, reception and reply.
*/
class TestProbe(_application: ActorSystem, name: String) extends TestKit(_application) {
def this(_application: ActorSystem) = this(_application, "testProbe")
/**
* Shorthand to get the testActor.
*/
def ref = testActor
protected override def testActorName = name
/**
* Send message to an actor while using the probe's TestActor as the sender.
* Replies will be available for inspection with all of TestKit's assertion
* methods.
*/
def send(actor: ActorRef, msg: Any): Unit = actor.!(msg)(testActor)
/**
* Forward this message as if in the TestActor's receive method with self.forward.
*/
def forward(actor: ActorRef, msg: Any = lastMessage.msg): Unit = actor.!(msg)(lastMessage.sender)
/**
* Get sender of last received message.
*/
def sender() = lastMessage.sender
/**
* Send message to the sender of the last dequeued message.
*/
def reply(msg: Any): Unit = sender().!(msg)(ref)
}
object TestProbe {
def apply()(implicit system: ActorSystem) = new TestProbe(system)
def apply(name: String)(implicit system: ActorSystem) = new TestProbe(system, name)
}
trait ImplicitSender { this: TestKitBase =>
implicit def self: ActorRef = testActor
}
trait DefaultTimeout { this: TestKitBase =>
implicit val timeout: Timeout = testKitSettings.DefaultTimeout
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy