All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.scalanative.nscplugin.GenReflectiveInstantisation.scala Maven / Gradle / Ivy

The newest version!
package scala.scalanative
package nscplugin

import scala.language.implicitConversions

import dotty.tools.dotc.ast.tpd._
import dotty.tools.dotc.core
import core.Contexts._
import core.Symbols._
import core.StdNames._
import core.Flags._

import scala.collection.mutable
import scala.scalanative.util.ScopedVar

object GenReflectiveInstantisation {
  object nirSymbols {
    val AbstractFunction0Name =
      nir.Global.Top("scala.runtime.AbstractFunction0")
    val AbstractFunction1Name =
      nir.Global.Top("scala.runtime.AbstractFunction1")
    val SerializableName = nir.Global.Top("java.io.Serializable")
    val Tuple2Name = nir.Global.Top("scala.Tuple2")
    val Tuple2Ref = nir.Type.Ref(Tuple2Name)
  }
}

trait GenReflectiveInstantisation(using Context) {
  self: NirCodeGen =>
  import GenReflectiveInstantisation._
  import positionsConversions.given

  protected val reflectiveInstantiationBuffers =
    mutable.UnrolledBuffer.empty[ReflectiveInstantiationBuffer]

  protected class ReflectiveInstantiationBuffer(fqcn: String) {
    val name = nir.Global.Top(fqcn + "$scalanative$ReflectivelyInstantiate$")
    reflectiveInstantiationBuffers += this
    private val buf = mutable.UnrolledBuffer.empty[nir.Defn]

    def +=(defn: nir.Defn): Unit = buf += defn
    def nonEmpty = buf.nonEmpty
    def toSeq = buf.toSeq
  }

  def genReflectiveInstantiation(td: TypeDef): Unit = {
    val sym = td.symbol.asClass
    val enableReflectiveInstantiation =
      sym.baseClasses
        .exists(
          _.hasAnnotation(defnNir.EnableReflectiveInstantiationAnnotationClass)
        )

    if (enableReflectiveInstantiation) {
      ScopedVar.scoped(
        curClassSym := sym,
        curFresh := nir.Fresh(),
        curScopeId := nir.ScopeId.TopLevel,
        curUnwindHandler := None,
        curMethodThis := None
      ) {
        registerReflectiveInstantiation(td)
      }
    }
  }

  private def registerReflectiveInstantiation(td: TypeDef): Unit = {
    given nir.SourcePosition = td.span
    val sym: Symbol = curClassSym
    val owner = genTypeName(sym)
    val name = owner.member(nir.Sig.Clinit)

    val staticInitBody =
      if (curClassSym.get.is(flag = Module, butNot = Lifted))
        Some(registerModuleClass(td))
      else if (sym.is(Module))
        None // see: https://github.com/scala-js/scala-js/issues/3228
      else if (sym.is(Lifted) && !sym.originalOwner.isClass)
        None // see: https://github.com/scala-js/scala-js/issues/3227
      else Some(registerNormalClass(td))

    staticInitBody
      .filter(_.nonEmpty)
      .foreach { body =>
        generatedDefns += new nir.Defn.Define(
          nir.Attrs(),
          name,
          nir.Type.Function(Seq.empty[nir.Type], nir.Type.Unit),
          body
        )
      }
  }

  private def registerModuleClass(
      td: TypeDef
  ): Seq[nir.Inst] = {
    val fqSymId = curClassSym.get.fullName.mangledString
    val fqSymName = nir.Global.Top(fqSymId)
    val fqcnArg = nir.Val.String(fqSymId)
    val runtimeClassArg = nir.Val.ClassOf(fqSymName)

    given nir.SourcePosition = td.span
    given reflInstBuffer: ReflectiveInstantiationBuffer =
      ReflectiveInstantiationBuffer(fqSymId)

    withFreshExprBuffer { buf ?=>

      buf.label(curFresh(), Seq.empty)
      val loadModuleFunArg = genModuleLoader(fqSymName)
      buf.genApplyModuleMethod(
        defnNir.ReflectModule,
        defnNir.Reflect_registerLoadableModuleClass,
        Seq(fqcnArg, runtimeClassArg, loadModuleFunArg).map(ValTree(_)(td.span))
      )
      buf.ret(nir.Val.Unit)
      buf.toSeq
    }
  }

  private def registerNormalClass(
      td: TypeDef
  ): Seq[nir.Inst] = {
    given nir.SourcePosition = td.span

    val fqSymId = curClassSym.get.fullName.mangledString
    val fqSymName = nir.Global.Top(fqSymId)
    val fqcnArg = nir.Val.String(fqSymId)
    val runtimeClassArg = nir.Val.ClassOf(fqSymName)

    // Collect public constructors.
    val ctors =
      if (curClassSym.get.isOneOf(AbstractOrTrait)) Nil
      else
        curClassSym.get.info
          .member(nme.CONSTRUCTOR)
          .alternatives
          .collect {
            case denot if denot.asSymDenotation.isPublic =>
              denot.asSymDenotation.underlyingSymbol
          }

    if (ctors.isEmpty) Nil
    else
      withFreshExprBuffer { buf ?=>
        buf.label(curFresh(), Seq.empty)
        val instantiateClassFunArg = genClassConstructorsInfo(fqSymName, ctors)
        buf.genApplyModuleMethod(
          defnNir.ReflectModule,
          defnNir.Reflect_registerInstantiatableClass,
          Seq(fqcnArg, runtimeClassArg, instantiateClassFunArg)
            .map(ValTree(_)(td.span))
        )
        buf.ret(nir.Val.Unit)
        buf.toSeq
      }
  }

  // Generate the constructor for the class instantiator class,
  // which is expected to extend one of scala.runtime.AbstractFunctionX.
  private def genConstructor(
      superClass: nir.Global.Top
  )(using
      nir.SourcePosition
  )(using reflInstBuffer: ReflectiveInstantiationBuffer): Unit = {
    withFreshExprBuffer { buf ?=>
      val body = {
        // first argument is this
        val thisArg =
          nir.Val.Local(curFresh(), nir.Type.Ref(reflInstBuffer.name))
        buf.label(curFresh(), Seq(thisArg))
        // call to super constructor
        buf.call(
          nir.Type.Function(Seq(nir.Type.Ref(superClass)), nir.Type.Unit),
          nir.Val
            .Global(superClass.member(nir.Sig.Ctor(Seq.empty)), nir.Type.Ptr),
          Seq(thisArg),
          unwind(curFresh)
        )
        buf.ret(nir.Val.Unit)
        buf.toSeq
      }

      reflInstBuffer += new nir.Defn.Define(
        nir.Attrs(),
        reflInstBuffer.name.member(nir.Sig.Ctor(Seq.empty)),
        nir.Type
          .Function(Seq(nir.Type.Ref(reflInstBuffer.name)), nir.Type.Unit),
        body
      )
    }
  }

// Allocate and construct an object, using the provided ExprBuffer.
  private def allocAndConstruct(
      name: nir.Global.Top,
      argTypes: Seq[nir.Type],
      args: Seq[nir.Val]
  )(using pos: nir.SourcePosition, buf: ExprBuffer): nir.Val = {
    val alloc = buf.classalloc(name, unwind(curFresh))
    buf.call(
      nir.Type.Function(nir.Type.Ref(name) +: argTypes, nir.Type.Unit),
      nir.Val.Global(name.member(nir.Sig.Ctor(argTypes)), nir.Type.Ptr),
      alloc +: args,
      unwind(curFresh)
    )
    alloc
  }

  private def genModuleLoader(
      fqSymName: nir.Global.Top
  )(using
      pos: nir.SourcePosition,
      buf: ExprBuffer,
      reflInstBuffer: ReflectiveInstantiationBuffer
  ): nir.Val = {
    val applyMethodSig = nir.Sig.Method("apply", Seq(nir.Rt.Object))
    val enclosingClass = curClassSym.get.originalOwner

    // Generate the module loader class. The generated class extends
    // AbstractFunction0[Any], i.e. has an apply method, which loads the module.
    // We need a fresh ExprBuffer for this, since it is different scope.
    withFreshExprBuffer { buf ?=>
      val body = {
        // first argument is this
        val thisArg =
          nir.Val.Local(curFresh(), nir.Type.Ref(reflInstBuffer.name))
        buf.label(curFresh(), Seq(thisArg))

        val module =
          if (enclosingClass.exists && !enclosingClass.is(ModuleClass))
            nir.Val.Null
          else buf.module(fqSymName, unwind(curFresh))
        buf.ret(module)
        buf.toSeq
      }

      reflInstBuffer += new nir.Defn.Define(
        nir.Attrs(),
        reflInstBuffer.name.member(applyMethodSig),
        nir.Type
          .Function(Seq(nir.Type.Ref(reflInstBuffer.name)), nir.Rt.Object),
        body
      )
    }

    // Generate the module loader class constructor.
    genConstructor(nirSymbols.AbstractFunction0Name)

    reflInstBuffer += nir.Defn.Class(
      nir.Attrs(),
      reflInstBuffer.name,
      Some(nirSymbols.AbstractFunction0Name),
      Seq(nirSymbols.SerializableName)
    )

    // Allocate and return an instance of the generated class.
    allocAndConstruct(reflInstBuffer.name, Seq.empty, Seq.empty)(using pos, buf)
  }

  // Create a new Tuple2 and initialise it with the provided values.
  private def createTuple(arg1: nir.Val, arg2: nir.Val)(using
      nir.SourcePosition,
      ExprBuffer
  ): nir.Val = {
    allocAndConstruct(
      nirSymbols.Tuple2Name,
      Seq(nir.Rt.Object, nir.Rt.Object),
      Seq(arg1, arg2)
    )
  }

  private def genClassConstructorsInfo(
      fqSymName: nir.Global.Top,
      ctors: Seq[Symbol]
  )(using pos: nir.SourcePosition, buf: ExprBuffer): nir.Val = {
    val applyMethodSig =
      nir.Sig.Method("apply", Seq(nir.Rt.Object, nir.Rt.Object))

    // Constructors info is an array of Tuple2 (tpes, inst), where:
    // - tpes is an array with the runtime classes of the constructor arguments.
    // - inst is a function, which accepts an array with tpes and returns a new
    //   instance of the class.
    val ctorsInfo = buf.arrayalloc(
      nir.Type.Array(nirSymbols.Tuple2Ref),
      nir.Val.Int(ctors.length),
      unwind(curFresh)
    )

    // For each (public) constructor C, generate a lambda responsible for
    // initialising and returning an instance of the class, using C.
    for ((ctor, ctorIdx) <- ctors.zipWithIndex) {
      val ctorSig = genMethodSig(ctor)
      given nir.SourcePosition = ctor.span
      val ctorSuffix = if (ctorIdx == 0) "" else s"$$$ctorIdx"
      given reflInstBuffer: ReflectiveInstantiationBuffer =
        ReflectiveInstantiationBuffer(fqSymName.id + ctorSuffix)

      // Lambda generation consists of generating a class which extends
      // scala.runtime.AbstractFunction1, with an apply method that accepts
      // the list of arguments, instantiates an instance of the class by
      // forwarding the arguments to C, and returns the instance.
      withFreshExprBuffer { buf ?=>
        val body = {
          // first argument is this
          val thisArg =
            nir.Val.Local(curFresh(), nir.Type.Ref(reflInstBuffer.name))
          // second argument is parameters sequence
          val argsArg = nir.Val.Local(curFresh(), nir.Type.Array(nir.Rt.Object))
          buf.label(curFresh(), Seq(thisArg, argsArg))

          // Extract and cast arguments to proper types.
          val argsVals =
            for (arg, argIdx) <- ctorSig.args.tail.zipWithIndex
            yield {
              val elem =
                buf.arrayload(
                  nir.Rt.Object,
                  argsArg,
                  nir.Val.Int(argIdx),
                  unwind(curFresh)
                )
              // If the expected argument type can be boxed (i.e. is a primitive
              // type), then we need to unbox it before passing it to C.
              nir.Type.box.get(arg) match {
                case Some(bt) => buf.unbox(bt, elem, unwind(curFresh))
                case None     => buf.as(arg, elem, unwind(curFresh))
              }
            }

          // Allocate a new instance and call constructor
          val alloc = allocAndConstruct(
            fqSymName,
            ctorSig.args.tail,
            argsVals
          )

          buf.ret(alloc)
          buf.toSeq
        }

        reflInstBuffer += new nir.Defn.Define(
          nir.Attrs.None,
          reflInstBuffer.name.member(applyMethodSig),
          nir.Type.Function(
            Seq(
              nir.Type.Ref(reflInstBuffer.name),
              nir.Type.Array(nir.Rt.Object)
            ),
            nir.Rt.Object
          ),
          body
        )
      }

      // Generate the class instantiator constructor.
      genConstructor(nirSymbols.AbstractFunction1Name)

      reflInstBuffer += nir.Defn.Class(
        nir.Attrs(),
        reflInstBuffer.name,
        Some(nirSymbols.AbstractFunction1Name),
        Seq(nirSymbols.SerializableName)
      )

      // Allocate an instance of the generated class.
      val instantiator =
        allocAndConstruct(reflInstBuffer.name, Seq.empty, Seq.empty)

      // Create the current constructor's info. We need:
      // - an array with the runtime classes of the ctor parameters.
      // - the instantiator function created above (instantiator).
      val rtClasses = buf.arrayalloc(
        nir.Rt.Class,
        nir.Val.Int(ctorSig.args.tail.length),
        unwind(curFresh)
      )
      for ((arg, argIdx) <- ctorSig.args.tail.zipWithIndex) {
        // Store the runtime class in the array.
        buf.arraystore(
          nir.Rt.Class,
          rtClasses,
          nir.Val.Int(argIdx),
          nir.Val.ClassOf(nir.Type.typeToName(arg)),
          unwind(curFresh)
        )
      }

      // Allocate a tuple to store the current constructor's info
      val to = createTuple(rtClasses, instantiator)

      buf.arraystore(
        nirSymbols.Tuple2Ref,
        ctorsInfo,
        nir.Val.Int(ctorIdx),
        to,
        unwind(curFresh)
      )
    }
    ctorsInfo
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy