![JAR search and dependency download from the Maven repository](/logo.png)
com.schuwalow.delegate.Macros.scala Maven / Gradle / Ivy
package zio.delegate
import scala.reflect.macros.blackbox.Context
import scala.reflect.macros.TypecheckException
private[delegate] class Macros(val c: Context) {
import c.universe._
def mixImpl[A: WeakTypeTag, B: WeakTypeTag]: c.Tree = {
val aTT = weakTypeOf[A]
val bTT = weakTypeOf[B]
val bTTComps = getTypeComponents(bTT) // we need do this because refinements does not count as a trait
// aT may extends a class bT may not as it will be mixed in
preconditions(
(!aTT.typeSymbol.isFinal -> s"${aTT.typeSymbol.toString()} must be nonfinal class or trait.") ::
bTTComps.map(t => t.typeSymbol.asClass.isTrait -> s"${t.typeSymbol.toString()} needs to be a trait."): _*
)
val aName = TermName(c.freshName("a"))
val bName = TermName(c.freshName("b"))
val resultType = {
val candidate =
s"${(getTypeComponents(aTT) ++ bTTComps).map(t => localName(t.typeSymbol.asClass)).mkString(" with ")}"
parseTypeString(candidate).fold(e => abort(s"Failed typechecking calculated type $candidate: $e"), identity)
}
val body = {
val methods = overlappingMethods(aTT, resultType).map((_, aName)).toMap ++
overlappingMethods(bTT, resultType).map((_, bName)).toMap
methods.filterNot { case (m, _) => isObjectMethod(m) }.map {
case (m, owner) => delegateMethodDef(m, owner)
}
}
val resultTypeName = TypeName(c.freshName("result"))
q"""
${c.parse(s"abstract class $resultTypeName extends $resultType")}
new Mix[$aTT, $bTT] {
def mix($aName: $aTT, $bName: $bTT): ${resultType} = {
new ${resultTypeName} {
..$body
}
}
}
"""
}
def delegateImpl(annottees: c.Expr[Any]*): c.Tree = {
case class Arguments(verbose: Boolean, forwardObjectMethods: Boolean, generateTraits: Boolean)
val args: Arguments = c.prefix.tree match {
case Apply(_, args) =>
val verbose: Boolean = args.collectFirst {
case q"verbose = $cfg" =>
c.eval(c.Expr[Boolean](cfg))
}.getOrElse(false)
val forwardObjectMethods = args.collectFirst {
case q"forwardObjectMethods = $cfg" =>
c.eval(c.Expr[Boolean](cfg))
}.getOrElse(false)
val generateTraits = args.collectFirst {
case q"generateTraits = $cfg" =>
c.eval(c.Expr[Boolean](cfg))
}.getOrElse(true)
Arguments(verbose, forwardObjectMethods, generateTraits)
case other => abort("not possible - macro invoked on type that does not have @delegate: " + showRaw(other))
}
def isBlackListed(m: MethodSymbol) =
if (!args.forwardObjectMethods) isObjectMethod(m) else false
def modifiedClass(classDecl: ClassDef, delegateTo: ValDef): c.Tree = {
val q"..$mods class $className(..$fields) extends ..$bases { ..$body }" = classDecl
val existingMethods = body
.flatMap(
tree =>
tree match {
case a @ DefDef(_, n, _, _, _, _) => Some(n)
case a @ ValDef(_, n, _, _) => Some(n)
case _ => None
}
)
.toSet
val (toName, toType) = typeCheckVal(delegateTo)
.fold(e => abort(s"Failed typechecking annotated member. Is it defined in local scope?: $e"), identity)
val additionalTraits =
if (args.generateTraits)
getTraits(toType) -- bases.flatMap(b => getTraits(c.typecheck(b, c.TYPEmode).tpe)).toSet
else Set.empty
val resultType = {
val candidate = (bases.map(_.toString()) ++ additionalTraits.map(localName).toList).mkString(" with ")
parseTypeString(candidate).fold(e => abort(s"Failed typechecking calculated type $candidate: $e"), identity)
}
val extensions = overlappingMethods(toType, resultType, !isBlackListed(_))
.filterNot(m => existingMethods.contains(m.name))
.map(delegateMethodDef(_, toName))
val resultTypeName = TypeName(c.freshName)
q"""
${c.parse(s"abstract class $resultTypeName extends $resultType")}
$mods class $className(..$fields) extends $resultTypeName { ..${body ++ extensions} }
"""
}
annottees.map(_.tree) match {
case (valDecl: ValDef) :: (classDecl: ClassDef) :: Nil =>
val modified = modifiedClass(classDecl, valDecl)
if (args.verbose) showInfo(modified.toString())
modified
case _ => abort("Invalid annottee")
}
}
final private[this] def delegateMethodDef(m: MethodSymbol, to: TermName) = {
val name = m.name
val rType = m.returnType
val mods =
if (!m.isAbstract) Modifiers(Flag.OVERRIDE)
else Modifiers()
if (m.paramss.isEmpty) {
q"$mods val $name: $rType = $to.$name"
} else {
val typeParams = m.typeParams.map(internal.typeDef(_))
val paramLists = m.paramLists.map(_.map(internal.valDef(_)))
q"""
$mods def $name[..${typeParams}](...$paramLists): $rType = {
$to.${name}(...${paramLists.map(_.map(_.name))})
}
"""
}
}
final private[this] def isObjectMethod(m: MethodSymbol): Boolean =
Set(
"java.lang.Object.clone",
"java.lang.Object.hashCode",
"java.lang.Object.finalize",
"java.lang.Object.equals",
"java.lang.Object.toString",
"scala.Any.getClass"
).contains(m.fullName)
final private[this] def getTraits(t: Type): Set[ClassSymbol] = {
def loop(stack: List[ClassSymbol], traits: Vector[ClassSymbol] = Vector()): Vector[ClassSymbol] = stack match {
case x :: xs =>
loop(xs, if (x.isTrait) traits :+ x else traits)
case Nil => traits
}
loop(t.baseClasses.map(_.asClass)).toSet
}
final private[this] val typeCheckVal: ValDef => Either[TypecheckException, (TermName, Type)] = {
case ValDef(_, tname, tpt, _) =>
val tpe = try {
Right(c.typecheck(tpt.duplicate, c.TYPEmode).tpe)
} catch {
case e: TypecheckException => Left(e)
}
tpe.right.map((tname, _))
}
final private[this] def parseTypeString(str: String): Either[TypecheckException, Type] =
try {
Right(c.typecheck(c.parse(s"null.asInstanceOf[$str]"), c.TYPEmode).tpe)
} catch {
case e: TypecheckException => Left(e)
}
final private[this] def localName(symbol: ClassSymbol): String =
parseTypeString(symbol.fullName).fold(
_ => {
val path = "_root_" +: symbol.fullName.split('.')
path
.zip(("_root_" +: enclosing.split('.')).take(path.length - 1).padTo(path.length, ""))
.dropWhile { case ((l, r)) => l == r }
.map(_._1)
.mkString(".")
},
_ => symbol.fullName
)
final private[this] val enclosing: String = c.enclosingClass match {
case clazz if clazz.isEmpty => c.enclosingPackage.symbol.fullName
case clazz => clazz.symbol.fullName
}
final private[this] def overlappingMethods(
from: Type,
to: Type,
filter: MethodSymbol => Boolean = _ => true
): Set[MethodSymbol] = {
def isVisible(m: MethodSymbol) =
m.isPublic || enclosing.startsWith(m.privateWithin.fullName)
to.baseClasses
.map(_.asClass.selfType)
.filter(from <:< _)
.flatMap { s =>
s.members
.flatMap(m => to.member(m.name).alternatives.map(_.asMethod).find(_ == m))
.filter(m => !m.isConstructor && !m.isFinal && isVisible(m) && filter(m))
}
.toSet
}
final private[this] def showInfo(s: String) =
c.info(c.enclosingPosition, s.split("\n").mkString("\n |---macro info---\n |", "\n |", ""), true)
final private[this] def abort(s: String) =
c.abort(c.enclosingPosition, s)
final private[this] def preconditions(conds: (Boolean, String)*): Unit =
conds.foreach {
case (cond, s) =>
if (!cond) abort(s)
}
final private[this] def getTypeComponents(t: Type): List[Type] = t.dealias match {
case RefinedType(parents, _) => parents.flatMap(p => getTypeComponents(p))
case t => List(t)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy