All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.meta.internal.trees.ast.scala Maven / Gradle / Ivy

package scala.meta
package internal
package trees

import org.scalameta.internal.MacroCompat

import scala.annotation.StaticAnnotation
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.math.Ordered.orderingToOrdered
import scala.reflect.macros.whitebox.Context

// @ast is a specialized version of @org.scalameta.adt.leaf for scala.meta ASTs.
class ast extends StaticAnnotation {
  def macroTransform(annottees: Any*): Any = macro AstNamerMacros.impl
}

class AstNamerMacros(val c: Context) extends Reflection with CommonNamerMacros {
  import AstNamerMacros._
  import c.universe.Flag._
  import c.universe._

  lazy val u: c.universe.type = c.universe
  lazy val mirror = c.mirror

  private class Mstats(
      val primary: ListBuffer[Tree] = ListBuffer.empty[Tree],
      val lowPrio: ListBuffer[Tree] = ListBuffer.empty[Tree]
  )

  def impl(annottees: Tree*): Tree = annottees.transformAnnottees(new ImplTransformer {
    override def transformClass(cdef: ClassDef, mdef: ModuleDef): List[ImplDef] = {
      val owner = c.internal.enclosingOwner
      val fullName = owner.fullName + "." + cdef.name.toString
      val isQuasi = isQuasiClass(cdef)
      // may not return other classes/modules at package level
      val isTopLevel = owner.isPackage
      val q"$imods class $iname[..$tparams] $ctorMods(...$rawparamss) extends { ..$earlydefns } with ..$iparents { $aself => ..$stats }" =
        cdef
      // NOTE: For stack traces, we'd like to have short class names, because stack traces print full names anyway.
      // However debugging macro expansion errors is much-much easier with full names for Api and Impl classes
      // because the typechecker only uses short names in error messages.
      // E.g. compare:
      //  class Impl needs to be abstract, since method withDenot in trait Name
      //  of type (denot: scala.meta.internal.semantic.Denotation)Impl.this.ThisType is not defined
      // and:
      //  class NameAnonymousImpl needs to be abstract, since method withDenot in trait Name
      //  of type (denot: scala.meta.internal.semantic.Denotation)NameAnonymousImpl.this.ThisType is not defined
      val descriptivePrefix = fullName.stripPrefix("scala.meta.").replace(".", "")
      val name = TypeName(descriptivePrefix + "Impl")
      val q"$mmods object $mname extends { ..$mearlydefns } with ..$mparents { $mself => ..$mstats }" =
        mdef
      val paramss1 = ListBuffer[List[ValDef]]() // payload params
      val iself = noSelfType
      val self = aself
      val istats1 = ListBuffer[Tree]()
      val stats1 = ListBuffer[Tree]()
      val ianns1 = ListBuffer[Tree]() ++ imods.annotations
      def imods1 = imods.mapAnnotations(_ => ianns1.toList)
      def mods1 = Modifiers(FINAL, mname.toTypeName, List(SerialVersionUIDAnnotation(1L)))
      val iparents1 = ListBuffer[Tree]() ++ iparents
      def parents1 = List(tq"$iname")
      val mstats1 = ListBuffer[Tree]() ++ mstats
      val mstatsLatest = ListBuffer[Tree]()
      val mstats1LowPriority = ListBuffer.empty[Tree]
      val mstatsLatestLowPriority = ListBuffer.empty[Tree]
      val manns1 = ListBuffer[Tree]() ++ mmods.annotations
      def mmods1 = mmods.mapAnnotations(_ => manns1.toList)
      val quasiCopyExtraParamss = ListBuffer[List[ValDef]]()
      val quasiExtraAbstractDefs = ListBuffer[ValOrDefDef]()

      // step 1: validate the shape of the class
      if (imods.hasFlag(SEALED)) c.abort(cdef.pos, "sealed is redundant for @ast classes")
      if (imods.hasFlag(FINAL)) c.abort(cdef.pos, "final is redundant for @ast classes")
      if (imods.hasFlag(CASE)) c.abort(cdef.pos, "case is redundant for @ast classes")
      if (imods.hasFlag(ABSTRACT)) c.abort(cdef.pos, "@ast classes cannot be abstract")
      if (ctorMods.flags != NoFlags) c
        .abort(cdef.pos, "@ast classes must define a public primary constructor")
      if (rawparamss.isEmpty) c
        .abort(cdef.pos, "@leaf classes must define a non-empty parameter list")
      if (rawparamss.lengthCompare(1) > 0) c
        .abort(cdef.pos, "@leaf classes must define a single parameter list")
      val params = rawparamss.head

      // step 1a: identify modified fields of the class
      val (versionedParams, paramsVersions) =
        if (isQuasi) (Nil, Nil) else getVersionedParams(params, stats)
      val replacedFields = versionedParams.flatMap(_.replaced.flatMap { field =>
        field.oldDefs.map { case (oldDef, _) => field.version -> oldDef }
      })
      val mstatsPerVersion = paramsVersions.map(ver => (ver, new Mstats()))
      def paramsForVersion(v: Version): List[ValDef] =
        positionVersionedParams(versionedParams.flatMap(_.getApplyDeclDefnBefore(v)._1))

      // step 2: validate the body of the class

      var needCopies = !isQuasi
      val importsBuilder = List.newBuilder[Import]
      val checkFieldsBuilder = List.newBuilder[Tree]
      val checkParentsBuilder = List.newBuilder[Tree]

      stats.foreach {
        case x: Import => importsBuilder += x
        case x: DefDef if !isQuasi && x.name == TermName("copy") => istats1 += x; needCopies = false
        case x: ValOrDefDef =>
          if (x.mods.hasFlag(Flag.ABSTRACT) || x.rhs.isEmpty) c
            .abort(x.pos, "definition without a value")
          val p = replacedFields.collectFirst { case (version, `x`) =>
            val mods = x.mods.mapAnnotations(getDeprecatedAnno(version) :: _)
            q"$mods def ${x.name}: ${x.tpt} = ${x.rhs}"
          }.getOrElse(x)
          if (x.mods.hasFlag(Flag.FINAL)) istats1 += p else quasiExtraAbstractDefs += p
        case q"checkFields($arg)" => checkFieldsBuilder += arg
        case x @ q"checkParent($what)" => checkParentsBuilder += x
        case x =>
          val error =
            "only checkFields(...), checkParent(...) and definitions are allowed in @ast classes"
          c.abort(x.pos, error)
      }
      val imports = importsBuilder.result()
      val fieldChecks = checkFieldsBuilder.result()
      val parentChecks = checkParentsBuilder.result()

      istats1 ++= imports
      mstats1 ++= imports
      stats1 ++= imports
      stats1 ++= quasiExtraAbstractDefs

      // step 4: implement the unimplemented methods in InternalTree (part 1)
      val privateFields = getPrivateFields(iname)
      val privateParams = privateFields.asList
      val bparams1 = privateParams.map(_.field)
      val privateApplyParams = privateParams.collect { case PrivateField(p, true) =>
        val annots = privateFieldAnnot :: p.mods.annotations
        val mods = Modifiers(p.mods.flags | OVERRIDE | DEFERRED, p.mods.privateWithin, annots)
        istats1 += declareGetter(p.name, p.tpt, mods)
        p
      }

      // step 5: turn all parameters into vars, create getters and setters
      params.foreach { p =>
        istats1 += declareGetter(p.name, p.tpt, astFieldAnnot :: p.mods.annotations)
        val pmods = if (p.mods.hasFlag(OVERRIDE)) Modifiers(OVERRIDE) else NoMods
        stats1 += defineGetter(p.name, p.tpt, pmods)
      }
      paramss1 += params.map { p =>
        val mods1 = p.mods.mkMutable.unPrivate.unOverride.unDefault
        q"$mods1 val ${internalize(p.name)}: ${p.tpt}"
      }

      // step 6: implement the unimplemented methods in InternalTree (part 1)
      // The purpose of privateCopy is to provide extremely cheap cloning
      // in the case when a tree changes its parent (because that happens often in our framework,
      // e.g. when we create a quasiquote and then insert it into a bigger quasiquote,
      // or when we parse something and build the trees from the ground up).
      // In such a situation, we copy all private state verbatim (tokens, denotations, etc)
      // and create lazy initializers that will take care of recursively copying the children.
      // Compare this with the `copy` method (described below), which additionally flushes the private state.
      // This method is private[meta] because the state that it's managing is not supposed to be touched
      // by the users of the framework.
      val privateCopyArgs = params
        .map(p => q"$CommonTyperMacrosModule.initField(this.${internalize(p.name)})")
      val privateCopyParentChecks =
        if (parentChecks.isEmpty) q""
        else q"""
            if (destination != null) {
              def checkParent(fn: ($name, $TreeClass, $StringClass) => $BooleanClass): $UnitClass = {
                val parentCheckOk = fn(this, parent, destination)
                if (!parentCheckOk) {
                  val parentPrefix = parent.productPrefix
                  _root_.org.scalameta.invariants.require(parentCheckOk && _root_.org.scalameta.debug(this, parentPrefix, destination))
                }
              }
              ..$parentChecks
            }
          """
      stats1 +=
        q"""
        private[meta] def privateCopy(
            prototype: $TreeClass = this,
            parent: $TreeClass = ${privateFields.parent.field.name},
            destination: $StringClass = null,
            origin: $OriginClass = ${privateFields.origin.field.name}): Tree = {
          $privateCopyParentChecks
          $DataTyperMacrosModule.nullCheck(origin)
          new $name(prototype.asInstanceOf[$iname], parent, origin)(..$privateCopyArgs)
        }
      """
      // step 7: create the copy method
      // The purpose of this method is to provide a facility to change small parts of the tree
      // without modifying the other parts, much like the standard case class copy works.
      // In such a situation, the tree is going to be recreated.
      // NOTE: Can't generate XXX.Quasi.copy, because XXX.Quasi already inherits XXX.copy,
      // and there can't be multiple overloaded methods with default parameters.
      // Not a big deal though, since XXX.Quasi is an internal class.
      def getParamArg(p: ValOrDefDef) = q"${p.name}"
      def addCopy(params: List[ValDef], annots: Tree*) = {
        val mods = getDeferredModifiers(annots.toList)
        istats1 +=
          q"""
            $mods def copy(..$params): $iname
          """
        val args = params.map(getParamArg)
        stats1 +=
          q"""
            final override def copy(..$params): $iname = {
              $mname.apply(..$args)
            }
          """
        quasiCopyExtraParamss += params
      }
      if (!isQuasi) {
        def getCopyParamWithDefault(p: ValOrDefDef): ValDef = asValDefn(p, q"this.${p.name}")
        val fullCopyParams = params.map(getCopyParamWithDefault)
        val iFullCopy = q"private[meta] def fullCopy(..$fullCopyParams): $iname"
        istats1 += iFullCopy
        quasiExtraAbstractDefs += iFullCopy
        stats1 +=
          q"""
            private[meta] final override def fullCopy(..$fullCopyParams): $iname = {
              $mname.apply(..${params.map(_.name)})
            }
          """
        if (needCopies)
          if (versionedParams.isEmpty) addCopy(fullCopyParams)
          else {
            // add primary copy with default values
            val defaultCopyParams =
              positionVersionedParams(versionedParams.flatMap(_.getDefaultCopyDef()))
                .map(getCopyParamWithDefault)
            addCopy(defaultCopyParams)

            val defaultCopyParamNames = defaultCopyParams.map(_.name.toString).toSet
            def allInDefaults(cp: Iterable[ValDef]): Boolean = cp
              .forall(x => defaultCopyParamNames.contains(x.name.toString))

            // add full copy without defaults
            if (!allInDefaults(params)) addCopy(params.map(asValDecl))
            // add secondary copy
            paramsVersions.foreach { version =>
              val copyParams = paramsForVersion(version)
              if (copyParams.length != defaultCopyParams.length || !allInDefaults(copyParams))
                addCopy(copyParams.map(asValDecl), getDeprecatedAnno(version))
            }
          }
      }

      // step 7a: override the Object and Equals methods
      if (!isQuasi) {
        istats1 +=
          q"final override def canEqual(that: Any): $BooleanClass = that.isInstanceOf[$iname]"
        istats1 +=
          q"final override def equals(that: Any): $BooleanClass = this eq that.asInstanceOf[AnyRef]"
        istats1 += q"final override def hashCode: $IntClass = System.identityHashCode(this)"
        istats1 +=
          q"final override def toString: $StringClass = scala.meta.internal.prettyprinters.TreeToString(this)"
      }

      // step 8: create the children method
      stats1 +=
        q"def children: $ListClass[$TreeClass] = $CommonTyperMacrosModule.children[$iname, $TreeClass]"

      // step 9: generate boilerplate required by the @ast infrastructure
      ianns1 += q"new $AstMetadataModule.astClass"
      ianns1 += q"new $AdtMetadataModule.leafClass"
      manns1 += q"new $AstMetadataModule.astCompanion"
      manns1 += q"new $AdtMetadataModule.leafCompanion"

      // step 10: generate boilerplate required by the classifier infrastructure
      mstats1 ++= mkClassifier(iname)
      mstats1 += mkAstInfo(iname)

      // step 11: implement Product
      iparents1 += tq"$ProductClass"

      stats1 +=
        q"override def productPrefix: $StringClass = $CommonTyperMacrosModule.productPrefix[$iname]"
      stats1 += q"override def productArity: $IntClass = ${params.length}"

      def patternMatchClauses(fromField: (ValDef, Int) => Tree) = {
        val pelClauses = ListBuffer[Tree]()
        pelClauses ++= params.zipWithIndex.map(fromField.tupled)
        pelClauses += cq"_ => throw new $IndexOutOfBoundsException(n.toString)"
        pelClauses.toList
      }

      val pelClauses = patternMatchClauses((vr, i) => cq"$i => this.${vr.name}")
      stats1 += q"override def productElement(n: $IntClass): Any = n match { case ..$pelClauses }"
      stats1 +=
        q"override def productIterator: $IteratorClass[$AnyClass] = $ScalaRunTimeModule.typedProductIterator(this)"
      val productFields = params.map(_.name.toString)
      stats1 +=
        q"override def productFields: $ListClass[$StringClass] = _root_.scala.List(..$productFields)"

      // step 13a add productElementName for 2.13
      if (MacroCompat.productFieldNamesAvailable) {
        val penClauses = patternMatchClauses { (vr, i) =>
          val lit = Literal(Constant(vr.name.toString()))
          cq"""$i => $lit """
        }
        stats1 +=
          q"override def productElementName(n: $IntClass): java.lang.String = n match { case ..$penClauses }"
      }

      // step 12: generate serialization logic
      stats1 +=
        q"""
          protected def writeReplace(): $AnyRefClass = {
            ..${params.map(loadField)}
            this
          }
        """

      // step 13: generate Companion.apply
      val internalBody = ListBuffer[Tree]()
      internalBody += q"$CommonTyperMacrosModule.hierarchyCheck[$iname]"
      params.foreach { p =>
        val local = p.name
        internalBody += q"$DataTyperMacrosModule.nullCheck($local)"
        internalBody += q"$DataTyperMacrosModule.emptyCheck($local)"
      }
      internalBody ++= imports
      fieldChecks.foreach { x =>
        val fieldCheck = q"_root_.org.scalameta.invariants.require($x)"
        var hasErrors = false
        object errorChecker extends Traverser {
          private val nmeParent = TermName("parent")
          override def traverse(tree: Tree): Unit = tree match {
            case _: This =>
              hasErrors = true; c.error(tree.pos, "cannot refer to this in @ast field checks")
            case Ident(`nmeParent`) =>
              hasErrors = true
              c.error(
                tree.pos,
                "cannot refer to parent in @ast field checks; use checkParent instead"
              )
            case _ => super.traverse(tree)
          }
        }
        errorChecker.traverse(fieldCheck)
        if (!hasErrors) internalBody += fieldCheck
      }
      val paramInits = params.map(p => q"$CommonTyperMacrosModule.initParam(${p.name})")
      privateParams.foreach { p =>
        if (p.persist) internalBody += q"$DataTyperMacrosModule.nullCheck(${p.field.name})"
        else internalBody += asValDefn(p.field)
      }
      val internalArgs = params.map(getParamArg)
      val bparamCtorArgs = bparams1.map { p =>
        if (p eq privateFields.origin.field) q"""
               $OriginModule.first(
                 alternativeOrigin,
                 $OriginModule.DialectOnly.getFromArgs(..$internalArgs)
               )
             """
        else getParamArg(p)
      }
      internalBody +=
        q"""
          val node = new $name(
            ..$bparamCtorArgs
          )(
            ..$paramInits
          )
        """
      params.foreach(p => internalBody += storeField(p))
      internalBody += q"node"
      val applyParamDefns = params.map(asValDefn)
      val applyParamDecls = params.map(asValDecl)
      val bparamDecls = privateApplyParams.map(asValDecl)
      val fullApplyParamDecls = bparamDecls ++ applyParamDecls
      val fullInternalArgs = privateApplyParams.map(getParamArg) ++ internalArgs
      val bparamRhsInternalArgs = privateApplyParams.map(p => p.rhs) ++ internalArgs
      if (isTopLevel) {
        mstats1 +=
          q"""
            def apply(..$applyParamDefns): $iname = {
              $mname.apply(..$bparamRhsInternalArgs)
            }
          """
        mstats1 +=
          q"""
            def apply(..$fullApplyParamDecls): $iname = {
              val alternativeOrigin = origin
              ..$internalBody
            }
          """
      } else {
        mstats1 +=
          q"""
            def apply(..$applyParamDefns)(implicit dialect: $DialectClass): $iname = {
              $mname.apply(..$bparamRhsInternalArgs)
            }
          """
        mstats1 +=
          q"""
            def apply(..$fullApplyParamDecls)(implicit dialect: $DialectClass): $iname = {
              val alternativeOrigin =
                $OriginModule.first(origin, implicitly[$OriginModule.DialectOnly])
              ..$internalBody
            }
          """
        mstats1LowPriority +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$applyParamDecls): $iname = {
              $mname.apply(..$bparamRhsInternalArgs)
            }
          """
        mstats1LowPriority +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$fullApplyParamDecls): $iname = {
              $mname.apply(..$fullInternalArgs)
            }
          """
      }
      mstatsLatest +=
        q"""
          @$InlineAnnotation def apply(..$fullApplyParamDecls)(implicit dialect: $DialectClass): $iname =
            $mname.apply(..$fullInternalArgs)
        """
      mstatsLatestLowPriority +=
        q"""
          @$InlineAnnotation @$deprecatedSince_4_9_0 def apply(..$fullApplyParamDecls): $iname =
            $mname.apply(..$fullInternalArgs)
        """
      mstatsLatest +=
        q"""
          @$InlineAnnotation def apply(..$applyParamDefns)(implicit dialect: $DialectClass): $iname =
            $mname.apply(..$internalArgs)
        """
      mstatsLatestLowPriority +=
        q"""
          @$InlineAnnotation @$deprecatedSince_4_9_0 def apply(..$applyParamDecls): $iname =
            $mname.apply(..$internalArgs)
        """

      // step 13a: generate additional companion apply for added and replaced fields
      // generate new applies for each new field added
      // with field A, B and additional binary compat ones C, D and E, we generate:
      // apply(A, B, C), apply(A, B, C, D), apply(A, B, C, D, E)
      mstatsPerVersion.foreach { case (v, verMstats) =>
        val applyParamsBuilder = List.newBuilder[(ValDef, Int)]
        val applyBodyBuilder = List.newBuilder[Tree]
        versionedParams.foreach { vp =>
          val (decl, defn) = vp.getApplyDeclDefnBefore(v)
          decl.foreach(applyParamsBuilder += _)
          defn.foreach(applyBodyBuilder += _)
        }
        val paramDefns = positionVersionedParams(applyParamsBuilder.result())
        val applyBody = applyBodyBuilder.result()
        val applyCall = q"$mname.apply(..$internalArgs)"
        val paramDecls = paramDefns.map(asValDecl)
        val fullParamDecls = bparamDecls ++ paramDecls
        val fullParamDeclNames = fullParamDecls.map(_.name)
        verMstats.lowPrio +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$fullParamDecls): $iname = {
              $mname.apply(..$fullParamDeclNames)
            }
          """
        verMstats.primary +=
          q"""
            def apply(..$fullParamDecls)(implicit dialect: $DialectClass): $iname = {
              $mname.apply(..$fullParamDeclNames)
            }
          """
        verMstats.lowPrio +=
          q"""
            @$deprecatedSince_4_9_0 def apply(..$paramDecls): $iname = {
              ..$applyBody
              $applyCall
            }
          """
        verMstats.primary +=
          q"""
            def apply(..$paramDefns)(implicit dialect: $DialectClass): $iname = {
              ..$applyBody
              $applyCall
            }
          """
        if (isTopLevel) {
          mstats1 +=
            q"""
              def apply(..$fullParamDecls): $iname = {
                ..$applyBody
                $mname.apply(..$fullInternalArgs)
              }
            """
          mstats1 +=
            q"""
              @${getDeprecatedAnno(v)} def apply(..$paramDecls): $iname = {
                ..$applyBody
                $applyCall
              }
            """

        } else {
          mstats1LowPriority +=
            q"""
              @$deprecatedSince_4_9_0 def apply(..$fullParamDecls): $iname = {
                ..$applyBody
                $mname.apply(..$fullInternalArgs)
              }
            """
          mstats1LowPriority +=
            q"""
              @${getDeprecatedAnno(v)} def apply(..$paramDecls): $iname = {
                ..$applyBody
                $applyCall
              }
            """
          mstats1 +=
            q"""
              def apply(..$fullParamDecls)(implicit dialect: $DialectClass): $iname = {
                ..$applyBody
                $mname.apply(..$fullInternalArgs)
              }
            """
          mstats1 +=
            q"""
              @${getDeprecatedAnno(v)} def apply(..$paramDecls)(implicit dialect: $DialectClass): $iname = {
                ..$applyBody
                $applyCall
              }
            """
        }
      }

      // step 14: generate Companion.unapply
      val needsUnapply = !mstats.exists {
        case DefDef(_, TermName("unapply"), _, _, _, _) => true
        case _ => false
      }
      if (needsUnapply) {
        def getUnapply(unapplyParams: List[ValDef], annots: Tree*): Tree =
          if (unapplyParams.isEmpty) q"""
                @$InlineAnnotation @..$annots final def unapply(x: $iname): $BooleanClass =
                  x != null && x.isInstanceOf[$name]
              """
          else {
            val successTargs = tq"(..${unapplyParams.map(p => p.tpt)})"
            val successArgs = q"(..${unapplyParams.map(p => q"x.${p.name}")})"
            q"""
                @$InlineAnnotation @..$annots final def unapply(x: $iname): $OptionClass[$successTargs] =
                  if (x != null && x.isInstanceOf[$name]) $SomeModule($successArgs) else $NoneModule
              """
          }
        val latestTree = getUnapply(params)
        mstatsPerVersion match {
          case (headVer, headMstats) :: tail =>
            val headParams = paramsForVersion(headVer)
            headMstats.primary += getUnapply(headParams)
            tail.foreach { case (ver, verMstats) =>
              verMstats.primary += getUnapply(paramsForVersion(ver))
            }
            val afterLastVer = getAfterVersion(mstatsPerVersion.last._1)
            val anno = getDeprecatedAnno(headVer, s"; use `.$afterLastVer`")
            mstats1 += getUnapply(headParams, anno)
          case Nil => mstats1 += latestTree
        }
        mstatsLatest += latestTree
      }

      // step 15: finish codegen for Quasi
      if (isQuasi) stats1 +=
        q"""
          def become[T <: $TreeClass](implicit ev: $AstInfoClass[T]): T with $QuasiClass = {
            (this match {
              case $mname(0, tree) =>
                ev.quasi(0, tree)
              case $mname(rank, nested @ $mname(0, tree)) =>
                ev.quasi(rank, nested.become[T])
              case _ =>
                throw new Exception("complex ellipses are not supported yet")
            }).withOrigin(this.origin): T with $QuasiClass
          }
        """
      else mstats1 += mkQuasi(
        iname,
        iparents,
        params,
        quasiCopyExtraParamss,
        quasiExtraAbstractDefs.result(),
        "name",
        "value",
        "tpe"
      )

      val latestName = mstatsPerVersion
        .foldLeft(initialName) { case (afterPrevVerName, (ver, verMstats)) =>
          val lowPriority = TypeName(afterPrevVerName + "LowPriority")
          mstats1 += q"private[meta] trait $lowPriority { ..${verMstats.lowPrio} }"
          val afterPrevVer = TermName(afterPrevVerName)
          mstats1 += q"object $afterPrevVer extends $lowPriority { ..${verMstats.primary} }"
          getAfterVersion(ver)
        }
      val latestTermName = TermName(latestName)
      val latestLowPriority = TypeName(latestName + "LowPriority")
      mstats1 += q"private[meta] trait $latestLowPriority { ..$mstatsLatestLowPriority }"
      mstats1 += q"object $latestTermName extends $latestLowPriority { ..$mstatsLatest }"
      // to be ignored by Mima, use "internal"
      mstats1 += q"object internal { final val Latest = $latestTermName }"

      mstats1 += q"$mods1 class $name[..$tparams] $ctorMods(...${bparams1 +:
          paramss1}) extends { ..$earlydefns } with ..$parents1 { $self => ..$stats1 }"

      val res = ListBuffer.empty[ImplDef]

      val mparents1 =
        if (isTopLevel) mparents
        else {
          val lowPriority = TypeName(iname.toString() + "LowPriority")
          res += q"private[meta] trait $lowPriority { ..$mstats1LowPriority }"
          mparents :+ tq"$lowPriority"
        }
      val cdef1 = q"$imods1 trait $iname extends ..$iparents1 { $iself => ..$istats1 }"
      res += cdef1
      val mdef1 =
        q"$mmods1 object $mname extends { ..$mearlydefns } with ..$mparents1 { $mself => ..$mstats1 }"
      res += mdef1
      if (c.compilerSettings.contains("-Xprint:typer")) { println(cdef1); println(mdef1) }
      res.result()
    }
  })

  private def internalize(name: String): TermName = TermName(s"_${name.stripPrefix("_")}")
  private def internalize(name: TermName): TermName = internalize(name.toString)
  private def setterName(name: String): TermName =
    TermName(s"set${name.stripPrefix("_").capitalize}")
  private def setterName(name: TermName): TermName = setterName(name.toString)
  private def setterName(vr: ValOrDefDef): TermName = setterName(vr.name)
  private def getterName(name: String): TermName = TermName(s"${name.stripPrefix("_")}")
  private def getterName(name: TermName): TermName = getterName(name.toString)
  private def getterName(vr: ValOrDefDef): TermName = getterName(vr.name)

  private def loadField(vr: ValOrDefDef): Tree = loadField(vr.name)
  private def loadField(name: TermName): Tree = loadField(internalize(name), name)
  private def loadField(internalName: TermName, name: TermName): Tree = q"""
      $CommonTyperMacrosModule.loadField(this.$internalName, ${name.decodedName.toString})
    """

  private def storeField(vr: ValOrDefDef): Tree = storeField(vr.name)
  private def storeField(name: TermName): Tree = storeField(internalize(name), name)
  private def storeField(internalName: TermName, name: TermName): Tree = q"""
      $CommonTyperMacrosModule.storeField(node.$internalName, $name, ${name.decodedName.toString})
    """

  private def getDeferredModifiers(annots: List[Tree]): Modifiers =
    Modifiers(DEFERRED, typeNames.EMPTY, annots)

  private val astFieldAnnot = q"new $AstMetadataModule.astField"
  private val privateFieldAnnot = q"new $AdtMetadataModule.privateField"

  private def declareGetter(name: TermName, tpe: Tree, annots: List[Tree]): Tree =
    declareGetter(name, tpe, getDeferredModifiers(annots))

  private def declareGetter(name: TermName, tpe: Tree, mods: Modifiers): Tree =
    q"$mods def ${getterName(name)}: $tpe"

  private def defineGetter(name: TermName, tpe: Tree, mods: Modifiers): Tree = {
    val internalName = internalize(name)
    q"""
      $mods def ${getterName(name)}: $tpe = {
        ${loadField(internalName, name)}
        this.$internalName
      }
    """
  }

  private def declareSetter(name: TermName, tpe: Tree, annots: List[Tree]): Tree =
    declareSetter(name, tpe, getDeferredModifiers(annots))

  private def declareSetter(name: TermName, tpe: Tree, mods: Modifiers): Tree =
    q"$mods def ${setterName(name)}($name : $tpe): Unit"

  private def defineSetter(name: TermName, tpe: Tree, mods: Modifiers): Tree = q"""
      $mods def ${setterName(name)}($name : $tpe): Unit = {
        val node = this
        ${storeField(name)}
      }
    """

  private class VersionedParam(
      val param: ValDef,
      val appended: Option[Version],
      val replaced: Seq[ReplacedField]
  ) {
    appended.foreach { aver =>
      replaced.headOption.foreach { rfield =>
        if (rfield.version <= aver) {
          val oldDef = rfield.oldDefs.head._1
          c.abort(
            param.pos,
            s"$aver [@newField for ${param.name}] must must precede " +
              s"${rfield.version} [@replacedField for ${oldDef.name}]"
          )
        }
      }
    }

    def getApplyDeclDefnBefore(version: Version): (List[(ValDef, Int)], Option[ValDef]) = {
      def checkVersion(ver: Version): Boolean = version <= ver
      if (appended.exists(checkVersion)) (Nil, Some(asValDefn(param)))
      else replaced.find(x => checkVersion(x.version)).map { rfield =>
        val decls = rfield.oldDefs.map { case (oldDef, pos) => asValDecl(oldDef) -> pos }
        (decls, Some(rfield.newValDefn))
      }.getOrElse((asValDefn(param) -> -1 :: Nil, None))
    }
    def getDefaultCopyDef(): List[(ValOrDefDef, Int)] = replaced.headOption.map(_.oldDefs)
      .getOrElse((param, -1) :: Nil)
  }

  private def positionVersionedParams[A](params: List[(A, Int)]): List[A] = {
    val res = new ListBuffer[A]
    val paramIter = params.iterator.filter(_._2 < 0)
    @tailrec
    def iter(withPositions: List[(A, Int)]): Unit = withPositions match {
      case (v, pos) :: rest =>
        paramIter.take(pos - res.length).foreach { case (x, _) => res += x }
        res += v
        iter(rest)
      case _ => paramIter.foreach { case (x, _) => res += x }
    }
    iter(params.filter(_._2 >= 0).sortBy(_._2))
    res.toList
  }

  private def getVersionedParams(
      params: List[ValDef],
      stats: List[Tree]
  ): (List[VersionedParam], List[Version]) = {
    val appendedFields: Map[String, Version] = getNewFieldVersions(params)
    val replacedFields: Map[String, Seq[ReplacedField]] = ReplacedField.getMap(params, stats)
    val versionsBuilder = Set.newBuilder[Version]
    appendedFields.values.foreach(versionsBuilder += _)
    replacedFields.values.foreach(_.foreach(versionsBuilder += _.version))
    val versions = versionsBuilder.result().toList.sorted
    val versionedParams = params.map { p =>
      val pname = p.name.toString
      val appended = appendedFields.get(pname)
      val replaced = replacedFields.getOrElse(pname, Seq.empty)
      new VersionedParam(p, appended, replaced)
    }
    (versionedParams, versions)
  }

  private def getAnnotAttribute(value: Tree): String = value match {
    case x: AssignOrNamedArg => x.rhs.toString
    case x => x.toString
  }

  private def parseVersionAnnot(version: Tree, annot: String, field: String): Version = {
    val parsed = Version.parse(getAnnotAttribute(version).stripPrefix("\"").stripSuffix("\""))
      .getOrElse(c.abort(version.pos, s"@$annot must contain $field=major.minor.patch"))
    buildVersion.foreach { bv =>
      if (parsed.major < bv.major) c
        .abort(version.pos, s"@$annot: obsolete, old major version (must be ${bv.major})")
      if (parsed > bv) c
        .abort(version.pos, s"@$annot can't refer to future versions (current is $bv)")
    }
    parsed
  }

  private def getNewFieldVersions(params: List[ValDef]): Map[String, Version] = {
    val builder = Map.newBuilder[String, Version]
    var prevVersion: Version = null
    params.foreach { x =>
      val sinceOpt = x.mods.annotations.collectFirst { case q"new newField($since)" => since }
      if (sinceOpt.isEmpty && prevVersion != null) c
        .abort(x.pos, "must be marked @newField since previous field is")
      sinceOpt.foreach { since =>
        if (x.mods.hasFlag(Flag.OVERRIDE)) c
          .abort(x.pos, "override fields may not be marked @newField")
        if (x.rhs == EmptyTree) c.abort(x.pos, "@newField fields must provide a default value")
        val version = parseVersionAnnot(since, "newField", "after")
        if (null != prevVersion && version < prevVersion) c
          .abort(x.pos, s"previous field marked with newer version: $prevVersion")
        prevVersion = version
        builder += x.name.toString -> version
      }
    }
    builder.result()
  }

  private class ReplacedField(
      val version: Version,
      val newVal: ValDef,
      ctor: Tree,
      val oldDefs: List[(ValOrDefDef, Int)]
  ) {
    def newValDefn: ValDef = {
      def bodyForSingleOldDef(oldDef: ValOrDefDef) =
        if (ctor eq null) q"""
            import scala.meta.trees._
            ${oldDef.name}
           """
        else q"""
            $ctor(${oldDef.name})
           """
      def bodyForMultipleOldDefs = {
        if (ctor eq null) c.abort(newVal.pos, s"${newVal.name} must define a ctor")
        val names = oldDefs.map { case (oldDef, _) =>
          val name = q"${oldDef.name}"
          val arg = AssignOrNamedArg(name, name)
          q"$arg"
        }
        q"$ctor(..$names)"
      }
      val body = oldDefs match {
        case (oldDef, _) :: Nil => bodyForSingleOldDef(oldDef)
        case _ => bodyForMultipleOldDefs
      }
      q"""
        val ${newVal.name}: ${newVal.tpt} = {
          ..$body
        }
      """
    }
  }

  private object ReplacedField {
    def getMap(params: List[ValDef], stats: List[Tree]): Map[String, Seq[ReplacedField]] = {
      val fields: Map[String, (ValDef, Map[Version, Tree])] = params.map { p =>
        val ctorsByVersion = p.mods.annotations
          .collect { case q"new replacesFields($since, $ctor)" =>
            val version = parseVersionAnnot(since, "replacesFields", "after")
            version -> ctor
          }.toMap
        p.name.toString -> (p, ctorsByVersion)
      }.toMap
      val replacedFields = stats.flatMap {
        case p: ValOrDefDef =>
          val anno = p.mods.annotations.collectFirst {
            case q"new replacedField($until)" => (until, -1)
            case q"new replacedField($until, $pos)" => (until, getAnnotAttribute(pos).toInt)
          }
          anno.map { case (until, pos) =>
            if (!p.mods.hasFlag(Flag.FINAL)) c
              .abort(p.pos, "replacedField-annotated fields must be final")
            val version = parseVersionAnnot(until, "replacedField", "until")
            val newField = getNewField(p)
            newField -> (p, version, pos)
          }
        case _ => None
      }
      replacedFields.groupBy(_._1).map { case (k, v) =>
        val (newVal, ctorsByVersion) = fields
          .getOrElse(k, c.abort(v.head._2._1.pos, s"@replacedField: field `$k` is undefined)"))
        val replacements = v.map(_._2).groupBy(_._2).toSeq.map { case (ver, oldFields) =>
          val ctor = ctorsByVersion.get(ver).orNull
          val oldDefs = oldFields.map { case (oldField, _, pos) => (oldField, pos) }
          new ReplacedField(ver, newVal, ctor, oldDefs)
        }
        k -> replacements.sortBy(_.version)
      }
    }

    private def getNewField(oldDef: ValOrDefDef): String = {
      @tailrec
      def iter(tree: Tree): Option[String] = tree match {
        case Select(Ident(TermName(newField)), _: TermName) => Some(newField)
        case Apply(_, Ident(TermName(newField)) :: Nil) => Some(newField)
        case Match(Ident(TermName(newField)), _) => Some(newField)
        case Select(x, _) => iter(x)
        case Apply(x, (_: Function) :: Nil) => iter(x)
        case _ => None
      }
      iter(oldDef.rhs).getOrElse(
        c.abort(oldDef.pos, s"@replacedField: can't find new field name (${showRaw(oldDef.rhs)})")
      )
    }
  }

  private val deprecatedSince_4_9_0 = getDeprecatedAnno("4.9.0")

  private def getDeprecatedAnno(v: Version, why: String = ""): Tree =
    getDeprecatedAnno(v.toString + why)
  private def getDeprecatedAnno(since: String): Tree =
    q"new scala.deprecated(${Literal(Constant(since))})"

  private def getAfterVersion(v: Version) = afterNamePrefix + v.asString('_')

  private def asValDecl(p: ValOrDefDef): ValDef =
    q"@..${p.mods.annotations} val ${p.name}: ${p.tpt}"
  private def asValDefn(p: ValOrDefDef): ValDef = asValDefn(p, p.rhs)
  private def asValDefn(p: ValOrDefDef, rhs: Tree): ValDef =
    q"@..${p.mods.annotations} val ${p.name}: ${p.tpt} = $rhs"

}

object AstNamerMacros {

  private val buildVersion: Option[Version] = {
    val bv = BuildInfo.version
    val idx = bv.indexWhere(x => x == '-' || x == '+')
    val version = if (idx < 0) bv else bv.substring(0, idx)
    // filter in case buildVersion is incorrectly set (Windows forces 0.0.0)
    Version.parse(version).toOption.filter(_ != Version.zero)
  }

  val initialName = "Initial"
  val afterNamePrefix = "After_"

  def getLatestAfterName(moduleNames: Iterable[String]): Option[String] = {
    var maxVersion = Version.zero
    var maxName: Option[String] = None
    moduleNames.foreach { name =>
      if (name == initialName) { if (maxName.isEmpty) maxName = Some(name) }
      else if (name.startsWith(afterNamePrefix)) Version
        .parse(name.substring(afterNamePrefix.length), '_').toOption.foreach { v =>
          if (v > maxVersion) {
            maxName = Some(name)
            maxVersion = v
          }
        }
    }
    maxName
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy