All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dotty.tools.dotc.quoted.Interpreter.scala Maven / Gradle / Ivy

package dotty.tools.dotc
package quoted

import scala.language.unsafeNulls

import scala.collection.mutable
import scala.reflect.ClassTag

import java.io.{PrintWriter, StringWriter}
import java.lang.reflect.{InvocationTargetException, Method => JLRMethod}

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.Annotations.*
import dotty.tools.dotc.core.Constants.*
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Decorators.*
import dotty.tools.dotc.core.Denotations.staticRef
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.core.NameKinds.FlatName
import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.TypeErasure
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.quoted.*
import dotty.tools.dotc.typer.ImportInfo.withRootImports
import dotty.tools.dotc.util.SrcPos
import dotty.tools.dotc.reporting.Message
import dotty.tools.repl.AbstractFileClassLoader
import dotty.tools.dotc.core.CyclicReference

/** Tree interpreter for metaprogramming constructs */
class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context):
  import Interpreter.*
  import tpd.*

  val classLoader =
    if ctx.owner.topLevelClass.name.startsWith(str.REPL_SESSION_LINE) then
        new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader0)
    else classLoader0

  /** Local variable environment */
  type Env = Map[Symbol, Object]
  def emptyEnv: Env = Map.empty
  inline def env(using e: Env): e.type = e

  /** Returns the result of interpreting the code in the tree.
   *  Return Some of the result or None if the result type is not consistent with the expected type.
   *  Throws a StopInterpretation if the tree could not be interpreted or a runtime exception occurred.
   */
  final def interpret[T](tree: Tree)(using ct: ClassTag[T]): Option[T] =
    interpretTree(tree)(using emptyEnv) match {
      case obj: T => Some(obj)
      case obj =>
        // TODO upgrade to a full type tag check or something similar
        report.error(em"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos)
        None
    }

  /** Returns the result of interpreting the code in the tree.
   *  Throws a StopInterpretation if the tree could not be interpreted or a runtime exception occurred.
   */
  protected def interpretTree(tree: Tree)(using Env): Object = tree match {
    case Literal(Constant(value)) =>
      interpretLiteral(value)

    case tree: Ident if tree.symbol.is(Inline, butNot = Method) =>
      tree.tpe.widenTermRefExpr match
        case ConstantType(c) => c.value.asInstanceOf[Object]
        case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos)

    // TODO disallow interpreted method calls as arguments
    case Call(fn, args) =>
      if (fn.symbol.isConstructor)
        interpretNew(fn.symbol, args.flatten.map(interpretTree))
      else if (fn.symbol.is(Module))
        interpretModuleAccess(fn.symbol)
      else if (fn.symbol.is(Method) && fn.symbol.isStatic) {
        interpretedStaticMethodCall(fn.symbol.owner, fn.symbol, interpretArgs(args, fn.symbol.info))
      }
      else if fn.symbol.isStatic then
        assert(args.isEmpty)
        interpretedStaticFieldAccess(fn.symbol)
      else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic)
        if (fn.name == nme.asInstanceOfPM)
          interpretModuleAccess(fn.qualifier.symbol)
        else {
          interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol, interpretArgs(args, fn.symbol.info))
        }
      else if (env.contains(fn.symbol))
        env(fn.symbol)
      else if (tree.symbol.is(InlineProxy))
        interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
      else
        unexpectedTree(tree)

    case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) =>
      (obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj))

    // Interpret `foo(j = x, i = y)` which it is expanded to
    // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
    case Block(stats, expr) => interpretBlock(stats, expr)
    case NamedArg(_, arg) => interpretTree(arg)

    case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)

    case Typed(expr, _) =>
      interpretTree(expr)

    case SeqLiteral(elems, _) =>
      interpretVarargs(elems.map(e => interpretTree(e)))

    case _ =>
      unexpectedTree(tree)
  }

  private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = {
    def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] =
      assert(args.size == argTypes.size)
      val view =
        for (arg, info) <- args.lazyZip(argTypes) yield
          info match
            case _: ExprType => () => interpretTree(arg) // by-name argument
            case _ => interpretTree(arg) // by-value argument
      view.toList

    fnType.dealias match
      case fnType: MethodType if fnType.hasErasedParams => interpretArgs(argss, fnType.resType)
      case fnType: MethodType =>
        val argTypes = fnType.paramInfos
        assert(argss.head.size == argTypes.size)
        interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType)
      case fnType: AppliedType if defn.isContextFunctionType(fnType) =>
        val argTypes :+ resType = fnType.args: @unchecked
        interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType)
      case fnType: PolyType => interpretArgs(argss, fnType.resType)
      case fnType: ExprType => interpretArgs(argss, fnType.resType)
      case _ =>
        assert(argss.isEmpty)
        Nil
  }

  private def interpretBlock(stats: List[Tree], expr: Tree)(using Env) = {
    var unexpected: Option[Object] = None
    val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match
      case stat: ValDef =>
        accEnv.updated(stat.symbol, interpretTree(stat.rhs)(using accEnv))
      case stat =>
        if (unexpected.isEmpty)
          unexpected = Some(unexpectedTree(stat))
        accEnv
    )
    unexpected.getOrElse(interpretTree(expr)(using newEnv))
  }

  private def interpretLiteral(value: Any): Object =
    value.asInstanceOf[Object]

  private def interpretVarargs(args: List[Object]): Object =
    args.toSeq

  private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: List[Object]): Object = {
    val inst =
      try loadModule(moduleClass)
      catch
        case MissingClassDefinedInCurrentRun(sym) =>
          suspendOnMissing(sym, pos)
    val clazz = inst.getClass
    val name = fn.name.asTermName
    val method = getMethod(clazz, name, paramsSig(fn))
    stopIfRuntimeException(method.invoke(inst, args: _*), method)
  }

  private def interpretedStaticFieldAccess(sym: Symbol): Object = {
    val clazz = loadClass(sym.owner.fullName.toString)
    val field = clazz.getField(sym.name.toString)
    field.get(null)
  }

  private def interpretModuleAccess(fn: Symbol): Object =
    loadModule(fn.moduleClass)

  private def interpretNew(fn: Symbol, args: List[Object]): Object = {
    val className = fn.owner.fullName.mangledString.replaceAll("\\$\\.", "\\$")
    val clazz = loadClass(className)
    val constr = clazz.getConstructor(paramsSig(fn): _*)
    constr.newInstance(args: _*).asInstanceOf[Object]
  }

  private def unexpectedTree(tree: Tree): Object =
    throw new StopInterpretation(em"Unexpected tree could not be interpreted: ${tree.toString}", tree.srcPos)

  private def loadModule(sym: Symbol): Object =
    if (sym.owner.is(Package)) {
      // is top level object
      val moduleClass = loadClass(sym.fullName.toString)
      moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null)
    }
    else {
      // nested object in an object
      val clazz = loadClass(sym.binaryClassName)
      clazz.getConstructor().newInstance().asInstanceOf[Object]
    }

  private def loadReplLineClass(moduleClass: Symbol): Class[?] = {
    val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader)
    lineClassloader.loadClass(moduleClass.name.firstPart.toString)
  }

  private def loadClass(name: String): Class[?] =
    try classLoader.loadClass(name)
    catch
      case MissingClassDefinedInCurrentRun(sym) =>
        suspendOnMissing(sym, pos)


  private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod =
    try clazz.getMethod(name.toString, paramClasses: _*)
    catch {
      case _: NoSuchMethodException =>
        val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)"
        throw new StopInterpretation(msg, pos)
      case MissingClassDefinedInCurrentRun(sym) =>
        suspendOnMissing(sym, pos)
    }

  private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T =
    try thunk
    catch {
      case ex: RuntimeException =>
        val sw = new StringWriter()
        sw.write("A runtime exception occurred while executing macro expansion\n")
        sw.write(ex.getMessage)
        sw.write("\n")
        ex.printStackTrace(new PrintWriter(sw))
        sw.write("\n")
        throw new StopInterpretation(sw.toString.toMessage, pos)
      case ex: InvocationTargetException =>
        ex.getTargetException match {
          case ex: scala.quoted.runtime.StopMacroExpansion =>
            throw ex
          case MissingClassDefinedInCurrentRun(sym) =>
            suspendOnMissing(sym, pos)
          case targetException =>
            val sw = new StringWriter()
            sw.write("Exception occurred while executing macro expansion.\n")
            if (!ctx.settings.Ydebug.value) {
              val end = targetException.getStackTrace.lastIndexWhere { x =>
                x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName
              }
              val shortStackTrace = targetException.getStackTrace.take(end + 1)
              targetException.setStackTrace(shortStackTrace)
              targetException.printStackTrace(new PrintWriter(sw))

              targetException match
                case _: CyclicReference => sw.write("\nSee full stack trace using -Ydebug")
                case _ =>
            } else {
              targetException.printStackTrace(new PrintWriter(sw))
            }
            sw.write("\n")
            throw new StopInterpretation(sw.toString.toMessage, pos)
        }
    }

  /** List of classes of the parameters of the signature of `sym` */
  private def paramsSig(sym: Symbol): List[Class[?]] = {
    def paramClass(param: Type): Class[?] = {
      def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
        case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
        case _ => (tpe, depth)
      }
      def javaArraySig(tpe: Type): String = {
        val (elemType, depth) = arrayDepth(tpe, 0)
        val sym = elemType.classSymbol
        val suffix =
          if (sym == defn.BooleanClass) "Z"
          else if (sym == defn.ByteClass) "B"
          else if (sym == defn.ShortClass) "S"
          else if (sym == defn.IntClass) "I"
          else if (sym == defn.LongClass) "J"
          else if (sym == defn.FloatClass) "F"
          else if (sym == defn.DoubleClass) "D"
          else if (sym == defn.CharClass) "C"
          else "L" + javaSig(elemType) + ";"
        ("[" * depth) + suffix
      }
      def javaSig(tpe: Type): String = tpe match {
        case tpe: JavaArrayType => javaArraySig(tpe)
        case _ =>
          // Take the flatten name of the class and the full package name
          val pack = tpe.classSymbol.topLevelClass.owner
          val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}."
          packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
      }

      val sym = param.classSymbol
      if (sym == defn.BooleanClass) classOf[Boolean]
      else if (sym == defn.ByteClass) classOf[Byte]
      else if (sym == defn.CharClass) classOf[Char]
      else if (sym == defn.ShortClass) classOf[Short]
      else if (sym == defn.IntClass) classOf[Int]
      else if (sym == defn.LongClass) classOf[Long]
      else if (sym == defn.FloatClass) classOf[Float]
      else if (sym == defn.DoubleClass) classOf[Double]
      else java.lang.Class.forName(javaSig(param), false, classLoader)
    }
    def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
      case tp: AppliedType if defn.isContextFunctionType(tp) =>
        // Call context function type direct method
        tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
      case _ => Nil
    }
    val extraParams = getExtraParams(sym.info.finalResultType)
    val allParams = TypeErasure.erasure(sym.info) match {
      case meth: MethodType => meth.paramInfos ::: extraParams
      case _ => extraParams
    }
    allParams.map(paramClass)
  }
end Interpreter

object Interpreter:
  /** Exception that stops interpretation if some issue is found */
  class StopInterpretation(val msg: Message, val pos: SrcPos) extends Exception

  object Call:
    import tpd.*
    /** Matches an expression that is either a field access or an application
    *  It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
    */
    def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] =
      Call0.unapply(arg).map((fn, args) => (fn, args.reverse))

    private object Call0 {
      def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match {
        case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) =>
          Some((fn, args))
        case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil))
        case fn: Select => Some((fn, Nil))
        case Apply(f @ Call0(fn, args1), args2) =>
          if (f.tpe.widenDealias.hasErasedParams) Some((fn, args1))
          else Some((fn, args2 :: args1))
        case TypeApply(Call0(fn, args), _) => Some((fn, args))
        case _ => None
      }
    }
  end Call

  object MissingClassDefinedInCurrentRun {
    def unapply(targetException: Throwable)(using Context): Option[Symbol] = {
      if !ctx.compilationUnit.isSuspendable then None
      else targetException match
        case _: NoClassDefFoundError | _: ClassNotFoundException =>
          val message = targetException.getMessage
          if message eq null then None
          else
            val className = message.replace('/', '.')
            val sym =
              if className.endsWith(str.MODULE_SUFFIX) then staticRef(className.toTermName).symbol.moduleClass
              else staticRef(className.toTypeName).symbol
            // If the symbol does not a a position we assume that it came from the current run and it has an error
            if sym.isDefinedInCurrentRun || (sym.exists && !sym.srcPos.span.exists) then Some(sym)
            else None
        case _ => None
    }
  }

  def suspendOnMissing(sym: Symbol, pos: SrcPos)(using Context): Nothing =
    if ctx.settings.XprintSuspension.value then
      report.echo(i"suspension triggered by a dependency on $sym", pos)
    ctx.compilationUnit.suspend() // this throws a SuspendException




© 2015 - 2025 Weber Informatics LLC | Privacy Policy