verify.asserts.RecorderMacro.scala Maven / Gradle / Ivy
The newest version!
/*
* Scala (https://www.scala-lang.org)
*
* Copyright EPFL and Lightbend, Inc.
*
* Licensed under Apache License 2.0
* (http://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/
package verify
package asserts
import scala.reflect.macros.blackbox.Context
import scala.util.Properties
class RecorderMacro[C <: Context](val context: C) {
import context.universe._
/** captures a method invocation in the shape of assert(expr, message). */
def apply[A: context.WeakTypeTag, R: context.WeakTypeTag](value: context.Tree, message: context.Tree): Expr[R] = {
context.Expr(
Block(
declareRuntime[A, R] ::
recordMessage(message) ::
recordExpressions(value),
completeRecording
)
)
}
/** captures a method invocation in the shape of assertEquals(expected, found). */
def apply2[A: context.WeakTypeTag, R: context.WeakTypeTag](
expected: context.Tree,
found: context.Tree,
message: context.Tree
): Expr[R] = {
context.Expr(
Block(
declareRuntime[A, R]("stringAssertEqualsListener") ::
recordMessage(message) ::
recordExpressions(expected) :::
recordExpressions(found),
completeRecording
)
)
}
private[this] def termName(c: C)(s: String) =
c.universe.TermName(s)
private[this] def declareRuntime[A: context.WeakTypeTag, R: context.WeakTypeTag](listener: String): Tree = {
val runtimeClass = context.mirror.staticClass(classOf[RecorderRuntime[_, _]].getName())
ValDef(
Modifiers(),
termName(context)("$scala_verify_recorderRuntime"),
TypeTree(weakTypeOf[RecorderRuntime[A, R]]),
Apply(
Select(New(Ident(runtimeClass)), termNames.CONSTRUCTOR),
List(Select(context.prefix.tree, termName(context)(listener)))
)
)
}
private[this] def declareRuntime[A: context.WeakTypeTag, R: context.WeakTypeTag]: Tree = {
val runtimeClass = context.mirror.staticClass(classOf[RecorderRuntime[_, _]].getName())
ValDef(
Modifiers(),
termName(context)("$scala_verify_recorderRuntime"),
TypeTree(weakTypeOf[RecorderRuntime[A, R]]),
Apply(
Select(New(Ident(runtimeClass)), termNames.CONSTRUCTOR),
List(Select(context.prefix.tree, termName(context)("listener")))
)
)
}
private[this] def recordExpressions(recording: Tree): List[Tree] = {
val source = getSourceCode(recording)
val ast = showRaw(recording)
try {
List(resetValues, recordExpression(source, ast, recording))
} catch {
case e: Throwable =>
throw new RuntimeException("Expecty: Error rewriting expression.\nText: " + source + "\nAST : " + ast, e)
}
}
private[this] def recordMessage(message: Tree): Tree =
Apply(
Select(Ident(termName(context)("$scala_verify_recorderRuntime")), termName(context)("recordMessage")),
List(message)
)
private[this] def completeRecording: Tree =
Apply(
Select(Ident(termName(context)("$scala_verify_recorderRuntime")), termName(context)("completeRecording")),
List()
)
private[this] def resetValues: Tree =
Apply(
Select(Ident(termName(context)("$scala_verify_recorderRuntime")), termName(context)("resetValues")),
List()
)
// emit recorderRuntime.recordExpression(