All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.annotation.MacroAnnotation.scala Maven / Gradle / Ivy

// TODO in which package should this class be located?
package scala
package annotation

import scala.quoted.*

/** Base trait for macro annotation implementation.
 *  Macro annotations can transform definitions and add new definitions.
 *
 *  See: `MacroAnnotation.transform`
 *
 *  @syntax markdown
 */
@experimental
trait MacroAnnotation extends StaticAnnotation:

  /** Transform the `tree` definition and add new definitions
   *
   *  This method takes as argument the annotated definition.
   *  It returns a non-empty list containing the modified version of the annotated definition.
   *  The new tree for the definition must use the original symbol.
   *  New definitions can be added to the list before or after the transformed definitions, this order
   *  will be retained. New definitions will not be visible from outside the macro expansion.
   *
   *  #### Restrictions
   *   - All definitions in the result must have the same owner. The owner can be recovered from `Symbol.spliceOwner`.
   *     - Special case: an annotated top-level `def`, `val`, `var`, `lazy val` can return a `class`/`object`
   *                     definition that is owned by the package or package object.
   *   - Can not return a `type`.
   *   - Annotated top-level `class`/`object` can not return top-level `def`, `val`, `var`, `lazy val`.
   *   - Can not see new definition in user written code.
   *
   *  #### Good practices
   *   - Make your new definitions private if you can.
   *   - New definitions added as class members should use a fresh name (`Symbol.freshName`) to avoid collisions.
   *   - New top-level definitions should use a fresh name (`Symbol.freshName`) that includes the name of the annotated
   *     member as a prefix to avoid collisions of definitions added in other files.
   *
   *  **IMPORTANT**: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
   *
   *  #### Example 1
   *  This example shows how to modify a `def` and add a `val` next to it using a macro annotation.
   *  ```scala
   *  import scala.quoted.*
   *  import scala.collection.concurrent
   *
   *  class memoize extends MacroAnnotation:
   *    def transform(using Quotes)(
   *      definition: quotes.reflect.Definition,
   *      companion: Option[quotes.reflect.Definition]
   *    ): List[quotes.reflect.Definition] =
   *      import quotes.reflect.*
   *      definition match
   *        case DefDef(name, TermParamClause(param  :: Nil) :: Nil, tpt, Some(rhsTree)) =>
   *          (param.tpt.tpe.asType, tpt.tpe.asType) match
   *            case ('[t], '[u]) =>
   *              val cacheName = Symbol.freshName(name + "Cache")
   *              val cacheSymbol = Symbol.newVal(Symbol.spliceOwner, cacheName, TypeRepr.of[concurrent.Map[t, u]], Flags.Private, Symbol.noSymbol)
   *              val cacheRhs =
   *                given Quotes = cacheSymbol.asQuotes
   *                '{ concurrent.TrieMap.empty[t, u] }.asTerm
   *              val cacheVal = ValDef(cacheSymbol, Some(cacheRhs))
   *              val newRhs =
   *                given Quotes = definition.symbol.asQuotes
   *                val cacheRefExpr = Ref(cacheSymbol).asExprOf[concurrent.Map[t, u]]
   *                val paramRefExpr = Ref(param.symbol).asExprOf[t]
   *                val rhsExpr = rhsTree.asExprOf[u]
   *                '{ $cacheRefExpr.getOrElseUpdate($paramRefExpr, $rhsExpr) }.asTerm
   *              val newTree = DefDef.copy(definition)(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(newRhs))
   *              List(cacheVal, newTree)
   *        case _ =>
   *          report.error("Annotation only supported on `def` with a single argument are supported")
   *          List(definition)
   *    end transform
   *  ```
   *  with this macro annotation a user can write
   *  ```scala
   *  //{
   *  class memoize extends scala.annotation.StaticAnnotation
   *  //}
   *  @memoize
   *  def fib(n: Int): Int =
   *    println(s"compute fib of $n")
   *    if n <= 1 then n else fib(n - 1) + fib(n - 2)
   *  ```
   *  and the macro will modify the definition to create
   *  ```scala
   *   val fibCache$macro$1 =
   *     scala.collection.concurrent.TrieMap.empty[Int, Int]
   *   def fib(n: Int): Int =
   *     fibCache$macro$1.getOrElseUpdate(
   *       n,
   *       {
   *         println(s"compute fib of $n")
   *         if n <= 1 then n else fib(n - 1) + fib(n - 2)
   *       }
   *     )
   *  ```
   *
   *  #### Example 2
   *  This example shows how to modify a `class` using a macro annotation.
   *  It shows how to override inherited members and add new ones.
   *  ```scala
   *  import scala.annotation.{experimental, MacroAnnotation}
   *  import scala.quoted.*
   *
   *  @experimental
   *  class equals extends MacroAnnotation:
   *    def transform(using Quotes)(
   *      definition: quotes.reflect.Definition,
   *      companion: Option[quotes.reflect.Definition]
   *    ): List[quotes.reflect.Definition] =
   *      import quotes.reflect.*
   *      definition match
   *        case ClassDef(className, ctr, parents, self, body) =>
   *          val cls = definition.symbol
   *
   *          val constructorParameters = ctr.paramss.collect { case clause: TermParamClause => clause }
   *          if constructorParameters.size != 1 || constructorParameters.head.params.isEmpty then
   *            report.errorAndAbort("@equals class must have a single argument list with at least one argument", ctr.pos)
   *          def checkNotOverridden(sym: Symbol): Unit =
   *            if sym.overridingSymbol(cls).exists then
   *              report.error(s"Cannot override ${sym.name} in a @equals class")
   *
   *          val fields = body.collect {
   *            case vdef: ValDef if vdef.symbol.flags.is(Flags.ParamAccessor) =>
   *              Select(This(cls), vdef.symbol).asExpr
   *          }
   *
   *          val equalsSym = Symbol.requiredMethod("java.lang.Object.equals")
   *          checkNotOverridden(equalsSym)
   *          val equalsOverrideSym = Symbol.newMethod(cls, "equals", equalsSym.info, Flags.Override, Symbol.noSymbol)
   *          def equalsOverrideDefBody(argss: List[List[Tree]]): Option[Term] =
   *            given Quotes = equalsOverrideSym.asQuotes
   *            cls.typeRef.asType match
   *              case '[c] =>
   *                Some(equalsExpr[c](argss.head.head.asExpr, fields).asTerm)
   *          val equalsOverrideDef = DefDef(equalsOverrideSym, equalsOverrideDefBody)
   *
   *          val hashSym = Symbol.newVal(cls, Symbol.freshName("hash"), TypeRepr.of[Int], Flags.Private | Flags.Lazy, Symbol.noSymbol)
   *          val hashVal = ValDef(hashSym, Some(hashCodeExpr(className, fields)(using hashSym.asQuotes).asTerm))
   *
   *          val hashCodeSym = Symbol.requiredMethod("java.lang.Object.hashCode")
   *          checkNotOverridden(hashCodeSym)
   *          val hashCodeOverrideSym = Symbol.newMethod(cls, "hashCode", hashCodeSym.info, Flags.Override, Symbol.noSymbol)
   *          val hashCodeOverrideDef = DefDef(hashCodeOverrideSym, _ => Some(Ref(hashSym)))
   *
   *          val newBody = equalsOverrideDef :: hashVal :: hashCodeOverrideDef :: body
   *          List(ClassDef.copy(definition)(className, ctr, parents, self, newBody))
   *        case _ =>
   *          report.error("Annotation only supports `class`")
   *          List(definition)
   *    end transform
   *
   *    private def equalsExpr[T: Type](that: Expr[Any], thisFields: List[Expr[Any]])(using Quotes): Expr[Boolean] =
   *      '{
   *        $that match
   *          case that: T @unchecked =>
   *            ${
   *              val thatFields: List[Expr[Any]] =
   *                import quotes.reflect.*
   *                thisFields.map(field => Select('{that}.asTerm, field.asTerm.symbol).asExpr)
   *              thisFields.zip(thatFields)
   *                .map { case (thisField, thatField) => '{ $thisField == $thatField } }
   *                .reduce { case (pred1, pred2) => '{ $pred1 && $pred2 } }
   *            }
   *          case _ => false
   *      }
   *
   *    private def hashCodeExpr(className: String, thisFields: List[Expr[Any]])(using Quotes): Expr[Int] =
   *      '{
   *        var acc: Int = ${ Expr(scala.runtime.Statics.mix(-889275714, className.hashCode)) }
   *        ${
   *          Expr.block(
   *            thisFields.map {
   *              case '{ $field: Boolean } => '{ if $field then 1231 else 1237 }
   *              case '{ $field: Byte } => '{ $field.toInt }
   *              case '{ $field: Char } => '{ $field.toInt }
   *              case '{ $field: Short } => '{ $field.toInt }
   *              case '{ $field: Int } => field
   *              case '{ $field: Long } => '{ scala.runtime.Statics.longHash($field) }
   *              case '{ $field: Double } => '{ scala.runtime.Statics.doubleHash($field) }
   *              case '{ $field: Float } => '{ scala.runtime.Statics.floatHash($field) }
   *              case '{ $field: Null } => '{ 0 }
   *              case '{ $field: Unit } => '{ 0 }
   *              case field => '{ scala.runtime.Statics.anyHash($field) }
   *            }.map(hash => '{ acc = scala.runtime.Statics.mix(acc, $hash) }),
   *            '{ scala.runtime.Statics.finalizeHash(acc, ${Expr(thisFields.size)}) }
   *          )
   *        }
   *      }
   *  ```
   *  with this macro annotation a user can write
   *  ```scala
   *  //{
   *  class equals extends scala.annotation.StaticAnnotation
   *  //}
   *  @equals class User(val name: String, val id: Int)
   *  ```
   *  and the macro will modify the class definition to generate the following code
   *  ```scala
   *  class User(val name: String, val id: Int):
   *    override def equals(that: Any): Boolean =
   *      that match
   *        case that: User => this.name == that.name && this.id == that.id
   *        case _ => false
   *    private lazy val hash$macro$1: Int =
   *      var acc = 515782504 // scala.runtime.Statics.mix(-889275714, "User".hashCode)
   *      acc = scala.runtime.Statics.mix(acc, scala.runtime.Statics.anyHash(name))
   *      acc = scala.runtime.Statics.mix(acc, id)
   *      scala.runtime.Statics.finalizeHash(acc, 2)
   *    override def hashCode(): Int = hash$macro$1
   *  ```
   *
   *  @param Quotes     Implicit instance of Quotes used for tree reflection
   *  @param definition Tree that will be transformed
   *  @param companion  Tree for the companion class or module if the definition is respectively a module or a class
   *
   *  @syntax markdown
   */
  def transform(using Quotes)(definition: quotes.reflect.Definition, companion: Option[quotes.reflect.Definition]): List[quotes.reflect.Definition]




© 2015 - 2025 Weber Informatics LLC | Privacy Policy