All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dotty.tools.dotc.transform.Splicer.scala Maven / Gradle / Ivy

package dotty.tools.dotc
package transform

import java.io.{PrintWriter, StringWriter}
import java.lang.reflect.Method

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags.Package
import dotty.tools.dotc.core.NameKinds.FlatName
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD
import dotty.tools.dotc.core.quoted._
import dotty.tools.dotc.core.Types._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.TypeErasure
import dotty.tools.dotc.tastyreflect.TastyImpl

import scala.util.control.NonFatal
import dotty.tools.dotc.util.Positions.Position

import scala.reflect.ClassTag

/** Utility class to splice quoted expressions */
object Splicer {
  import tpd._

  /** Splice the Tree for a Quoted expression. `~'(xyz)` becomes `xyz`
   *  and for `~xyz` the tree of `xyz` is interpreted for which the
   *  resulting expression is returned as a `Tree`
   *
   *  See: `ReifyQuotes`
   */
  def splice(tree: Tree, call: Tree, bindings: List[Tree], pos: Position, classLoader: ClassLoader)(implicit ctx: Context): Tree = tree match {
    case Quoted(quotedTree) => quotedTree
    case _ =>
      val liftedArgs = getLiftedArgs(call, bindings)
      val interpreter = new Interpreter(pos, classLoader)
      val interpreted = interpreter.interpretCallToSymbol[Seq[Any] => Object](call.symbol)
      val tctx = new TastyImpl(ctx)
      evaluateMacro(pos) {
        // Some parts of the macro are evaluated during the unpickling performed in quotedExprToTree
        val evaluated = interpreted.map(lambda => lambda(tctx :: liftedArgs).asInstanceOf[scala.quoted.Expr[Nothing]])
        evaluated.fold(tree)(PickledQuotes.quotedExprToTree)
      }
  }

  /** Given the inline code and bindings, compute the lifted arguments that will be used to execute the macro
   *  - Type parameters are lifted to quoted.Types.TreeType
   *  - Inline parameters are listed as their value
   *  - Other parameters are lifted to quoted.Types.TreeExpr (may reference a binding)
   */
  private def getLiftedArgs(call: Tree, bindings: List[Tree])(implicit ctx: Context): List[Any] = {
    val bindMap = bindings.collect {
      case vdef: ValDef => (vdef.rhs, ref(vdef.symbol).withPos(vdef.rhs.pos))
    }.toMap
    def allArgs(call: Tree, acc: List[List[Tree]]): List[List[Tree]] = call match {
      case call: Apply => allArgs(call.fun, call.args :: acc)
      case call: TypeApply => allArgs(call.fun, call.args :: acc)
      case _ => acc
    }
    def liftArgs(tpe: Type, args: List[List[Tree]]): List[Any] = tpe match {
      case tp: MethodType =>
        val args1 = args.head.zip(tp.paramInfos).map {
          case (arg: Literal, tp) if tp.hasAnnotation(defn.InlineParamAnnot) => arg.const.value
          case (arg, tp) =>
            assert(!tp.hasAnnotation(defn.InlineParamAnnot))
            // Replace argument by its binding
            val arg1 = bindMap.getOrElse(arg, arg)
            new scala.quoted.Exprs.TastyTreeExpr(arg1)
        }
        args1 ::: liftArgs(tp.resType, args.tail)
      case tp: PolyType =>
        val args1 = args.head.map(tp => new scala.quoted.Types.TreeType(tp))
        args1 ::: liftArgs(tp.resType, args.tail)
      case _ => Nil
    }

    liftArgs(call.symbol.info, allArgs(call, Nil))
  }

  /* Evaluate the code in the macro and handle exceptions durring evaluation */
  private def evaluateMacro(pos: Position)(code: => Tree)(implicit ctx: Context): Tree = {
    try code
    catch {
      case ex: scala.quoted.QuoteError =>
        ctx.error(ex.getMessage, pos)
        EmptyTree
      case NonFatal(ex) =>
        val msg =
          s"""Failed to evaluate inlined quote.
             |  Caused by ${ex.getClass}: ${if (ex.getMessage == null) "" else ex.getMessage}
             |    ${ex.getStackTrace.takeWhile(_.getClassName != "dotty.tools.dotc.transform.Splicer$").init.mkString("\n    ")}
         """.stripMargin
        ctx.error(msg, pos)
        EmptyTree
    }
  }

  /** Tree interpreter that can interpret calls to static methods with it's default arguments
   *
   *  The interpreter assumes that all calls in the trees are to code that was
   *  previously compiled and is present in the classpath of the current context.
   */
  private class Interpreter(pos: Position, classLoader: ClassLoader)(implicit ctx: Context) {

    /** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
     *  Return Some of the result or None if some error happen during the interpretation.
     */
    def interpretCallToSymbol[T](sym: Symbol)(implicit ct: ClassTag[T]): Option[T] = {
      try {
        val (clazz, instance) = loadModule(sym.owner)
        val paramClasses = paramsSig(sym)
        val interpretedArgs = paramClasses.map(defaultValue)
        val method = getMethod(clazz, sym.name, paramClasses)
        stopIfRuntimeException(method.invoke(instance, interpretedArgs: _*)) match {
          case obj: T => Some(obj)
          case obj =>
            // TODO upgrade to a full type tag check or something similar
            ctx.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
            None
        }
      } catch {
        case ex: StopInterpretation =>
          ctx.error(ex.msg, ex.pos)
          None
      }
    }

    private def loadModule(sym: Symbol): (Class[_], Object) = {
      if (sym.owner.is(Package)) {
        // is top level object
        val moduleClass = loadClass(sym.fullName)
        val moduleInstance = moduleClass.getField(MODULE_INSTANCE_FIELD).get(null)
        (moduleClass, moduleInstance)
      } else {
        // nested object in an object
        val clazz = loadClass(sym.fullNameSeparated(FlatName))
        (clazz, clazz.newInstance().asInstanceOf[Object])
      }
    }

    private def loadClass(name: Name): Class[_] = {
      try classLoader.loadClass(name.toString)
      catch {
        case _: ClassNotFoundException =>
          val msg = s"Could not find macro class $name in classpath$extraMsg"
          throw new StopInterpretation(msg, pos)
      }
    }

    private def getMethod(clazz: Class[_], name: Name, paramClasses: List[Class[_]]): Method = {
      try clazz.getMethod(name.toString, paramClasses: _*)
      catch {
        case _: NoSuchMethodException =>
          val msg = s"Could not find inline macro method ${clazz.getCanonicalName}.$name with parameters $paramClasses$extraMsg"
          throw new StopInterpretation(msg, pos)
      }
    }

    private def extraMsg = ". The most common reason for that is that you cannot use inline macro implementations in the same compilation run that defines them"

    private def stopIfRuntimeException[T](thunk: => T): T = {
      try thunk
      catch {
        case ex: RuntimeException =>
          val sw = new StringWriter()
          sw.write("A runtime exception occurred while executing macro expansion\n")
          sw.write(ex.getMessage)
          sw.write("\n")
          ex.printStackTrace(new PrintWriter(sw))
          sw.write("\n")
          throw new StopInterpretation(sw.toString, pos)
      }
    }

    /** List of classes of the parameters of the signature of `sym` */
    private def paramsSig(sym: Symbol): List[Class[_]] = {
      TypeErasure.erasure(sym.info) match {
        case meth: MethodType =>
          meth.paramInfos.map { param =>
            def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
              case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
              case _ => (tpe, depth)
            }
            def javaArraySig(tpe: Type): String = {
              val (elemType, depth) = arrayDepth(tpe, 0)
              val sym = elemType.classSymbol
              val suffix =
                if (sym == defn.BooleanClass) "Z"
                else if (sym == defn.ByteClass) "B"
                else if (sym == defn.ShortClass) "S"
                else if (sym == defn.IntClass) "I"
                else if (sym == defn.LongClass) "J"
                else if (sym == defn.FloatClass) "F"
                else if (sym == defn.DoubleClass) "D"
                else if (sym == defn.CharClass) "C"
                else "L" + javaSig(elemType) + ";"
              ("[" * depth) + suffix
            }
            def javaSig(tpe: Type): String = tpe match {
              case tpe: JavaArrayType => javaArraySig(tpe)
              case _ =>
                // Take the flatten name of the class and the full package name
                val pack = tpe.classSymbol.topLevelClass.owner
                val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
                packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
            }

            val sym = param.classSymbol
            if (sym == defn.BooleanClass) classOf[Boolean]
            else if (sym == defn.ByteClass) classOf[Byte]
            else if (sym == defn.CharClass) classOf[Char]
            else if (sym == defn.ShortClass) classOf[Short]
            else if (sym == defn.IntClass) classOf[Int]
            else if (sym == defn.LongClass) classOf[Long]
            else if (sym == defn.FloatClass) classOf[Float]
            else if (sym == defn.DoubleClass) classOf[Double]
            else java.lang.Class.forName(javaSig(param), false, classLoader)
          }
        case _ => Nil
      }
    }

    /** Get the default value for the given class */
    private def defaultValue(clazz: Class[_]): Object = {
      if (clazz == classOf[Boolean]) false.asInstanceOf[Object]
      else if (clazz == classOf[Byte]) 0.toByte.asInstanceOf[Object]
      else if (clazz == classOf[Char]) 0.toChar.asInstanceOf[Object]
      else if (clazz == classOf[Short]) 0.asInstanceOf[Object]
      else if (clazz == classOf[Int]) 0.asInstanceOf[Object]
      else if (clazz == classOf[Long]) 0L.asInstanceOf[Object]
      else if (clazz == classOf[Float]) 0f.asInstanceOf[Object]
      else if (clazz == classOf[Double]) 0d.asInstanceOf[Object]
      else null
    }

    /** Exception that stops interpretation if some issue is found */
    private class StopInterpretation(val msg: String, val pos: Position) extends Exception

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy