All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.meta.internal.trees.ast.scala Maven / Gradle / Ivy

Go to download

Bag of private and public helpers used in scala.meta's APIs and implementations

The newest version!
package scala.meta
package internal
package trees

import org.scalameta.internal.MacroCompat

import scala.annotation.StaticAnnotation
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.math.Ordered.orderingToOrdered
import scala.reflect.macros.whitebox.Context

// @ast is a specialized version of @org.scalameta.adt.leaf for scala.meta ASTs.
class ast extends StaticAnnotation {
  def macroTransform(annottees: Any*): Any = macro AstNamerMacros.impl
}

class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
  import AstNamerMacros._
  import c.universe.Flag._
  import c.universe._

  lazy val u: c.universe.type = c.universe
  lazy val mirror = c.mirror

  private class Mstats(
      val primary: ListBuffer[Tree] = ListBuffer.empty[Tree],
      val lowPrio: ListBuffer[Tree] = ListBuffer.empty[Tree]
  )

  def impl(annottees: Tree*): Tree = annottees.transformAnnottees(new ImplTransformer {
    override def transformClass(cdef: ClassDef, mdef: ModuleDef): List[ImplDef] = {
      val owner = c.internal.enclosingOwner
      val fullName = owner.fullName + "." + cdef.name.toString
      val isQuasi = isQuasiClass(cdef)
      // may not return other classes/modules at package level
      val isTopLevel = owner.isPackage
      val q"$imods class $iname[..$tparams] $ctorMods(...$rawparamss) extends { ..$earlydefns } with ..$iparents { $aself => ..$stats }" =
        cdef
      // NOTE: For stack traces, we'd like to have short class names, because stack traces print full names anyway.
      // However debugging macro expansion errors is much-much easier with full names for Api and Impl classes
      // because the typechecker only uses short names in error messages.
      // E.g. compare:
      //  class Impl needs to be abstract, since method withDenot in trait Name
      //  of type (denot: scala.meta.internal.semantic.Denotation)Impl.this.ThisType is not defined
      // and:
      //  class NameAnonymousImpl needs to be abstract, since method withDenot in trait Name
      //  of type (denot: scala.meta.internal.semantic.Denotation)NameAnonymousImpl.this.ThisType is not defined
      val descriptivePrefix = fullName.stripPrefix("scala.meta.").replace(".", "")
      val name = TypeName(descriptivePrefix + "Impl")
      val q"$mmods object $mname extends { ..$mearlydefns } with ..$mparents { $mself => ..$mstats }" =
        mdef
      val paramss1 = ListBuffer[List[ValDef]]() // payload params
      val iself = noSelfType
      val self = aself
      val istats1 = ListBuffer[Tree]()
      val stats1 = ListBuffer[Tree]()
      val ianns1 = ListBuffer[Tree]() ++ imods.annotations
      def imods1 = imods.mapAnnotations(_ => ianns1.toList)
      def mods1 = Modifiers(FINAL, mname.toTypeName, List(SerialVersionUIDAnnotation(1L)))
      val iparents1 = ListBuffer[Tree]() ++ iparents
      def parents1 = List(tq"$iname")
      val mstats1 = ListBuffer[Tree]() ++ mstats
      val mstatsLatest = ListBuffer[Tree]()
      val mstats1LowPriority = ListBuffer.empty[Tree]
      val mstatsLatestLowPriority = ListBuffer.empty[Tree]
      val manns1 = ListBuffer[Tree]() ++ mmods.annotations
      def mmods1 = mmods.mapAnnotations(_ => manns1.toList)
      val quasiCopyExtraParamss = ListBuffer[List[ValDef]]()
      val quasiExtraAbstractDefs = ListBuffer[ValOrDefDef]()

      // step 1: validate the shape of the class
      if (imods.hasFlag(SEALED)) c.abort(cdef.pos, "sealed is redundant for @ast classes")
      if (imods.hasFlag(FINAL)) c.abort(cdef.pos, "final is redundant for @ast classes")
      if (imods.hasFlag(CASE)) c.abort(cdef.pos, "case is redundant for @ast classes")
      if (imods.hasFlag(ABSTRACT)) c.abort(cdef.pos, "@ast classes cannot be abstract")
      if (ctorMods.flags != NoFlags) c
        .abort(cdef.pos, "@ast classes must define a public primary constructor")
      if (rawparamss.isEmpty) c
        .abort(cdef.pos, "@leaf classes must define a non-empty parameter list")
      if (rawparamss.lengthCompare(1) > 0) c
        .abort(cdef.pos, "@leaf classes must define a single parameter list")
      val params = rawparamss.head

      // step 1a: identify modified fields of the class
      val (versionedParams, paramsVersions) =
        if (isQuasi) (Nil, Nil) else getVersionedParams(params, stats)
      val replacedFields = versionedParams.flatMap(_.replaced.flatMap { field =>
        field.oldDefs.map { case (oldDef, _) => field.version -> oldDef }
      })
      val mstatsPerVersion = paramsVersions.map(ver => (ver, new Mstats()))
      def paramsForVersion(v: Version): List[ValDef] =
        positionVersionedParams(versionedParams.flatMap(_.getApplyDeclDefnBefore(v)._1))

      // step 2: validate the body of the class

      var needCopies = !isQuasi
      val importsBuilder = List.newBuilder[Import]
      val checkFieldsBuilder = List.newBuilder[Tree]
      val checkParentsBuilder = List.newBuilder[Tree]

      stats.foreach {
        case x: Import => importsBuilder += x
        case x: DefDef if !isQuasi && x.name == TermName("copy") => istats1 += x; needCopies = false
        case x: ValOrDefDef =>
          if (x.mods.hasFlag(Flag.ABSTRACT) || x.rhs.isEmpty) c
            .abort(x.pos, "definition without a value")
          val p = replacedFields.collectFirst { case (version, `x`) =>
            val mods = x.mods.mapAnnotations(getDeprecatedAnno(version) :: _)
            q"$mods def ${x.name}: ${x.tpt} = ${x.rhs}"
          }.getOrElse(x)
          if (x.mods.hasFlag(Flag.FINAL)) istats1 += p else quasiExtraAbstractDefs += p
        case q"checkFields($arg)" => checkFieldsBuilder += arg
        case x @ q"checkParent($what)" => checkParentsBuilder += x
        case x =>
          val error =
            "only checkFields(...), checkParent(...) and definitions are allowed in @ast classes"
          c.abort(x.pos, error)
      }
      val imports = importsBuilder.result()
      val fieldChecks = checkFieldsBuilder.result()
      val parentChecks = checkParentsBuilder.result()

      istats1 ++= imports
      mstats1 ++= imports
      stats1 ++= imports
      stats1 ++= quasiExtraAbstractDefs

      // step 4: implement the unimplemented methods in InternalTree (part 1)
      val privateFields = getPrivateFields(iname)
      val privateParams = privateFields.asList
      val bparams1 = privateParams.map(_.field)
      val privateApplyParams = privateParams.collect { case PrivateField(p, true) =>
        val annots = privateFieldAnnot :: p.mods.annotations
        val mods = Modifiers(p.mods.flags | OVERRIDE | DEFERRED, p.mods.privateWithin, annots)
        istats1 += declareGetter(p.name, p.tpt, mods)
        p
      }

      // step 5: turn all parameters into vars, create getters and setters
      params.foreach { p =>
        istats1 += declareGetter(p.name, p.tpt, astFieldAnnot :: p.mods.annotations)
        val pmods = if (p.mods.hasFlag(OVERRIDE)) Modifiers(OVERRIDE) else NoMods
        stats1 += defineGetter(p.name, p.tpt, pmods)
      }
      paramss1 += params.map { p =>
        val mods1 = p.mods.mkMutable.unPrivate.unOverride.unDefault
        q"$mods1 val ${internalize(p.name)}: ${p.tpt}"
      }

      // step 6: implement the unimplemented methods in InternalTree (part 1)
      // The purpose of privateCopy is to provide extremely cheap cloning
      // in the case when a tree changes its parent (because that happens often in our framework,
      // e.g. when we create a quasiquote and then insert it into a bigger quasiquote,
      // or when we parse something and build the trees from the ground up).
      // In such a situation, we copy all private state verbatim (tokens, denotations, etc)
      // and create lazy initializers that will take care of recursively copying the children.
      // Compare this with the `copy` method (described below), which additionally flushes the private state.
      // This method is private[meta] because the state that it's managing is not supposed to be touched
      // by the users of the framework.
      val privateCopyArgs = params
        .map(p => q"$CommonTyperMacrosModule.initField(this.${internalize(p.name)})")
      val privateCopyParentChecks =
        if (parentChecks.isEmpty) q""
        else q"""
            if (destination != null) {
              def checkParent(fn: ($name, $TreeClass, $StringClass) => $BooleanClass): $UnitClass = {
                val parentCheckOk = fn(this, parent, destination)
                if (!parentCheckOk) {
                  val parentPrefix = parent.productPrefix
                  _root_.org.scalameta.invariants.require(parentCheckOk && _root_.org.scalameta.debug(this, parentPrefix, destination))
                }
              }
              ..$parentChecks
            }
          """
      stats1 +=
        q"""
        private[meta] def privateCopy(
            prototype: $TreeClass = this,
            parent: $TreeClass = ${privateFields.parent.field.name},
            destination: $StringClass = null,
            origin: $OriginClass = ${privateFields.origin.field.name}): Tree = {
          $privateCopyParentChecks
          $DataTyperMacrosModule.nullCheck(origin)
          new $name(prototype.asInstanceOf[$iname], parent, origin)(..$privateCopyArgs)
        }
      """
      // step 7: create the copy method
      // The purpose of this method is to provide a facility to change small parts of the tree
      // without modifying the other parts, much like the standard case class copy works.
      // In such a situation, the tree is going to be recreated.
      // NOTE: Can't generate XXX.Quasi.copy, because XXX.Quasi already inherits XXX.copy,
      // and there can't be multiple overloaded methods with default parameters.
      // Not a big deal though, since XXX.Quasi is an internal class.
      def getParamArg(p: ValOrDefDef) = q"${p.name}"
      def addCopy(params: List[ValDef], annots: Tree*) = {
        val mods = getDeferredModifiers(annots.toList)
        istats1 +=
          q"""
            $mods def copy(..$params): $iname
          """
        val args = params.map(getParamArg)
        stats1 +=
          q"""
            final override def copy(..$params): $iname = {
              $mname.apply(..$args)
            }
          """
        quasiCopyExtraParamss += params
      }
      if (!isQuasi) {
        def getCopyParamWithDefault(p: ValOrDefDef): ValDef = asValDefn(p, q"this.${p.name}")
        val fullCopyParams = params.map(getCopyParamWithDefault)
        val iFullCopy = q"private[meta] def fullCopy(..$fullCopyParams): $iname"
        istats1 += iFullCopy
        quasiExtraAbstractDefs += iFullCopy
        stats1 +=
          q"""
            private[meta] final override def fullCopy(..$fullCopyParams): $iname = {
              $mname.apply(..${params.map(_.name)})
            }
          """
        if (needCopies)
          if (versionedParams.isEmpty) addCopy(fullCopyParams)
          else {
            // add primary copy with default values
            val defaultCopyParams =
              positionVersionedParams(versionedParams.flatMap(_.getDefaultCopyDef()))
                .map(getCopyParamWithDefault)
            addCopy(defaultCopyParams)

            val defaultCopyParamNames = defaultCopyParams.map(_.name.toString).toSet
            def allInDefaults(cp: Iterable[ValDef]): Boolean = cp
              .forall(x => defaultCopyParamNames.contains(x.name.toString))

            // add full copy without defaults
            if (!allInDefaults(params)) addCopy(params.map(asValDecl))
            // add secondary copy
            paramsVersions.foreach { version =>
              val copyParams = paramsForVersion(version)
              if (copyParams.length != defaultCopyParams.length || !allInDefaults(copyParams))
                addCopy(copyParams.map(asValDecl), getDeprecatedAnno(version))
            }
          }
      }

      // step 7a: override the Object and Equals methods
      if (!isQuasi) {
        istats1 +=
          q"final override def canEqual(that: Any): $BooleanClass = that.isInstanceOf[$iname]"
        istats1 +=
          q"final override def equals(that: Any): $BooleanClass = this eq that.asInstanceOf[AnyRef]"
        istats1 += q"final override def hashCode: $IntClass = System.identityHashCode(this)"
        istats1 +=
          q"final override def toString: $StringClass = scala.meta.internal.prettyprinters.TreeToString(this)"
      }

      // step 8: create the children method
      stats1 +=
        q"def children: $ListClass[$TreeClass] = $CommonTyperMacrosModule.children[$iname, $TreeClass]"

      // step 9: generate boilerplate required by the @ast infrastructure
      ianns1 += q"new $AstMetadataModule.astClass"
      ianns1 += q"new $AdtMetadataModule.leafClass"
      manns1 += q"new $AstMetadataModule.astCompanion"
      manns1 += q"new $AdtMetadataModule.leafCompanion"

      // step 10: generate boilerplate required by the classifier infrastructure
      mstats1 ++= mkClassifier(iname)
      mstats1 += mkAstInfo(iname)

      // step 11: implement Product
      iparents1 += tq"$ProductClass"

      stats1 +=
        q"override def productPrefix: $StringClass = $CommonTyperMacrosModule.productPrefix[$iname]"
      stats1 += q"override def productArity: $IntClass = ${params.length}"

      def patternMatchClauses(fromField: (ValDef, Int) => Tree) = {
        val pelClauses = ListBuffer[Tree]()
        pelClauses ++= params.zipWithIndex.map(fromField.tupled)
        pelClauses += cq"_ => throw new $IndexOutOfBoundsException(n.toString)"
        pelClauses.toList
      }

      val pelClauses = patternMatchClauses((vr, i) => cq"$i => this.${vr.name}")
      stats1 += q"override def productElement(n: $IntClass): Any = n match { case ..$pelClauses }"
      stats1 +=
        q"override def productIterator: $IteratorClass[$AnyClass] = $ScalaRunTimeModule.typedProductIterator(this)"
      val productFields = params.map(_.name.toString)
      stats1 +=
        q"override def productFields: $ListClass[$StringClass] = _root_.scala.List(..$productFields)"

      // step 13a add productElementName for 2.13
      if (MacroCompat.productFieldNamesAvailable) {
        val penClauses = patternMatchClauses { (vr, i) =>
          val lit = Literal(Constant(vr.name.toString()))
          cq"""$i => $lit """
        }
        stats1 +=
          q"override def productElementName(n: $IntClass): java.lang.String = n match { case ..$penClauses }"
      }

      // step 12: generate serialization logic
      stats1 +=
        q"""
          protected def writeReplace(): $AnyRefClass = {
            ..${params.map(loadField)}
            this
          }
        """

      // step 13: generate Companion.apply
      val internalBody = ListBuffer[Tree]()
      internalBody += q"$CommonTyperMacrosModule.hierarchyCheck[$iname]"
      params.foreach { p =>
        val local = p.name
        internalBody += q"$DataTyperMacrosModule.nullCheck($local)"
        internalBody += q"$DataTyperMacrosModule.emptyCheck($local)"
      }
      internalBody ++= imports
      fieldChecks.foreach { x =>
        val fieldCheck = q"_root_.org.scalameta.invariants.require($x)"
        var hasErrors = false
        object errorChecker extends Traverser {
          private val nmeParent = TermName("parent")
          override def traverse(tree: Tree): Unit = tree match {
            case _: This =>
              hasErrors = true; c.error(tree.pos, "cannot refer to this in @ast field checks")
            case Ident(`nmeParent`) =>
              hasErrors = true
              c.error(
                tree.pos,
                "cannot refer to parent in @ast field checks; use checkParent instead"
              )
            case _ => super.traverse(tree)
          }
        }
        errorChecker.traverse(fieldCheck)
        if (!hasErrors) internalBody += fieldCheck
      }
      val paramInits = params.map(p => q"$CommonTyperMacrosModule.initParam(${p.name})")
      privateParams.foreach { p =>
        if (p.persist) internalBody += q"$DataTyperMacrosModule.nullCheck(${p.field.name})"
        else internalBody += asValDefn(p.field)
      }
      val internalArgs = params.map(getParamArg)
      val bparamCtorArgs = bparams1.map { p =>
        if (p eq privateFields.origin.field) q"""
               $OriginModule.first(
                 alternativeOrigin,
                 $OriginModule.DialectOnly.getFromArgs(..$internalArgs)
               )
             """
        else getParamArg(p)
      }
      internalBody +=
        q"""
          val node = new $name(
            ..$bparamCtorArgs
          )(
            ..$paramInits
          )
        """
      params.foreach(p => internalBody += storeField(p))
      internalBody += q"node"
      val applyParamDefns = params.map(asValDefn)
      val applyParamDecls = params.map(asValDecl)
      val bparamDecls = privateApplyParams.map(asValDecl)
      val fullApplyParamDecls = bparamDecls ++ applyParamDecls
      val fullInternalArgs = privateApplyParams.map(getParamArg) ++ internalArgs
      val bparamRhsInternalArgs = privateApplyParams.map(p => p.rhs) ++ internalArgs
      if (isTopLevel) {
        mstats1 +=
          q"""
            def apply(..$applyParamDefns): $iname = {
              $mname.apply(..$bparamRhsInternalArgs)
            }
          """
        mstats1 +=
          q"""
            def apply(..$fullApplyParamDecls): $iname = {
              val alternativeOrigin = origin
              ..$internalBody
            }
          """
      } else {
        mstats1 +=
          q"""
            def apply(..$applyParamDefns)(implicit dialect: $DialectClass): $iname = {
              $mname.apply(..$bparamRhsInternalArgs)
            }
          """
        mstats1 +=
          q"""
            def apply(..$fullApplyParamDecls)(implicit dialect: $DialectClass): $iname = {
              val alternativeOrigin =
                $OriginModule.first(origin, implicitly[$OriginModule.DialectOnly])
              ..$internalBody
            }
          """
        mstats1LowPriority +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$applyParamDecls): $iname = {
              $mname.apply(..$bparamRhsInternalArgs)
            }
          """
        mstats1LowPriority +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$fullApplyParamDecls): $iname = {
              $mname.apply(..$fullInternalArgs)
            }
          """
      }
      mstatsLatest +=
        q"""
          @$InlineAnnotation def apply(..$fullApplyParamDecls)(implicit dialect: $DialectClass): $iname =
            $mname.apply(..$fullInternalArgs)
        """
      mstatsLatestLowPriority +=
        q"""
          @$InlineAnnotation @$deprecatedSince_4_9_0 def apply(..$fullApplyParamDecls): $iname =
            $mname.apply(..$fullInternalArgs)
        """
      mstatsLatest +=
        q"""
          @$InlineAnnotation def apply(..$applyParamDefns)(implicit dialect: $DialectClass): $iname =
            $mname.apply(..$internalArgs)
        """
      mstatsLatestLowPriority +=
        q"""
          @$InlineAnnotation @$deprecatedSince_4_9_0 def apply(..$applyParamDecls): $iname =
            $mname.apply(..$internalArgs)
        """

      // step 13a: generate additional companion apply for added and replaced fields
      // generate new applies for each new field added
      // with field A, B and additional binary compat ones C, D and E, we generate:
      // apply(A, B, C), apply(A, B, C, D), apply(A, B, C, D, E)
      mstatsPerVersion.foreach { case (v, verMstats) =>
        val applyParamsBuilder = List.newBuilder[(ValDef, Int)]
        val applyBodyBuilder = List.newBuilder[Tree]
        versionedParams.foreach { vp =>
          val (decl, defn) = vp.getApplyDeclDefnBefore(v)
          decl.foreach(applyParamsBuilder += _)
          defn.foreach(applyBodyBuilder += _)
        }
        val paramDefns = positionVersionedParams(applyParamsBuilder.result())
        val applyBody = applyBodyBuilder.result()
        val applyCall = q"$mname.apply(..$internalArgs)"
        val paramDecls = paramDefns.map(asValDecl)
        val fullParamDecls = bparamDecls ++ paramDecls
        val fullParamDeclNames = fullParamDecls.map(_.name)
        verMstats.lowPrio +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$fullParamDecls): $iname = {
              $mname.apply(..$fullParamDeclNames)
            }
          """
        verMstats.primary +=
          q"""
            def apply(..$fullParamDecls)(implicit dialect: $DialectClass): $iname = {
              $mname.apply(..$fullParamDeclNames)
            }
          """
        verMstats.lowPrio +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$paramDecls): $iname = {
              ..$applyBody
              $applyCall
            }
          """
        verMstats.primary +=
          q"""
            def apply(..$paramDefns)(implicit dialect: $DialectClass): $iname = {
              ..$applyBody
              $applyCall
            }
          """
        if (isTopLevel) {
          mstats1 +=
            q"""
              def apply(..$fullParamDecls): $iname = {
                ..$applyBody
                $mname.apply(..$fullInternalArgs)
              }
            """
          mstats1 +=
            q"""
              @${getDeprecatedAnno(v)} def apply(..$paramDecls): $iname = {
                ..$applyBody
                $applyCall
              }
            """

        } else {
          mstats1LowPriority +=
            q"""
              @$deprecatedSince_4_9_0 def apply(..$fullParamDecls): $iname = {
                ..$applyBody
                $mname.apply(..$fullInternalArgs)
              }
            """
          mstats1LowPriority +=
            q"""
              @${getDeprecatedAnno(v)} def apply(..$paramDecls): $iname = {
                ..$applyBody
                $applyCall
              }
            """
          mstats1 +=
            q"""
              def apply(..$fullParamDecls)(implicit dialect: $DialectClass): $iname = {
                ..$applyBody
                $mname.apply(..$fullInternalArgs)
              }
            """
          mstats1 +=
            q"""
              @${getDeprecatedAnno(v)} def apply(..$paramDecls)(implicit dialect: $DialectClass): $iname = {
                ..$applyBody
                $applyCall
              }
            """
        }
      }

      // step 14: generate Companion.unapply
      val needsUnapply = !mstats.exists {
        case DefDef(_, TermName("unapply"), _, _, _, _) => true
        case _ => false
      }
      if (needsUnapply) {
        def getUnapply(unapplyParams: List[ValDef], annots: Tree*): Tree =
          if (unapplyParams.isEmpty) q"""
                @$InlineAnnotation @..$annots final def unapply(x: $iname): $BooleanClass =
                  x != null && x.isInstanceOf[$name]
              """
          else {
            val successTargs = tq"(..${unapplyParams.map(p => p.tpt)})"
            val successArgs = q"(..${unapplyParams.map(p => q"x.${p.name}")})"
            q"""
                @$InlineAnnotation @..$annots final def unapply(x: $iname): $OptionClass[$successTargs] =
                  if (x != null && x.isInstanceOf[$name]) $SomeModule($successArgs) else $NoneModule
              """
          }
        val latestTree = getUnapply(params)
        mstatsPerVersion match {
          case (headVer, headMstats) :: tail =>
            val headParams = paramsForVersion(headVer)
            headMstats.primary += getUnapply(headParams)
            tail.foreach { case (ver, verMstats) =>
              verMstats.primary += getUnapply(paramsForVersion(ver))
            }
            val afterLastVer = getAfterVersion(mstatsPerVersion.last._1)
            val anno = getDeprecatedAnno(headVer, s"; use `.$afterLastVer`")
            mstats1 += getUnapply(headParams, anno)
          case Nil => mstats1 += latestTree
        }
        mstatsLatest += latestTree
      }

      // step 15: finish codegen for Quasi
      if (isQuasi) stats1 +=
        q"""
          def become[T <: $TreeClass](implicit ev: $AstInfoClass[T]): T with $QuasiClass = {
            (this match {
              case $mname(0, tree) =>
                ev.quasi(0, tree)
              case $mname(rank, nested @ $mname(0, tree)) =>
                ev.quasi(rank, nested.become[T])
              case _ =>
                throw new Exception("complex ellipses are not supported yet")
            }).withOrigin(this.origin): T with $QuasiClass
          }
        """
      else mstats1 += mkQuasi(
        iname,
        iparents,
        params,
        quasiCopyExtraParamss,
        quasiExtraAbstractDefs.result(),
        "name",
        "value",
        "tpe"
      )

      val latestName = mstatsPerVersion
        .foldLeft(initialName) { case (afterPrevVerName, (ver, verMstats)) =>
          val lowPriority = TypeName(afterPrevVerName + "LowPriority")
          mstats1 += q"private[meta] trait $lowPriority { ..${verMstats.lowPrio} }"
          val afterPrevVer = TermName(afterPrevVerName)
          mstats1 += q"object $afterPrevVer extends $lowPriority { ..${verMstats.primary} }"
          getAfterVersion(ver)
        }
      val latestTermName = TermName(latestName)
      val latestLowPriority = TypeName(latestName + "LowPriority")
      mstats1 += q"private[meta] trait $latestLowPriority { ..$mstatsLatestLowPriority }"
      mstats1 += q"object $latestTermName extends $latestLowPriority { ..$mstatsLatest }"
      // to be ignored by Mima, use "internal"
      mstats1 += q"object internal { final val Latest = $latestTermName }"

      mstats1 += q"$mods1 class $name[..$tparams] $ctorMods(...${bparams1 +:
          paramss1}) extends { ..$earlydefns } with ..$parents1 { $self => ..$stats1 }"

      val res = ListBuffer.empty[ImplDef]

      val mparents1 =
        if (isTopLevel) mparents
        else {
          val lowPriority = TypeName(iname.toString() + "LowPriority")
          res += q"private[meta] trait $lowPriority { ..$mstats1LowPriority }"
          mparents :+ tq"$lowPriority"
        }
      val cdef1 = q"$imods1 trait $iname extends ..$iparents1 { $iself => ..$istats1 }"
      res += cdef1
      val mdef1 =
        q"$mmods1 object $mname extends { ..$mearlydefns } with ..$mparents1 { $mself => ..$mstats1 }"
      res += mdef1
      if (c.compilerSettings.contains("-Xprint:typer")) { println(cdef1); println(mdef1) }
      res.result()
    }
  })

  private def internalize(name: String): TermName = TermName(s"_${name.stripPrefix("_")}")
  private def internalize(name: TermName): TermName = internalize(name.toString)
  private def setterName(name: String): TermName =
    TermName(s"set${name.stripPrefix("_").capitalize}")
  private def setterName(name: TermName): TermName = setterName(name.toString)
  private def setterName(vr: ValOrDefDef): TermName = setterName(vr.name)
  private def getterName(name: String): TermName = TermName(s"${name.stripPrefix("_")}")
  private def getterName(name: TermName): TermName = getterName(name.toString)
  private def getterName(vr: ValOrDefDef): TermName = getterName(vr.name)

  private def loadField(vr: ValOrDefDef): Tree = loadField(vr.name)
  private def loadField(name: TermName): Tree = loadField(internalize(name), name)
  private def loadField(internalName: TermName, name: TermName): Tree = q"""
      $CommonTyperMacrosModule.loadField(this.$internalName, ${name.decodedName.toString})
    """

  private def storeField(vr: ValOrDefDef): Tree = storeField(vr.name)
  private def storeField(name: TermName): Tree = storeField(internalize(name), name)
  private def storeField(internalName: TermName, name: TermName): Tree = q"""
      $CommonTyperMacrosModule.storeField(node.$internalName, $name, ${name.decodedName.toString})
    """

  private def getDeferredModifiers(annots: List[Tree]): Modifiers =
    Modifiers(DEFERRED, typeNames.EMPTY, annots)

  private val astFieldAnnot = q"new $AstMetadataModule.astField"
  private val privateFieldAnnot = q"new $AdtMetadataModule.privateField"

  private def declareGetter(name: TermName, tpe: Tree, annots: List[Tree]): Tree =
    declareGetter(name, tpe, getDeferredModifiers(annots))

  private def declareGetter(name: TermName, tpe: Tree, mods: Modifiers): Tree =
    q"$mods def ${getterName(name)}: $tpe"

  private def defineGetter(name: TermName, tpe: Tree, mods: Modifiers): Tree = {
    val internalName = internalize(name)
    q"""
      $mods def ${getterName(name)}: $tpe = {
        ${loadField(internalName, name)}
        this.$internalName
      }
    """
  }

  private def declareSetter(name: TermName, tpe: Tree, annots: List[Tree]): Tree =
    declareSetter(name, tpe, getDeferredModifiers(annots))

  private def declareSetter(name: TermName, tpe: Tree, mods: Modifiers): Tree =
    q"$mods def ${setterName(name)}($name : $tpe): Unit"

  private def defineSetter(name: TermName, tpe: Tree, mods: Modifiers): Tree = q"""
      $mods def ${setterName(name)}($name : $tpe): Unit = {
        val node = this
        ${storeField(name)}
      }
    """

  private class VersionedParam(
      val param: ValDef,
      val appended: Option[Version],
      val replaced: Seq[ReplacedField]
  ) {
    appended.foreach { aver =>
      replaced.headOption.foreach { rfield =>
        if (rfield.version <= aver) {
          val oldDef = rfield.oldDefs.head._1
          c.abort(
            param.pos,
            s"$aver [@newField for ${param.name}] must must precede " +
              s"${rfield.version} [@replacedField for ${oldDef.name}]"
          )
        }
      }
    }

    def getApplyDeclDefnBefore(version: Version): (List[(ValDef, Int)], Option[ValDef]) = {
      def checkVersion(ver: Version): Boolean = version <= ver
      if (appended.exists(checkVersion)) (Nil, Some(asValDefn(param)))
      else replaced.find(x => checkVersion(x.version)).map { rfield =>
        val decls = rfield.oldDefs.map { case (oldDef, pos) => asValDecl(oldDef) -> pos }
        (decls, Some(rfield.newValDefn))
      }.getOrElse((asValDefn(param) -> -1 :: Nil, None))
    }
    def getDefaultCopyDef(): List[(ValOrDefDef, Int)] = replaced.headOption.map(_.oldDefs)
      .getOrElse((param, -1) :: Nil)
  }

  private def positionVersionedParams[A](params: List[(A, Int)]): List[A] = {
    val res = new ListBuffer[A]
    val paramIter = params.iterator.filter(_._2 < 0)
    @tailrec
    def iter(withPositions: List[(A, Int)]): Unit = withPositions match {
      case (v, pos) :: rest =>
        paramIter.take(pos - res.length).foreach { case (x, _) => res += x }
        res += v
        iter(rest)
      case _ => paramIter.foreach { case (x, _) => res += x }
    }
    iter(params.filter(_._2 >= 0).sortBy(_._2))
    res.toList
  }

  private def getVersionedParams(
      params: List[ValDef],
      stats: List[Tree]
  ): (List[VersionedParam], List[Version]) = {
    val appendedFields: Map[String, Version] = getNewFieldVersions(params)
    val replacedFields: Map[String, Seq[ReplacedField]] = ReplacedField.getMap(params, stats)
    val versionsBuilder = Set.newBuilder[Version]
    appendedFields.values.foreach(versionsBuilder += _)
    replacedFields.values.foreach(_.foreach(versionsBuilder += _.version))
    val versions = versionsBuilder.result().toList.sorted
    val versionedParams = params.map { p =>
      val pname = p.name.toString
      val appended = appendedFields.get(pname)
      val replaced = replacedFields.getOrElse(pname, Seq.empty)
      new VersionedParam(p, appended, replaced)
    }
    (versionedParams, versions)
  }

  private def getAnnotAttribute(value: Tree): String = value match {
    case x: AssignOrNamedArg => x.rhs.toString
    case x => x.toString
  }

  private def parseVersionAnnot(version: Tree, annot: String, field: String): Version = {
    val parsed = Version.parse(getAnnotAttribute(version).stripPrefix("\"").stripSuffix("\""))
      .getOrElse(c.abort(version.pos, s"@$annot must contain $field=major.minor.patch"))
    buildVersion.foreach { bv =>
      if (parsed.major < bv.major) c
        .abort(version.pos, s"@$annot: obsolete, old major version (must be ${bv.major})")
      if (parsed > bv) c
        .abort(version.pos, s"@$annot can't refer to future versions (current is $bv)")
    }
    parsed
  }

  private def getNewFieldVersions(params: List[ValDef]): Map[String, Version] = {
    val builder = Map.newBuilder[String, Version]
    var prevVersion: Version = null
    params.foreach { x =>
      val sinceOpt = x.mods.annotations.collectFirst { case q"new newField($since)" => since }
      if (sinceOpt.isEmpty && prevVersion != null) c
        .abort(x.pos, "must be marked @newField since previous field is")
      sinceOpt.foreach { since =>
        if (x.mods.hasFlag(Flag.OVERRIDE)) c
          .abort(x.pos, "override fields may not be marked @newField")
        if (x.rhs == EmptyTree) c.abort(x.pos, "@newField fields must provide a default value")
        val version = parseVersionAnnot(since, "newField", "after")
        if (null != prevVersion && version < prevVersion) c
          .abort(x.pos, s"previous field marked with newer version: $prevVersion")
        prevVersion = version
        builder += x.name.toString -> version
      }
    }
    builder.result()
  }

  private class ReplacedField(
      val version: Version,
      val newVal: ValDef,
      ctor: Tree,
      val oldDefs: List[(ValOrDefDef, Int)]
  ) {
    def newValDefn: ValDef = {
      def bodyForSingleOldDef(oldDef: ValOrDefDef) =
        if (ctor eq null) q"""
            import scala.meta.trees._
            ${oldDef.name}
           """
        else q"""
            $ctor(${oldDef.name})
           """
      def bodyForMultipleOldDefs = {
        if (ctor eq null) c.abort(newVal.pos, s"${newVal.name} must define a ctor")
        val names = oldDefs.map { case (oldDef, _) =>
          val name = q"${oldDef.name}"
          val arg = AssignOrNamedArg(name, name)
          q"$arg"
        }
        q"$ctor(..$names)"
      }
      val body = oldDefs match {
        case (oldDef, _) :: Nil => bodyForSingleOldDef(oldDef)
        case _ => bodyForMultipleOldDefs
      }
      q"""
        val ${newVal.name}: ${newVal.tpt} = {
          ..$body
        }
      """
    }
  }

  private object ReplacedField {
    def getMap(params: List[ValDef], stats: List[Tree]): Map[String, Seq[ReplacedField]] = {
      val fields: Map[String, (ValDef, Map[Version, Tree])] = params.map { p =>
        val ctorsByVersion = p.mods.annotations
          .collect { case q"new replacesFields($since, $ctor)" =>
            val version = parseVersionAnnot(since, "replacesFields", "after")
            version -> ctor
          }.toMap
        p.name.toString -> (p, ctorsByVersion)
      }.toMap
      val replacedFields = stats.flatMap {
        case p: ValOrDefDef =>
          val anno = p.mods.annotations.collectFirst {
            case q"new replacedField($until)" => (until, -1)
            case q"new replacedField($until, $pos)" => (until, getAnnotAttribute(pos).toInt)
          }
          anno.map { case (until, pos) =>
            if (!p.mods.hasFlag(Flag.FINAL)) c
              .abort(p.pos, "replacedField-annotated fields must be final")
            val version = parseVersionAnnot(until, "replacedField", "until")
            val newField = getNewField(p)
            newField -> (p, version, pos)
          }
        case _ => None
      }
      replacedFields.groupBy(_._1).map { case (k, v) =>
        val (newVal, ctorsByVersion) = fields
          .getOrElse(k, c.abort(v.head._2._1.pos, s"@replacedField: field `$k` is undefined)"))
        val replacements = v.map(_._2).groupBy(_._2).toSeq.map { case (ver, oldFields) =>
          val ctor = ctorsByVersion.get(ver).orNull
          val oldDefs = oldFields.map { case (oldField, _, pos) => (oldField, pos) }
          new ReplacedField(ver, newVal, ctor, oldDefs)
        }
        k -> replacements.sortBy(_.version)
      }
    }

    private def getNewField(oldDef: ValOrDefDef): String = {
      @tailrec
      def iter(tree: Tree): Option[String] = tree match {
        case Select(Ident(TermName(newField)), _: TermName) => Some(newField)
        case Apply(_, Ident(TermName(newField)) :: Nil) => Some(newField)
        case Match(Ident(TermName(newField)), _) => Some(newField)
        case Select(x, _) => iter(x)
        case Apply(x, (_: Function) :: Nil) => iter(x)
        case _ => None
      }
      iter(oldDef.rhs).getOrElse(
        c.abort(oldDef.pos, s"@replacedField: can't find new field name (${showRaw(oldDef.rhs)})")
      )
    }
  }

  private val deprecatedSince_4_9_0 = getDeprecatedAnno("4.9.0")

  private def getDeprecatedAnno(v: Version, why: String = ""): Tree =
    getDeprecatedAnno(v.toString + why)
  private def getDeprecatedAnno(since: String): Tree =
    q"new scala.deprecated(${Literal(Constant(since))})"

  private def getAfterVersion(v: Version) = afterNamePrefix + v.asString('_')

  private def asValDecl(p: ValOrDefDef): ValDef =
    q"@..${p.mods.annotations} val ${p.name}: ${p.tpt}"
  private def asValDefn(p: ValOrDefDef): ValDef = asValDefn(p, p.rhs)
  private def asValDefn(p: ValOrDefDef, rhs: Tree): ValDef =
    q"@..${p.mods.annotations} val ${p.name}: ${p.tpt} = $rhs"

}

object AstNamerMacros {

  private val buildVersion: Option[Version] = {
    val bv = BuildInfo.version
    val idx = bv.indexWhere(x => x == '-' || x == '+')
    val version = if (idx < 0) bv else bv.substring(0, idx)
    // filter in case buildVersion is incorrectly set (Windows forces 0.0.0)
    Version.parse(version).toOption.filter(_ != Version.zero)
  }

  val initialName = "Initial"
  val afterNamePrefix = "After_"

  def getLatestAfterName(moduleNames: Iterable[String]): Option[String] = {
    var maxVersion = Version.zero
    var maxName: Option[String] = None
    moduleNames.foreach { name =>
      if (name == initialName) { if (maxName.isEmpty) maxName = Some(name) }
      else if (name.startsWith(afterNamePrefix)) Version
        .parse(name.substring(afterNamePrefix.length), '_').toOption.foreach { v =>
          if (v > maxVersion) {
            maxName = Some(name)
            maxVersion = v
          }
        }
    }
    maxName
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy