All Downloads are FREE. Search and download functionalities are using the official Maven repository.

plugin.CommonPhase.scala Maven / Gradle / Ivy

The newest version!
package dfhdl.plugin

import dotty.tools.dotc.*
import plugins.*
import core.*
import Contexts.*
import Symbols.*
import Flags.*
import SymDenotations.*
import Decorators.*
import ast.Trees.*
import ast.{tpd, untpd}
import StdNames.nme
import Names.*
import Types.*
import Constants.Constant
import dotty.tools.dotc.ast.tpd.Tree
import annotation.tailrec
import scala.language.implicitConversions
import scala.compiletime.uninitialized
import scala.annotation.targetName

given canEqualNothingL: CanEqual[Nothing, Any] = CanEqual.derived
given canEqualNothingR: CanEqual[Any, Nothing] = CanEqual.derived

abstract class CommonPhase extends PluginPhase:

  import tpd._

  val debugFilter: String => Boolean = _ => false
  var pluginDebugSource: String = ""

  def debug(str: => Any*): Unit =
    if (debugFilter(pluginDebugSource)) println(str.mkString(", "))

  protected def mkOptionString(argOpt: Option[String])(using Context): Tree =
    argOpt match
      case Some(str) =>
        New(
          defn.SomeClass.typeRef.appliedTo(defn.StringType),
          Literal(Constant(str)) :: Nil
        )
      case None =>
        mkNone

  protected def mkSome(tree: Tree)(using Context): Tree =
    ref(requiredMethod("scala.Some.apply"))
      .appliedToType(tree.tpe)
      .appliedTo(tree)

  protected def mkNone(using Context): Tree =
    ref(defn.NoneModule.termRef)

  protected def mkOption(optTree: Option[Tree])(using Context): Tree =
    optTree.map(mkSome).getOrElse(mkNone)

  protected def mkList(tree: List[Tree], tpeOpt: Option[Type] = None)(using Context): Tree =
    if (tree.isEmpty) ref(defn.NilModule)
    else
      val tpe = tpeOpt.getOrElse(tree.view.map(_.tpe).reduce(_ | _).widenUnion)
      tpd.mkList(tree, TypeTree(tpe))

  protected def mkTuple(trees: List[Tree])(using Context): Tree =
    ref(requiredMethod(s"scala.Tuple${trees.length}.apply"))
      .appliedToTypes(trees.map(_.tpe.widen))
      .appliedToArgs(trees)

  private val dropProxiesTreeMap = new TreeMap:
    override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
      super.transform(tree) match
        case block: Block =>
          val proxyMap = block.stats.collect {
            case vd: ValDef if vd.symbol.is(Synthetic) => vd.name -> vd.rhs
          }.toMap
          block.expr match
            case Apply(fun, proxies) =>
              Apply(
                fun,
                proxies.map {
                  case Ident(n)              => proxyMap(n.toTermName)
                  case NamedArg(x, Ident(n)) => NamedArg(x, proxyMap(n.toTermName))
                  case x                     => x
                }
              )
            case tree => tree
        case tree => tree

  protected def dropProxies(tree: Tree)(using Context): Tree =
    dropProxiesTreeMap.transform(tree)

  var metaContextTpe: TypeRef = uninitialized
  var metaContextCls: ClassSymbol = uninitialized
  var metaContextIgnoreAnnotSym: ClassSymbol = uninitialized
  var metaContextForwardAnnotSym: ClassSymbol = uninitialized
  var metaGenSym: Symbol = uninitialized
  var positionCls: ClassSymbol = uninitialized
  var contextFunctionSym: Symbol = uninitialized
  var hasDFCTpe: TypeRef = uninitialized
  var inlineAnnotSym: Symbol = uninitialized
  var dfValSym: Symbol = uninitialized
  var constModTpe: Type = uninitialized
  var genDesignParamSym: TermSymbol = uninitialized

  extension (tree: TypeDef)
    def hasDFC(using Context): Boolean =
      (tree.tpe <:< hasDFCTpe) // && (dfSpecTpe == NoType || !(tree.tpe <:< dfSpecTpe))

  // replacing the old arg references according to the argument map
  def replaceArgs(expr: Tree, argMap: Map[Symbol, Tree])(using Context): Tree =
    val replacer = new TreeMap():
      override def transform(tree: Tree)(using Context): Tree =
        tree match
          case id: (Select | Ident) if argMap.contains(id.symbol) =>
            argMap(id.symbol)
          case _ =>
            super.transform(tree)
    replacer.transform(expr)
  end replaceArgs

  extension (tpe: Type)(using Context)
    def dfValTpeOpt: Option[Type] =
      tpe.dealias match
        case res if res.dealias.typeSymbol == dfValSym => Some(res)
        case _                                         => None
    def isDFConst: Boolean =
      tpe.widenDealias match
        case AppliedType(_, _ :: modifierTpe :: Nil) =>
          modifierTpe.dealias match
            case AppliedType(_, _ :: _ :: _ :: pTpe :: Nil) =>
              pTpe =:= constModTpe
            case _ => false
        case _ =>
          false
  end extension

  extension (tree: ValOrDefDef)(using Context)
    def dfValTpeOpt: Option[Type] =
      tree.tpt.tpe.dfValTpeOpt

  extension (tree: ValOrDefDef)(using Context)
    def genMeta: Tree =
      val nameOptTree = mkOptionString(Some(tree.name.toString.nameCheck(tree)))
      val positionTree = tree.srcPos.positionTree
      val docOptTree = mkOptionString(tree.symbol.docString)
      val annotTree = mkList(tree.symbol.annotations.map(_.tree))
      ref(metaGenSym).appliedToArgs(
        nameOptTree :: positionTree :: docOptTree :: annotTree :: Nil
      )
  end extension

  extension (v: ValDef)(using Context)
    def genDesignParamValDef(dfcTree: Tree): ValDef =
      val meta = v.genMeta
      val paramGen =
        ref(genDesignParamSym)
          .appliedToType(v.tpt.tpe)
          .appliedToArgs(List(ref(v.symbol), meta))
          .appliedTo(dfcTree)
      val uniqueName = NameKinds.UniqueName.fresh(s"${v.name}_plugin".toTermName)
      val flags: FlagSet = if (ctx.owner.isConstructor) Private else EmptyFlags
      SyntheticValDef(uniqueName, paramGen, flags)

  extension (sym: Symbol)
    def ignoreMetaContext(using Context): Boolean =
      sym.hasAnnotation(metaContextIgnoreAnnotSym)
    def forwardMetaContext(using Context): Boolean =
      sym.hasAnnotation(metaContextForwardAnnotSym)
    def inherits(parentFullName: String)(using Context): Boolean =
      if (sym.isClass)
        sym.asClass.parentSyms.exists(ps =>
          ps.fullName.toString == parentFullName || ps.inherits(parentFullName)
        )
      else false
    def docString(using Context): Option[String] =
      extension (c: Context)
        def docCtx: Option[Comments.ContextDocstrings] = c.property(Comments.ContextDoc)
      def removeLastLineWhitespace(input: String): String =
        val lines = input.split("\n")
        if (lines.length <= 1) input
        else
          val lastIndex = lines.length - 1
          val lastLineWithoutWhitespace = lines(lastIndex).trim
          lines.slice(0, lastIndex).mkString("\n") + "\n" + lastLineWithoutWhitespace
      def extract(input: String): String =
        val pattern = """(?s)/\*(.*?)\*/""".r
        val extractedText = pattern.findFirstMatchIn(input).map(_.group(1)).getOrElse("")
        removeLastLineWhitespace(extractedText).stripMargin('*')
      end extract
      def extractParamDescription(docstring: String, paramName: String): Option[String] =
        val pattern = (s"@param\\s+$paramName\\s+([^@]*)").r
        pattern.findFirstMatchIn(docstring) match
          case Some(m) => Some(m.group(1).trim)
          case None    => None
      if (sym.is(Param))
        sym.owner.docString.flatMap(d => extractParamDescription(d, sym.name.toString))
      else if (sym.isConstructor)
        sym.owner.docString
      else ctx.docCtx.flatMap(_.docstring(sym)).map(_.raw).map(extract)
    end docString

    def staticAnnotations(using Context): List[Annotations.Annotation] =
      sym.annotations.collect {
        case a if a.tree.tpe <:< defn.StaticAnnotationClass.typeRef => a
      }
  end extension

  extension (name: String)
    def nameCheck(posTree: Tree)(using Context): String =
      val finalName =
        posTree.symbol.getAnnotation(defn.TargetNameAnnot)
          .flatMap(_.argumentConstantString(0))
          .getOrElse(name)
      if (
        !finalName.matches("^[a-zA-Z0-9_]*$") && !posTree.symbol.flags.is(
          Flags.Synthetic
        )
      )
        report.error(
          s"""Unsupported DFHDL member name $finalName.
           |Only alphanumric or underscore characters are supported.
           |You can leave the Scala name as-is and add @targetName("newName") annotation.""".stripMargin,
          posTree.srcPos
        )
      finalName

  // custom replacement for compiler defn.ContextFunctionType
  object ContextFunctionType:
    def unapply(tp: Type)(using Context): Option[(List[Type], Type)] =
      defn.asContextFunctionType(tp) match
        case tp1 if tp1.exists =>
          val args = tp1.dropDependentRefinement.argInfos
          Some((args.init, args.last))
        case _ => None

  extension (tree: ValOrDefDef)(using Context)
    def isInline: Boolean =
      val sym = tree.symbol
      sym.is(Inline) || sym.hasAnnotation(inlineAnnotSym)

  extension (tpe: Type)(using Context)
    def simple: Type =
      tpe match
        case tr: TermRef        => tr.underlying.dealias
        case ann: AnnotatedType => ann.parent.simple
        case _                  => tpe.dealias
    @tailrec private def flattenConsTuple(pastArgs: List[Type]): Type = tpe.stripAnnots match
      case emptyTuple if emptyTuple.typeSymbol == defn.EmptyTupleModule =>
        AppliedType(requiredClassRef(s"scala.Tuple${pastArgs.length}"), pastArgs.reverse)
      case AppliedType(tycon, head :: next :: Nil) if tycon.typeSymbol == defn.PairClass =>
        next.flattenConsTuple(head :: pastArgs)
      case _ =>
        tpe
    def flattenConsTuple: Type = tpe.flattenConsTuple(Nil)
  end extension
  extension (tp: Type)(using Context)
    def dfcFuncTpeOptRecur: Option[Type] =
      tp.dealias match
        case ContextFunctionType(ctx, res) if ctx.head <:< metaContextTpe => Some(res)
        case AppliedType(tycon, args) =>
          var requiresUpdate = false
          val updatedArgs = args.map { tp =>
            tp.dfcFuncTpeOptRecur match
              case Some(tp) =>
                requiresUpdate = true
                tp
              case None => tp
          }
          if (requiresUpdate) Some(AppliedType(tycon, updatedArgs))
          else None
        case _ => None
    def dfcFuncTpeOpt: Option[Type] =
      tp.dealias match
        case ContextFunctionType(ctx, res) if ctx.head <:< metaContextTpe =>
          Some(res)
        case _ => None
  end extension

  extension (srcPos: util.SrcPos)(using Context)
    def show: String =
      val pos = srcPos.startPos
      val endPos = srcPos.endPos
      s"${pos.source.path}:${pos.line}:${pos.column}-${endPos.line}:${endPos.column}"

  extension (tree: Apply)(using Context)
    def replaceArg(fromArg: Tree, toArg: Tree): Apply =
      var changed = false
      val repArgs = tree.args.map { a =>
        if (a eq fromArg)
          changed = true
          toArg
        else a
      }
      tree.fun match
        case apply: Apply if !changed =>
          Apply(apply.replaceArg(fromArg, toArg), tree.args)
        case _ =>
          Apply(tree.fun, repArgs)

  extension (tree: Apply)(using Context)
    def isContextDelegate: Boolean =
      tree.symbol.hasAnnotation(requiredClass("dfhdl.internals.metaContextDelegate"))

  extension (tree: ValOrDefDef)(using Context)
    def ident: Tree =
      untpd.Ident(tree.name).withType(tree.tpe)

  object ContextArg:
    def unapply(tree: Tree)(using Context): Option[Tree] =
      tree match
        case Apply(tree, args) =>
          args
            .collectFirst {
              case a if a.tpe <:< metaContextTpe =>
                a
            }
            .orElse(unapply(tree))
        case _ => None

    def at(tree: DefDef | TypeDef)(using Context): Option[Tree] =
      tree match
        case tree: DefDef =>
          tree.paramss.flatten.view.reverse.collectFirst {
            case a @ ValDef(name, _, _) if a.tpe <:< metaContextTpe =>
              a.ident
          }
        case TypeDef(name, _: Template) if tree.tpe <:< hasDFCTpe =>
          Some(This(tree.symbol.asClass).select("dfc".toTermName))
        case _ => None
  end ContextArg

  def getRelativePath(absolutePathStr: String): String =
    import java.nio.file.Paths
    val absolutePath = Paths.get(absolutePathStr).toAbsolutePath()
    val currentDir = Paths.get(System.getProperty("user.dir")).toAbsolutePath()
    currentDir.relativize(absolutePath).toString

  extension (srcPos: util.SrcPos)(using Context)
    def positionTree: Tree =
      if (srcPos.span == util.Spans.NoSpan) ref(requiredMethod("dfhdl.internals.Position.unknown"))
      else
        val fileNameTree = Literal(Constant(getRelativePath(srcPos.startPos.source.path)))
        val lineStartTree = Literal(Constant(srcPos.startPos.line + 1))
        val columnStartTree = Literal(Constant(srcPos.startPos.column + 1))
        val lineEndTree = Literal(Constant(srcPos.endPos.line + 1))
        val columnEndTree = Literal(Constant(srcPos.endPos.column + 1))
        New(
          positionCls.typeRef,
          fileNameTree :: lineStartTree :: columnStartTree :: lineEndTree :: columnEndTree :: Nil
        )
  end extension

  object ApplyFunArgs:
    @tailrec private def recurUnapply(fun: Tree, args: List[List[Tree]])(using
        Context
    ): (Tree, List[List[Tree]]) =
      fun match
        case Apply(f, a) => recurUnapply(f, a :: args)
        case f           => (f, args)

    def unapply(tree: Apply)(using Context): Option[(Tree, List[List[Tree]])] =
      Some(recurUnapply(tree, Nil))

    def apply(fun: Tree, args: List[List[Tree]])(using Context): Apply =
      fun.appliedToArgss(args).asInstanceOf[Apply]

  override def prepareForUnit(tree: Tree)(using Context): Context =
    pluginDebugSource = tree.source.path.toString
    dfValSym = requiredClass("dfhdl.core.DFVal")
    metaContextTpe = requiredClassRef("dfhdl.internals.MetaContext")
    metaContextCls = requiredClass("dfhdl.internals.MetaContext")
    metaContextIgnoreAnnotSym = requiredClass("dfhdl.internals.metaContextIgnore")
    metaContextForwardAnnotSym = requiredClass("dfhdl.internals.metaContextForward")
    metaGenSym = requiredMethod("dfhdl.compiler.ir.Meta.gen")
    positionCls = requiredClass("dfhdl.internals.Position")
    hasDFCTpe = requiredClassRef("dfhdl.core.HasDFC")
    inlineAnnotSym = requiredClass("scala.inline")
    constModTpe = requiredClassRef("dfhdl.core.ISCONST").appliedTo(ConstantType(Constant(true)))
    contextFunctionSym = defn.FunctionSymbol(1, isContextual = true)
    genDesignParamSym = requiredMethod("dfhdl.core.r__For_Plugin.genDesignParam")
    if (debugFilter(tree.source.path.toString))
      println(
        s"""===============================================================
           |Before: $phaseName
           |===============================================================
           |""".stripMargin
      )
      println(tree.show)

    ctx
  end prepareForUnit

  override def transformUnit(tree: Tree)(using Context): Tree =
    pluginDebugSource = ""
    if (debugFilter(tree.source.path.toString))
      println(
        s"""===============================================================
           |After: $phaseName
           |===============================================================
           |""".stripMargin
      )
      println(tree.show)
    tree
end CommonPhase




© 2015 - 2024 Weber Informatics LLC | Privacy Policy