dotty.tools.dotc.transform.Splicer.scala Maven / Gradle / Ivy
The newest version!
package dotty.tools.dotc
package transform
import java.io.{PrintWriter, StringWriter}
import java.lang.reflect.{InvocationTargetException, Method}
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.Trees._
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds.FlatName
import dotty.tools.dotc.core.Names.{Name, TermName}
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.core.quoted._
import dotty.tools.dotc.core.Types._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.{NameKinds, TypeErasure}
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.tastyreflect.ReflectionImpl
import scala.util.control.NonFatal
import dotty.tools.dotc.util.SourcePosition
import dotty.tools.repl.AbstractFileClassLoader
import scala.reflect.ClassTag
/** Utility class to splice quoted expressions */
object Splicer {
import tpd._
/** Splice the Tree for a Quoted expression. `${'{xyz}}` becomes `xyz`
* and for `$xyz` the tree of `xyz` is interpreted for which the
* resulting expression is returned as a `Tree`
*
* See: `Staging`
*/
def splice(tree: Tree, pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context): Tree = tree match {
case Quoted(quotedTree) => quotedTree
case _ =>
val interpreter = new Interpreter(pos, classLoader)
try {
// Some parts of the macro are evaluated during the unpickling performed in quotedExprToTree
val interpretedExpr = interpreter.interpret[scala.quoted.Expr[Any]](tree)
interpretedExpr.fold(tree)(x => PickledQuotes.quotedExprToTree(x))
}
catch {
case NonFatal(ex) =>
val msg =
s"""Failed to evaluate macro.
| Caused by ${ex.getClass}: ${if (ex.getMessage == null) "" else ex.getMessage}
| ${ex.getStackTrace.takeWhile(_.getClassName != "dotty.tools.dotc.transform.Splicer$").init.mkString("\n ")}
""".stripMargin
ctx.error(msg, pos)
EmptyTree
}
}
/** Check that the Tree can be spliced. `${'{xyz}}` becomes `xyz`
* and for `$xyz` the tree of `xyz` is interpreted for which the
* resulting expression is returned as a `Tree`
*
* See: `Staging`
*/
def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match {
case Quoted(_) => // ok
case _ =>
def checkValidStat(tree: Tree): Unit = tree match {
case tree: ValDef if tree.symbol.is(Synthetic) =>
// Check val from `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
checkIfValidArgument(tree.rhs)
case _ =>
ctx.error("Macro should not have statements", tree.sourcePos)
}
def checkIfValidArgument(tree: Tree): Unit = tree match {
case Block(Nil, expr) => checkIfValidArgument(expr)
case Typed(expr, _) => checkIfValidArgument(expr)
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
// OK
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
// OK
case Literal(Constant(value)) =>
// OK
case _ if tree.symbol == defn.QuoteContext_macroContext =>
// OK
case Call(fn, args)
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
fn.symbol.is(Module) || fn.symbol.isStatic ||
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
args.foreach(_.foreach(checkIfValidArgument))
case NamedArg(_, arg) =>
checkIfValidArgument(arg)
case SeqLiteral(elems, _) =>
elems.foreach(checkIfValidArgument)
case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) =>
// OK
case _ =>
ctx.error(
"""Malformed macro parameter
|
|Parameters may be:
| * Quoted parameters or fields
| * References to inline parameters
| * Literal values of primitive types
|""".stripMargin, tree.sourcePos)
}
def checkIfValidStaticCall(tree: Tree): Unit = tree match {
case Block(stats, expr) =>
stats.foreach(checkValidStat)
checkIfValidStaticCall(expr)
case Typed(expr, _) =>
checkIfValidStaticCall(expr)
case Call(fn, args)
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
fn.symbol.is(Module) || fn.symbol.isStatic ||
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
args.flatten.foreach(checkIfValidArgument)
case _ =>
ctx.error(
"""Malformed macro.
|
|Expected the splice ${...} to contain a single call to a static method.
|""".stripMargin, tree.sourcePos)
}
checkIfValidStaticCall(tree)
}
/** Tree interpreter that evaluates the tree */
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) {
type Env = Map[Name, Object]
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
* Return Some of the result or None if some error happen during the interpretation.
*/
def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] = {
try {
interpretTree(tree)(Map.empty) match {
case obj: T => Some(obj)
case obj =>
// TODO upgrade to a full type tag check or something similar
ctx.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
None
}
} catch {
case ex: StopInterpretation =>
ctx.error(ex.msg, ex.pos)
None
}
}
def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
val quoted1 = quoted match {
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
// inline proxy for by-name parameter
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
case Inlined(EmptyTree, _, quoted) => quoted
case _ => quoted
}
interpretQuote(quoted1)
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
interpretTypeQuote(quoted)
case Literal(Constant(value)) =>
interpretLiteral(value)
case _ if tree.symbol == defn.QuoteContext_macroContext =>
interpretQuoteContext()
// TODO disallow interpreted method calls as arguments
case Call(fn, args) =>
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
interpretNew(fn.symbol, args.flatten.map(interpretTree))
} else if (fn.symbol.is(Module)) {
interpretModuleAccess(fn.symbol)
} else if (fn.symbol.isStatic) {
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
staticMethodCall(args.flatten.map(interpretTree))
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
staticMethodCall(args.flatten.map(interpretTree))
} else if (env.contains(fn.name)) {
env(fn.name)
} else if (tree.symbol.is(InlineProxy)) {
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
} else {
unexpectedTree(tree)
}
// Interpret `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
case Block(stats, expr) => interpretBlock(stats, expr)
case NamedArg(_, arg) => interpretTree(arg)
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
case Typed(expr, _) =>
interpretTree(expr)
case SeqLiteral(elems, _) =>
interpretVarargs(elems.map(e => interpretTree(e)))
case _ =>
unexpectedTree(tree)
}
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
var unexpected: Option[Object] = None
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
case stat: ValDef =>
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
case stat =>
if (unexpected.isEmpty)
unexpected = Some(unexpectedTree(stat))
accEnv
})
unexpected.getOrElse(interpretTree(expr)(newEnv))
}
private def interpretQuote(tree: Tree)(implicit env: Env): Object =
new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span))
private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
new scala.internal.quoted.TreeType(tree)
private def interpretLiteral(value: Any)(implicit env: Env): Object =
value.asInstanceOf[Object]
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
args.toSeq
private def interpretQuoteContext()(implicit env: Env): Object =
new scala.quoted.QuoteContext(ReflectionImpl(ctx, pos))
private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = {
val (inst, clazz) =
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) {
(null, loadReplLineClass(moduleClass))
} else {
val inst = loadModule(moduleClass)
(inst, inst.getClass)
}
def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match {
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
getDirectName(tp.args.last, NameKinds.DirectMethodName(name))
case _ => name
}
val name = getDirectName(fn.info.finalResultType, fn.name.asTermName)
val method = getMethod(clazz, name, paramsSig(fn))
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*))
}
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
loadModule(fn.moduleClass)
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
val clazz = loadClass(fn.owner.fullName.toString)
val constr = clazz.getConstructor(paramsSig(fn): _*)
constr.newInstance(args: _*).asInstanceOf[Object]
}
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.sourcePos)
private def loadModule(sym: Symbol): Object = {
if (sym.owner.is(Package)) {
// is top level object
val moduleClass = loadClass(sym.fullName.toString)
moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null)
} else {
// nested object in an object
val className = {
val pack = sym.topLevelClass.owner
if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString
else pack.showFullName + "." + sym.flatName
}
val clazz = loadClass(className)
clazz.getConstructor().newInstance().asInstanceOf[Object]
}
}
private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[_] = {
val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader)
lineClassloader.loadClass(moduleClass.name.firstPart.toString)
}
private def loadClass(name: String): Class[_] = {
try classLoader.loadClass(name)
catch {
case _: ClassNotFoundException =>
val msg = s"Could not find class $name in classpath$extraMsg"
throw new StopInterpretation(msg, pos)
}
}
private def getMethod(clazz: Class[_], name: Name, paramClasses: List[Class[_]]): Method = {
try clazz.getMethod(name.toString, paramClasses: _*)
catch {
case _: NoSuchMethodException =>
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)$extraMsg"
throw new StopInterpretation(msg, pos)
}
}
private def extraMsg = ". The most common reason for that is that you apply macros in the compilation run that defines them"
private def stopIfRuntimeException[T](thunk: => T): T = {
try thunk
catch {
case ex: RuntimeException =>
val sw = new StringWriter()
sw.write("A runtime exception occurred while executing macro expansion\n")
sw.write(ex.getMessage)
sw.write("\n")
ex.printStackTrace(new PrintWriter(sw))
sw.write("\n")
throw new StopInterpretation(sw.toString, pos)
case ex: InvocationTargetException =>
val sw = new StringWriter()
sw.write("An exception occurred while executing macro expansion\n")
sw.write(ex.getTargetException.getMessage)
sw.write("\n")
ex.getTargetException.printStackTrace(new PrintWriter(sw))
sw.write("\n")
throw new StopInterpretation(sw.toString, pos)
}
}
/** List of classes of the parameters of the signature of `sym` */
private def paramsSig(sym: Symbol): List[Class[_]] = {
def paramClass(param: Type): Class[_] = {
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
case _ => (tpe, depth)
}
def javaArraySig(tpe: Type): String = {
val (elemType, depth) = arrayDepth(tpe, 0)
val sym = elemType.classSymbol
val suffix =
if (sym == defn.BooleanClass) "Z"
else if (sym == defn.ByteClass) "B"
else if (sym == defn.ShortClass) "S"
else if (sym == defn.IntClass) "I"
else if (sym == defn.LongClass) "J"
else if (sym == defn.FloatClass) "F"
else if (sym == defn.DoubleClass) "D"
else if (sym == defn.CharClass) "C"
else "L" + javaSig(elemType) + ";"
("[" * depth) + suffix
}
def javaSig(tpe: Type): String = tpe match {
case tpe: JavaArrayType => javaArraySig(tpe)
case _ =>
// Take the flatten name of the class and the full package name
val pack = tpe.classSymbol.topLevelClass.owner
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
}
val sym = param.classSymbol
if (sym == defn.BooleanClass) classOf[Boolean]
else if (sym == defn.ByteClass) classOf[Byte]
else if (sym == defn.CharClass) classOf[Char]
else if (sym == defn.ShortClass) classOf[Short]
else if (sym == defn.IntClass) classOf[Int]
else if (sym == defn.LongClass) classOf[Long]
else if (sym == defn.FloatClass) classOf[Float]
else if (sym == defn.DoubleClass) classOf[Double]
else java.lang.Class.forName(javaSig(param), false, classLoader)
}
def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
// Call implicit function type direct method
tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
case _ => Nil
}
val extraParams = getExtraParams(sym.info.finalResultType)
val allParams = TypeErasure.erasure(sym.info) match {
case meth: MethodType => meth.paramInfos ::: extraParams
case _ => extraParams
}
allParams.map(paramClass)
}
/** Exception that stops interpretation if some issue is found */
private class StopInterpretation(val msg: String, val pos: SourcePosition) extends Exception
}
object Call {
/** Matches an expression that is either a field access or an application
* It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
*/
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] =
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
private object Call0 {
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match {
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
Some((fn, args))
case fn: RefTree => Some((fn, Nil))
case Apply(f @ Call0(fn, args1), args2) =>
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
else Some((fn, args2 :: args1))
case TypeApply(Call0(fn, args), _) => Some((fn, args))
case _ => None
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy