zio.mock.internal.ProxyFactory.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2019-2022 John A. De Goes and the ZIO Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package zio.mock.internal
import zio.mock.{Capability, Expectation, Proxy}
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.test.Assertion
import zio.{EnvironmentTag, IO, Trace, ULayer, ZIO, ZLayer}
import scala.annotation.tailrec
import scala.util.Try
object ProxyFactory {
import Debug._
import Expectation._
import ExpectationState._
import InvalidCall._
import MockException._
/** Given initial `MockState[R]`, constructs a `Proxy` running that state.
*/
def mockProxy[R: EnvironmentTag](state: MockState[R])(implicit trace: Trace): ULayer[Proxy] =
ZLayer.succeed(new Proxy {
debug(s"::: new Proxy created")
def invoke[RIn, ROut, I, E, A](invoked: Capability[RIn, I, E, A], args: I): ZIO[ROut, E, A] = {
debug(s"::: invoked $invoked")
sealed trait MatchResult
object MatchResult {
case object UnexpectedCall extends MatchResult
case class Success(value: Matched[R, E, A]) extends MatchResult
case class Failure(failures: List[InvalidCall]) extends MatchResult
}
@tailrec
def findMatching(scopes: List[Scope[R]], failedMatches: List[InvalidCall]): MatchResult = {
debug(s"::: invoked $invoked\n${prettify(scopes)}")
scopes match {
case Nil => MatchResult.UnexpectedCall
case Scope(expectation, id, update0) :: nextScopes =>
val update: Expectation[R] => Expectation[R] = updated => {
debug(s"::: updated state to: ${updated.state}")
update0(updated)
}
expectation match {
case anyExpectation if anyExpectation.state == Saturated =>
debug("::: skipping saturated expectation")
findMatching(nextScopes, failedMatches)
case NoCalls(_) =>
findMatching(nextScopes, failedMatches)
case call @ Call(capability, assertion, returns, _, invocations) if invoked isEqual capability =>
debug(s"::: matched call $capability")
assertion.asInstanceOf[Assertion[I]].test(args) match {
case true =>
val result = returns.asInstanceOf[I => IO[E, A]](args)
val updated = call
.asInstanceOf[Call[R, I, E, A]]
.copy(
state = Saturated,
invocations = id :: invocations
)
MatchResult.Success(Matched[R, E, A](update(updated), result))
case false =>
handleLeafFailure(
InvalidArguments(invoked, args, assertion.asInstanceOf[Assertion[Any]]),
nextScopes,
failedMatches
)
}
case Call(capability, assertion, _, _, _) =>
debug(s"::: invalid call $capability")
val invalidCall =
if (invoked.id == capability.id) InvalidPolyType(invoked, args, capability, assertion)
else InvalidCapability(invoked, capability, assertion)
handleLeafFailure(invalidCall, nextScopes, failedMatches)
case self @ Chain(children, _, invocations, _) =>
children.zipWithIndex.dropWhile(_._1.state == Saturated) match {
case (child1 @ Repeated(_, _, state, _, _, _), idx1) :: (child2, idx2) :: _ if state == Satisfied =>
val scope1 = Scope[R](
child1,
id,
updatedChild => {
val updatedChildren = children.updated(idx1, updatedChild)
update(
self.copy(
children = updatedChildren,
state = minimumState(updatedChildren),
invocations = id :: invocations
)
)
}
)
val scope2 = Scope[R](
child2,
id,
updatedChild => {
val updatedChildren = children
.updated(idx1, child1.copy(state = Saturated))
.updated(idx2, updatedChild)
update(
self.copy(
children = updatedChildren,
state = minimumState(updatedChildren),
invocations = id :: invocations
)
)
}
)
findMatching(scope1 :: scope2 :: Nil, failedMatches)
case (child, idx) :: _ =>
val scope = Scope[R](
child,
id,
updatedChild => {
val updatedChildren = children.updated(idx, updatedChild)
update(
self.copy(
children = updatedChildren,
state = minimumState(updatedChildren),
invocations = id :: invocations
)
)
}
)
findMatching(scope :: nextScopes, failedMatches)
case Nil =>
findMatching(nextScopes, failedMatches)
}
case self @ And(children, _, invocations, _) =>
val scopes = children.zipWithIndex.collect {
case (child, index) if child.state < Saturated =>
Scope[R](
child,
id,
updatedChild => {
val updatedChildren = children.updated(index, updatedChild)
update(
self.copy(
children = updatedChildren,
state = minimumState(updatedChildren),
invocations = id :: invocations
)
)
}
)
}
findMatching(scopes ++ nextScopes, failedMatches)
case self @ Or(children, _, invocations, _) =>
children.zipWithIndex.find(_._1.state == PartiallySatisfied) match {
case Some((child, index)) =>
val scope = Scope[R](
child,
id,
updatedChild => {
val updatedChildren = children.updated(index, updatedChild)
update(
self.copy(
children = updatedChildren,
state = maximumState(updatedChildren),
invocations = id :: invocations
)
)
}
)
findMatching(scope :: nextScopes, failedMatches)
case None =>
val scopes = children.zipWithIndex.collect { case (child, index) =>
Scope[R](
child,
id,
updatedChild => {
val updatedChildren = children.updated(index, updatedChild)
update(
self.copy(
children = updatedChildren,
state = maximumState(updatedChildren),
invocations = id :: invocations
)
)
}
)
}
findMatching(scopes ++ nextScopes, failedMatches)
}
case self @ Repeated(expectation, range, _, invocations, started, completed) =>
val scope = Scope[R](
expectation,
id,
updatedChild => {
val updatedStarted =
if (started == completed) started + 1
else started
val updatedCompleted =
if (updatedChild.state == Saturated) completed + 1
else completed
val inRepeatsRange: Boolean =
if (range.end != -1) range contains updatedStarted
else {
val fakeUnboundedRange = range.start to Int.MaxValue by range.step
fakeUnboundedRange contains updatedStarted
}
val maxRepeatsReached: Boolean =
Try(range.max == updatedCompleted).getOrElse(false)
update(
self.copy(
child = if (updatedChild.state == Saturated) resetTree(updatedChild) else updatedChild,
state = updatedChild.state match {
case Saturated =>
if (!inRepeatsRange) PartiallySatisfied
else if (maxRepeatsReached) Saturated
else Satisfied
case Satisfied =>
if (!inRepeatsRange) PartiallySatisfied
else Satisfied
case childState => childState
},
invocations = id :: invocations,
started = updatedStarted,
completed = updatedCompleted
)
)
}
)
findMatching(scope :: nextScopes, failedMatches)
case self @ Exactly(expectation, times, _, invocations, completed) =>
val scope = Scope[R](
expectation,
id,
updatedChild => {
val updatedCompleted =
if (updatedChild.state == Saturated) completed + 1
else completed
update(
self.copy(
child = if (updatedChild.state == Saturated) resetTree(updatedChild) else updatedChild,
state = if (times == updatedCompleted) Saturated else PartiallySatisfied,
invocations = id :: invocations,
completed = updatedCompleted
)
)
}
)
findMatching(scope :: nextScopes, failedMatches)
}
}
}
def minimumState(children: List[Expectation[R]]): ExpectationState = {
val states = children.map(_.state)
val min = states.min
val max = states.max
if (min >= Satisfied) min
else if (max >= Satisfied) PartiallySatisfied
else Unsatisfied
}
def maximumState(children: List[Expectation[R]]): ExpectationState =
children.map(_.state).max
def handleLeafFailure(
failure: => InvalidCall,
nextScopes: List[Scope[R]],
failedMatches: List[InvalidCall]
): MatchResult = {
val nextFailed = failure :: failedMatches
if (nextScopes.isEmpty) MatchResult.Failure(nextFailed)
else findMatching(nextScopes, nextFailed)
}
def resetTree(expectation: Expectation[R]): Expectation[R] =
expectation match {
case self: Call[R, _, _, _] =>
self.copy(state = Unsatisfied)
case self: Chain[R] =>
self.copy(
children = self.children.map(resetTree),
state = Unsatisfied
)
case self: And[R] =>
self.copy(
children = self.children.map(resetTree),
state = Unsatisfied
)
case self: Or[R] =>
self.copy(
children = self.children.map(resetTree),
state = Unsatisfied
)
case self: NoCalls[R] => self
case self: Repeated[R] =>
self.copy(
child = resetTree(self.child),
state = Unsatisfied,
completed = 0
)
case self: Exactly[R] =>
self.copy(
child = resetTree(self.child),
state = Unsatisfied,
completed = 0
)
}
debug(s"::: invoked ${invoked} before 'for'.")
for {
id <- state.callsCountRef.updateAndGet(_ + 1)
_ = debug(s"::: invoked ${invoked} before setting `matchResult`.")
matchResult <-
state.expectationRef.modify { root =>
val scope = Scope[R](root, id, identity)
val res = findMatching(scope :: Nil, Nil)
res match {
case MatchResult.Success(matched) => res -> matched.expectation
case MatchResult.UnexpectedCall => res -> root
case MatchResult.Failure(_) => res -> root
}
}
_ = debug(s"::: invoked $invoked\n ::: matchedResult: ${matchResult.toString}")
matched <-
matchResult match {
case MatchResult.Success(matched) => ZIO.succeed(matched)
case MatchResult.UnexpectedCall => ZIO.die(UnexpectedCallException(invoked, args))
case MatchResult.Failure(failures) => ZIO.die(InvalidCallException(failures))
}
_ = debug(s"::: setting root to\n${prettify(matched.expectation)}")
output <- matched.result
} yield output
}
})
}