
wjson.macros.ADTMappingMacro.scala Maven / Gradle / Ivy
package wjson.macros
import scala.quoted.*
import scala.deriving.*
import wjson.{*, given}
import scala.Symbol as _
/**
* Macro to generate a JsValueMapper for a given case class.
*/
object ADTMappingMacro:
def genADTImpl[T: Type](using Quotes): Expr[JsValueMapper[T]] =
new ADTMappingMacro(quotes).genADTImpl[T]
private val NO_EXPAND_ADT = new ThreadLocal[Boolean]:
override def initialValue(): Boolean = false
// use ThreadLocal to trace the recursive calls
private val NOT_EXPAND_ADTS = new ThreadLocal[Int]():
override def initialValue(): Int = 0
/**
* to enable macro debug, runs like `sbt -Dwjson.printMacroCode=true compile`
*
* 1. -Dwjson.printMacroCode=all
* 2. -Dwjson.printMacroCode=class1,class2 all classes contains class1, class2 will dump
*/
private val PRINT_MACRO_CODE: String|Null = System.getProperty("wjson.printMacroCode")
def extractElemTypes[T: Type](using Quotes): List[quotes.reflect.TypeRepr] =
Type.of[T] match
case '[t *: ts] => quotes.reflect.TypeRepr.of[t] :: extractElemTypes[ts]
case '[EmptyTuple] => Nil
def extractElemLabels[T: Type](using Quotes): List[String] =
import quotes.reflect.*
Type.of[T] match
case '[t *: ts] => TypeRepr.of[t] match
case ConstantType(StringConstant(name)) =>
name :: extractElemLabels[ts]
case _ => throw new AssertionError("Expected a String constant type")
case '[EmptyTuple] => Nil
def extractDefaultCaseParams[T: Type](using Quotes): Map[String, Expr[Any]] =
import quotes.reflect.*
val sym = TypeTree.of[T].symbol
val comp = sym.companionClass
if (comp == Symbol.noSymbol)
throw new AssertionError("no companionClass found for type:" + TypeRepr.of[T].show(using Printer.TypeReprCode))
val module = Ref(sym.companionModule)
val names = for p <- sym.caseFields if p.flags.is(Flags.HasDefault) yield p.name
val body = comp.tree.asInstanceOf[ClassDef].body
val idents: List[Expr[?]] = for case deff@DefDef(name, _, _, _) <- body if name.startsWith("$lessinit$greater$default$")
yield module.select(deff.symbol).asExpr
names.zip(idents).toMap
trait Generator[T: Type]:
def baseTpe(using quotes: Quotes): quotes.reflect.TypeRepr = quotes.reflect.TypeRepr.of[T]
def mapperTpe(using quotes: Quotes): quotes.reflect.TypeRepr = quotes.reflect.TypeRepr.of[JsValueMapper[T]]
/**
* deps: Map[ TypeRepr: TypeRepr.of[t], Ref: JsValueMapper[t] ]
*/
def generate(using Quotes)(deps: Map[quotes.reflect.TypeRepr, quotes.reflect.Ref]): Expr[JsValueMapper[T]]
def summonJsValueMapper(using Quotes)(tpe: quotes.reflect.TypeRepr, deps: Map[quotes.reflect.TypeRepr, quotes.reflect.Ref]): Option[Expr[JsValueMapper[_]]] =
tpe.asType match
case '[t] => summonJsValueMapper[t](using quotes)(deps)
def summonJsValueMapper[t: Type](using Quotes)(deps: Map[quotes.reflect.TypeRepr, quotes.reflect.Ref]): Option[Expr[JsValueMapper[t]]] =
import quotes.reflect.*
if deps contains TypeRepr.of[t] then
Some(deps(TypeRepr.of[t]).asExprOf[JsValueMapper[t]])
else
try
ADTMappingMacro.NO_EXPAND_ADT.set(true)
ADTMappingMacro.NOT_EXPAND_ADTS.set(0)
val found = Expr.summon[JsValueMapper[t]]
if ADTMappingMacro.NOT_EXPAND_ADTS.get().nn > 0 then
None
else found
finally
ADTMappingMacro.NO_EXPAND_ADT.set(false)
class ADTMappingMacro(q: Quotes):
import ADTMappingMacro.*
private def genADTImpl[T: Type](using Quotes): Expr[JsValueMapper[T]] =
if ADTMappingMacro.NO_EXPAND_ADT.get().nn then
ADTMappingMacro.NOT_EXPAND_ADTS.set(ADTMappingMacro.NOT_EXPAND_ADTS.get().nn + 1)
'{ ??? }
else
val dependencies =
try
ADTMappingMacro.NO_EXPAND_ADT.set(true)
Expr.summon[Mirror.Of[T]].get match
case '{ $m: Mirror.ProductOf[T] } =>
visit[T](Map.empty )
case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = elemTypes; type MirroredElemLabels = elemNames } } =>
visit[T](Map.empty )
case _ => throw new AssertionError("Expected a Product or a Sum")
finally
ADTMappingMacro.NO_EXPAND_ADT.set(false)
genMultiMapperBlock[T](dependencies)
/**
* generate a block for each generator as a lazy val, so they can be referenced by each other
*/
private def genMultiMapperBlock[T: Type](needGenTypes: GeneratorMap)(using Quotes): Expr[JsValueMapper[T]] =
import quotes.reflect.*
var i = 0 // variable name counter
val valSyms: Map[TypeRepr, Symbol] = needGenTypes map: (tpe, generator) =>
i += 1
val sym = Symbol.newVal( Symbol.spliceOwner, s"mapper_$i", generator.mapperTpe, Flags.Lazy, Symbol.noSymbol )
( tpe.asInstanceOf[TypeRepr], sym )
val refs: Map[TypeRepr, Ref] = valSyms map { (tpe, sym) => ( tpe, Ref(sym) ) }
val valDefs: Map[TypeRepr, ValDef] = needGenTypes map: (tpe, generator) =>
val sym: Symbol = valSyms( tpe.asInstanceOf[TypeRepr] )
val nested = sym.asQuotes // the nested generator should be in the scope of the lazy val
val refs2 = refs.asInstanceOf[Map[nested.reflect.TypeRepr, nested.reflect.Ref]]
val mapper: Term = generator.generate(using nested)(refs2).asTerm
val valDef = ValDef( sym, Some(mapper) )
(tpe.asInstanceOf[TypeRepr], valDef)
val term = Block(valDefs.values.toList, refs(TypeRepr.of[T]) )
val debug =
val typeName = TypeRepr.of[T] match
case x if x.typeSymbol != Symbol.noSymbol => x.typeSymbol.fullName
case x if x.termSymbol != Symbol.noSymbol => x.termSymbol.fullName
case _ => ""
if "all" equalsIgnoreCase ADTMappingMacro.PRINT_MACRO_CODE then true
else if PRINT_MACRO_CODE != null then
val parts = PRINT_MACRO_CODE.split(",").nn
parts.exists( part => typeName contains part )
else false
if debug then
println("generated JsValueMapper[" + TypeRepr.of[T].show(using Printer.TypeReprAnsiCode) + "] = "
+ term.show(using Printer.TreeAnsiCode))
term.asExpr.asInstanceOf[Expr[JsValueMapper[T]]]
private type GeneratorMap = Map[q.reflect.TypeRepr, Generator[?] ]// Map[q.reflect.TypeRepr, Generator[?]]
/**
* recursive visit Type and its dependencies types and build the dependency map
* for Product types, the dependencies is it's fields types.
* for SUM types, the dependencies is it's elements types with recursive visit
* for List[T], the dependencies is T with recursive visit
* for U|V, the dependencies is U and V with recursive visit
*/
private def visit[T: Type](acc: GeneratorMap) : GeneratorMap =
given Quotes = q
import q.reflect.*
def visitAppliedType1[T: Type](value: GeneratorMap, generator: Generator[T]): GeneratorMap =
val acc2 = value + ( TypeRepr.of[T] -> generator)
TypeRepr.of[T].asInstanceOf[AppliedType].args.foldLeft(acc2):
case (acc, arg) => arg.asType match
case '[t] => visit[t](acc)
def visitADT[T: Type](value: GeneratorMap): GeneratorMap =
Expr.summon[Mirror.Of[T]] match
case Some('{ $m: Mirror.ProductOf[T] }) => visitProduct[T](acc)
case Some('{
$m: Mirror.SumOf[T] {
type MirroredElemTypes = elemTypes
type MirroredElemLabels = elemLabels
}
}) => visitSum[T, elemTypes, elemLabels](acc)
case Some(_) => ???
case None =>
throw new NotImplementedError("Not a Product or Sum or OrType:" + TypeRepr.of[T].show)
def visitProduct[T: Type](acc: GeneratorMap): GeneratorMap =
val generator = ProductGenerator[T]() // (TypeRepr.of[T], GeneratorKind.GenProduct)
val acc2 = acc + (TypeRepr.of[T] -> generator)
TypeTree.of[T].symbol.caseFields.foldLeft(acc2): (acc, field) =>
field.tree.asInstanceOf[ValDef].tpt.tpe.asType match
case '[t] => visit[t](acc)
def visitSum[T: Type, elemTypes:Type, elemNames:Type](result: GeneratorMap): GeneratorMap =
val generator = SumGenerator[T]() // Generator(TypeRepr.of[T], GeneratorKind.GenSum)
val r2 = result + (TypeRepr.of[T] -> generator)
val tpes = extractElemTypes[elemTypes]
tpes.foldLeft(r2): (acc, tpe) =>
val isParameterizedCase = tpe.termSymbol == Symbol.noSymbol
if isParameterizedCase then
tpe.asType match
case '[t] => visit[t](acc)
else acc // simple case falls into the SumGenerator process
def flatternOrTypeElements(tpe: TypeRepr): List[TypeRepr] =
tpe match
case OrType(left, right) => flatternOrTypeElements(left) ++ flatternOrTypeElements(right)
case _ => List(tpe)
def visitOrType[T: Type](acc: GeneratorMap): GeneratorMap =
val generator = OrTypeGenerator[T]() // Generator(TypeRepr.of[T], GeneratorKind.GenOrType)
val r2 = acc + (TypeRepr.of[T] -> generator)
val flatterned: List[q.reflect.TypeRepr] = flatternOrTypeElements(TypeRepr.of[T])
flatterned.foldLeft(r2): (acc, tpe) =>
tpe.asType match
case '[t] => visit[t](acc)
def visitInside[T: Type](acc: GeneratorMap): GeneratorMap =
// val tpe = TypeRepr.of[T] // TODO should we support TypeDefs
TypeRepr.of[T] match
case tpe if tpe <:< TypeRepr.of[List[_]] => visitAppliedType1[T](acc, ListGenerator[T]())
case tpe if tpe <:< TypeRepr.of[Seq[_]] => visitAppliedType1[T](acc, SeqGenerator[T]())
case tpe if tpe <:< TypeRepr.of[Vector[_]] => visitAppliedType1[T](acc, VectorGenerator[T]())
case tpe if tpe <:< TypeRepr.of[Array[_]] => visitAppliedType1[T](acc, ArrayGenerator[T]())
case tpe if tpe <:< TypeRepr.of[Set[_]] => visitAppliedType1[T](acc, SetGenerator[T]())
case tpe if tpe <:< TypeRepr.of[Option[_]] => visitAppliedType1[T](acc, OptionGenerator[T]())
case OrType(_, _) => visitOrType[T](acc)
case _ => visitADT[T](acc)
if TypeRepr.of[T] =:= TypeRepr.of[Null] then acc
else if acc.isEmpty then visitInside[T](acc) // this is the root type, don't summon self, on derived case, it maybe has a non-initialized value
else if acc.contains(TypeRepr.of[T]) then acc // already visited // TDO Type[?] is not a good key
else // a new Type, first summon it, if success, skp it, otherwise, visit it
ADTMappingMacro.NOT_EXPAND_ADTS.set(0)
val found = Expr.summon[JsValueMapper[T]] // here we will not expand ADT, but increase the counter
if ADTMappingMacro.NOT_EXPAND_ADTS.get().nn > 0 then // the type contains ADT need to expand
visitInside[T](acc)
else found match
case Some(_) => acc
case None => visitInside[T](acc)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy