![JAR search and dependency download from the Maven repository](/logo.png)
com.outworkers.phantom.udt.macros.UdtRootMacro.scala Maven / Gradle / Ivy
package com.outworkers.phantom.udt.macros
import java.nio.BufferUnderflowException
import com.datastax.driver.core.exceptions.InvalidTypeException
import com.outworkers.phantom.macros.RootMacro
import com.outworkers.phantom.udt.{Udt, UdtRoot}
import scala.reflect.macros.whitebox
import scala.collection.compat._
@macrocompat.bundle
trait UdtRootMacro extends RootMacro {
val c: whitebox.Context
import c.universe._
val udtPackage = q"com.outworkers.phantom.udt"
val prefix = q"com.outworkers.phantom.udt"
val keySpaceTpe = tq"com.outworkers.phantom.dsl.KeySpace"
val udtValueTpe = tq"com.datastax.driver.core.UDTValue"
def typed[A : c.WeakTypeTag]: Symbol = weakTypeOf[A].typeSymbol
val executableQuery: Type = typeOf[com.outworkers.phantom.builder.query.execution.ExecutableCqlQuery]
val emptyOptions = q"com.outworkers.phantom.builder.query.QueryOptions.empty"
object Symbols {
val listSymbol: Symbol = typed[scala.collection.immutable.List[_]]
val setSymbol: Symbol = typed[scala.collection.immutable.Set[_]]
val mapSymbol: Symbol = typed[scala.collection.immutable.Map[_, _]]
val udtSymbol: Symbol = typed[com.outworkers.phantom.udt.UDTPrimitive[_]]
}
/**
* Checks if the field type on a case class is an UDT type.
* This logic requires deep checking to deal with the fact that we provide DSL level [[scala.Option]]
* support for all Cassandra types.
*
* Nested UDT types require freezing on Cassandra and if the type is wrapped inside an [[Option]] it is not enough
* to check if the type is a case class. Simply checking if the type is a case class can also mislead the
* macro generation and prevent it from using auto-tupled case class primitives with support available via
* phantom auto-tables.
* @param tpe The type of the field to check for annotations.
* @return True if the type contains a UDT annotation or if the type is an Option wrapping an UDT annotated type.
*/
def deepCaseClass(tpe: Type): Boolean = {
if (tpe <:< typeOf[scala.Option[_]]) {
isCaseClass(tpe.typeArgs.head)
} else {
isCaseClass(tpe)
}
}
def hasAnnotation(tpe: Type): Boolean = {
tpe.typeSymbol.typeSignature
val annotations = tpe.typeSymbol.annotations
val out = annotations.collect {
case annot if annot.tree.tpe <:< c.weakTypeOf[UdtRoot] => annot
}
out.nonEmpty
}
case class Accessor(
name: TermName,
paramType: Type,
annotations: List[Tree] = Nil
) {
def tpe: TypeName = symbol.name.toTypeName
def symbol: Symbol = paramType.typeSymbol
}
object Accessors {
def apply[M[X] <: IterableOnce[X]](source: M[(Name, Type)])(
implicit cbf: Factory[Accessor, M[Accessor]]
): M[Accessor] = {
val builder = cbf.newBuilder
for ((nm, tp) <- source.iterator) builder += Accessor(nm.toTermName, tp, Nil)
builder.result()
}
}
/**
* Retrieves the accessor fields on a case class and returns an iterable of tuples of the form Name -> Type.
* For every single field in a case class, a reference to the string name and string type of the field are returned.
*
* Example:
*
* {{{
* case class Test(id: UUID, name: String, age: Int)
*
* accessors(Test) = Iterable("id" -> "UUID", "name" -> "String", age: "Int")
* }}}
*
* @param params The list of params retrieved from the case class.
* @return An iterable of tuples where each tuple encodes the string name and string type of a field.
*/
def accessors(
params: Seq[ValDef]
): Iterable[Accessor] = {
params.map {
case ValDef(mods, name: TermName, tpt: Tree, tree: Tree) => {
val tpe = c.typecheck(tq"$tpt", c.TYPEmode).tpe
Accessor(name, tpe, mods.annotations)
}
}
}
/**
* The base implementation of an UDT extractor derived from a case class field.
*/
trait Extractor {
/**
* The corresponding case class accessor that encloses the relevant type information.
* This will include information about the name and the type of the underlying case class field.
* @return A reference to a case class accessor.
*/
def accessor: Accessor
/**
* The string name of the case class field.
* This is used to determine the name of the columns in Cassandra.
* We will use the same name for the UDT columns as the name typed into the case classes
* by uers.
* @return A string name for the case filed.
*/
def field: String = accessor.name.decodedName.toString
/**
* The string name of the case class field.
* This is used to determine the name of the columns in Cassandra.
* We will use the same name for the UDT columns as the name typed into the case classes
* by uers.
* @return A string name for the case filed.
*/
def schemaField: Tree = {
if (hasAnnotation(accessor.paramType)) {
q"$udtPackage.UDTPrimitive[${accessor.paramType}].name"
} else {
q"${accessor.name.decodedName.toString}"
}
}
/**
* The name of the field extractor used in the generated for yield expression.
* This will need to be the same name across the 2 places where it is required,
* namely the generated for yield block as well as the companion apply block.
*
* Let's look at how the macro expands the annotated case classes with respect to extracting
* the relevant type information.
*
* Example: {{{
* @Udt case class Test(id: UUID, text: String)
*
* new UDTPrimitive[Test] {
* def fromRow(udt: UDTValue): Option[Test] = {
* for {
* idOpt <- Primitive[UUID].fromRow("id", udt)
* nameOpt <- Primitive[String].fromRow("name", udt)
* } yield Test.apply(idOpt, nameOpt)
* }
* }
*
* }}}
* @return
*/
def extractorName: TermName = TermName(field + "Opt")
def typeQualifier: Tree
/**
* Whether or not this extractor is derived for an underlying UDT type.
* The parent type of the extractor is always an UDT, but this property marks
* whether or not we have identified the current field/accessor is also a case class
* that needs to be treated as an UDT type.
* @return A boolean value that marks whether or not the current extractor should be
* derived for an UDT type.
*/
def isUdtType: Boolean = deepCaseClass(accessor.paramType)
/**
* The tree enclosing the definition of the Cassandra type information of the extractor.
* This will be provided by implementors as it based on a combination of inner and outer types.
* @return The CQL type tree that will be used to produce the CQL type string during schema auto-generation.
*/
def cassandraType: Tree
def schema: Tree = q"""$field -> $cassandraType"""
def serializer: Tree
}
case class RegularExtractor(
accessor: Accessor,
override val typeQualifier: Tree
) extends Extractor {
def cassandraType: Tree = if (isUdtType) {
val udtName = q"$prefix.UDTPrimitive[${accessor.paramType}].name"
q"$prefix.Helper.frozen($udtName)"
} else {
q"$typeQualifier.cassandraType"
}
override def serializer: Tree = q"""$field -> $typeQualifier.asCql(instance.${accessor.name})"""
}
case class NestedExtractor(
accessor: Accessor,
override val typeQualifier: Tree
) extends Extractor {
def cassandraType: Tree = if (isUdtType) {
val udtName = q"$prefix.UDTPrimitive[${accessor.paramType}].name"
q"$prefix.Helper.frozen($udtName)"
} else {
q"$typeQualifier.cassandraType"
}
override def serializer: Tree = q"""$field -> instance.${accessor.name}.fold("null")(item => $typeQualifier.asCql(item))"""
}
/**
* Derives the full primitive type for a given root type.
* This will not yield a method call of any kind, it will simply return
* a reference to the right primitive.
* @param accessor The root type to compute the primitive type for.
* @return
*/
private[this] def derivePrimitive(accessor: Accessor): Extractor = {
if (accessor.paramType <:< typeOf[scala.Option[_]] && deepCaseClass(accessor.paramType)) {
val nestedType = accessor.paramType.typeArgs.head
NestedExtractor(
accessor.copy(paramType = nestedType),
q"$primitivePkg.Primitive[$nestedType]"
)
} else {
RegularExtractor(
accessor,
q"$primitivePkg.Primitive[${accessor.paramType}]"
)
}
}
def isCaseClass(sym: Symbol): Boolean = {
sym.isClass && sym.asClass.isCaseClass
}
def isCaseClass(tpe: Type): Boolean = isCaseClass(tpe.typeSymbol)
def udtDeps(params: Iterable[Accessor]): Seq[Type] = {
params.foldLeft(Seq.empty[Type]) { case (acc, accessor) =>
if (isCaseClass(accessor.paramType)) {
acc :+ accessor.paramType
} else {
accessor.symbol match {
case Symbols.listSymbol | Symbols.setSymbol | Symbols.mapSymbol =>
acc ++ accessor.paramType.typeArgs.filter(isCaseClass)
case _ => acc
}
}
}
}
def typeDependencies(params: Iterable[Accessor]): Seq[Tree] = {
udtDeps(params) map (tp => q"$udtPackage.UDTPrimitive[$tp]")
}
def queryDependencies(params: Iterable[Accessor]): Seq[Tree] = {
udtDeps(params) map (tp => q"""$udtPackage.UDTPrimitive[$tp].schemaQuery""")
}
def elTerm(i: Int): TermName = TermName(s"el$i")
def fieldTerm(i: Int): TermName = TermName(s"n$i")
def fqTerm(i: Int): TermName = TermName(s"fq$i")
private[this] val boolType = tq"_root_.scala.Boolean"
private[this] val codecUtils = q"_root_.com.datastax.driver.core.CodecUtils"
private[this] val pVersion = tq"_root_.com.datastax.driver.core.ProtocolVersion"
private[this] val bufferType = tq"_root_.java.nio.ByteBuffer"
private[this] val bufferCompanion = q"_root_.java.nio.ByteBuffer"
private[this] val bufferException = typeOf[BufferUnderflowException]
private[this] val invalidTypeException = typeOf[InvalidTypeException]
def udtPrimitive(tpe: TypeName, stringName: String, params: Iterable[Accessor]): Tree = {
val sourceTerm = TermName("source")
val versionTerm = TermName("version")
val source = params
val indexedFields = source.zipWithIndex
val udtFields = source map derivePrimitive
val sizeComp = indexedFields.map { case (vd, i) =>
val term = elTerm(i)
fq"""
$term <- {
val $term = $primitivePkg.Primitive[${vd.paramType}].serialize(source.${vd.name}, $versionTerm)
Some((4 + { if ($term == null) 0 else $term.remaining()}))
}
"""
}
val serializedComponents = indexedFields.map { case (vd, i) =>
fq""" ${elTerm(i)} <- {
val serialized = $primitivePkg.Primitive[${vd.paramType}].serialize(source.${vd.name}, $versionTerm)
val buf = if (serialized == null) {
res.putInt(-1)
} else {
res.putInt(serialized.remaining())
res.put(serialized.duplicate())
}
Some(buf)
}
"""
}
val inputTerm = TermName("input")
val deserializedFields = indexedFields.map { case (vd, i) =>
val tm = fieldTerm(i)
val el = elTerm(i)
fq"""
${fqTerm(i)} <- {
val $tm = $inputTerm.getInt()
val $el = if ($tm < 0) { null } else { $codecUtils.readBytes($inputTerm, $tm) }
Some($primitivePkg.Primitive[${vd.paramType}].deserialize($el, $versionTerm))
}
"""
}
val sumTerm = indexedFields.foldRight(q"") { case ((_, pos), acc) =>
acc match {
case t if t.isEmpty => q"${elTerm(pos)}"
case _ => q"${elTerm(pos)} + $acc"
}
}
val extractorTerms = indexedFields.map { case (_, i) => fqTerm(i) }
val fieldExtractor = q"for (..$deserializedFields) yield new $tpe(..$extractorTerms)"
val tree = q"""
new $prefix.UDTPrimitive[$tpe] {
def name: $strTpe = $stringName
def deps()(implicit space: $keySpaceTpe): $collections.Seq[$udtPackage.UDTPrimitive[_]] = {
$collections.Seq.apply[$prefix.UDTPrimitive[_]](..${typeDependencies(source)})
}
def typeDependencies()(implicit space: $keySpaceTpe): $collections.Seq[$executableQuery] = {
$collections.Seq.apply[$executableQuery](..${queryDependencies(source)})
}
override def asCql(instance: $tpe): String = {
val baseString = $collections.List(..${udtFields.map(_.serializer)}).map {
case (name, ext) => name + ": " + ext
} mkString(", ")
"{" + baseString + "}"
}
def schemaQuery()(implicit space: $keySpaceTpe): $executableQuery = {
val membersList = $collections.List(..${udtFields.map(_.schema)}).map {
case (name, casType) => name + " " + casType
}
val base = "CREATE TYPE IF NOT EXISTS " + space.name + "." + ${stringName.toLowerCase} + " (" + membersList.mkString(", ") + ")"
new $executableQuery(new $enginePkg.CQLQuery(base), $emptyOptions, Nil)
}
override def dataType: $strTpe = $stringName
override def serialize($sourceTerm: $tpe, $versionTerm: $pVersion): $bufferType = {
if ($sourceTerm == null) {
null
} else {
val size = {for (..$sizeComp) yield ($sumTerm) } get
val res = $bufferCompanion.allocate(size)
val buf = for (..$serializedComponents) yield ()
buf.get
res.flip().asInstanceOf[$bufferType]
}
}
override def deserialize($sourceTerm: $bufferType, $versionTerm: $pVersion): $tpe = {
if ($sourceTerm == null) {
null
} else {
try {
val $inputTerm = $sourceTerm.duplicate()
$fieldExtractor.get
} catch {
case e: $bufferException =>
throw new $invalidTypeException("Not enough bytes to deserialize an UDT value", e)
}
}
}
override def frozen: $boolType = true
override def shouldFreeze: $boolType = true
}"""
if (showTrees) {
c.echo(c.enclosingPosition, showCode(tree))
}
tree
}
def makePrimitive(
tpe: TypeName,
name: TermName,
params: List[ValDef]
): Tree = {
val stringName = name.decodedName.toString
val objName = TermName(c.freshName(stringName + "_udt_primitive"))
val primitive = udtPrimitive(tpe, stringName, accessors(params))
val tree = q"""
implicit val $objName: $prefix.UDTPrimitive[$tpe] = $primitive
"""
if (showTrees) {
c.echo(c.enclosingPosition, s"Generated tree for $tpe:\n${showCode(tree)}")
}
tree
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy