package com.twitter.scalding.db.macros.impl
import scala.annotation.tailrec
import scala.reflect.macros.Context
import scala.util.{ Success, Failure }
import com.twitter.bijection.macros.impl.IsCaseClassImpl
import com.twitter.scalding.db.{ ColumnDefinition, ColumnDefinitionProvider, ResultSetExtractor }
import com.twitter.scalding.db.macros.impl.handler._
// Simple wrapper to pass around the string name format of fields
private[impl] case class FieldName(toStr: String) {
override def toString = toStr
object ColumnDefinitionProviderImpl {
// Takes a type and its companion objects apply method
// based on the args it takes gives back out a field name to symbol
private[this] def getDefaultArgs(c: Context)(tpe: c.Type): Map[String, c.Expr[String]] = {
import c.universe._
val classSym = tpe.typeSymbol
val moduleSym = classSym.companionSymbol
if (moduleSym == NoSymbol) {
c.abort(c.enclosingPosition, s"No companion for case class ${tpe} available. Possibly a nested class? These do not work with this macro.")
// pick the last apply method which (anecdotally) gives us the defaults
// set in the case class declaration, not the companion object
val applyList = moduleSym.typeSignature.declaration(newTermName("apply")).asTerm.alternatives
val apply = applyList.last.asMethod
// can handle only default parameters from the first parameter list
// because subsequent parameter lists might depend on previous parameters{
case (p, i) =>
if (!p.isParamWithDefault) None
else {
val getterName = newTermName("apply$default$" + (i + 1))
Some( -> c.Expr(q"${moduleSym}.$getterName.toString"))
private[scalding] def getColumnFormats[T](c: Context)(implicit T: c.WeakTypeTag[T]): List[ColumnFormat[c.type]] = {
import c.universe._
if (!IsCaseClassImpl.isCaseClassType(c)(T.tpe))
c.abort(c.enclosingPosition, s"""We cannot enforce ${T.tpe} is a case class, either it is not a case class or this macro call is possibly enclosed in a class.
This will mean the macro is operating on a non-resolved type.""")
// Field To JDBCColumn
def matchField(accessorTree: List[MethodSymbol],
oTpe: Type,
fieldName: FieldName,
defaultValOpt: Option[c.Expr[String]],
annotationInfo: List[(Type, Option[Int])],
nullable: Boolean): scala.util.Try[List[ColumnFormat[c.type]]] = {
oTpe match {
// String handling
case tpe if tpe =:= typeOf[String] => StringTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable)
case tpe if tpe =:= typeOf[Short] => NumericTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable, "SMALLINT")
case tpe if tpe =:= typeOf[Int] => NumericTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable, "INT")
case tpe if tpe =:= typeOf[Long] => NumericTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable, "BIGINT")
case tpe if tpe =:= typeOf[Double] => NumericTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable, "DOUBLE")
case tpe if tpe =:= typeOf[Boolean] => NumericTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable, "BOOLEAN")
case tpe if tpe =:= typeOf[java.util.Date] => DateTypeHandler(c)(accessorTree, fieldName, defaultValOpt, annotationInfo, nullable)
case tpe if tpe.erasure =:= typeOf[Option[Any]] && nullable == true =>
Failure(new Exception(s"Case class ${T.tpe} has field ${fieldName} which contains a nested option. This is not supported by this macro."))
case tpe if tpe.erasure =:= typeOf[Option[Any]] && nullable == false =>
if (defaultValOpt.isDefined)
Failure(new Exception(s"Case class ${T.tpe} has field ${fieldName}: ${oTpe.toString}, with a default value. Options cannot have default values"))
else {
matchField(accessorTree, tpe.asInstanceOf[TypeRefApi].args.head, fieldName, None, annotationInfo, true)
case tpe if IsCaseClassImpl.isCaseClassType(c)(tpe) => expandMethod(accessorTree, tpe)
// default
case _ => Failure(new Exception(s"Case class ${T.tpe} has field ${fieldName}: ${oTpe.toString}, which is not supported for talking to JDBC"))
def expandMethod(outerAccessorTree: List[MethodSymbol], outerTpe: Type): scala.util.Try[List[ColumnFormat[c.type]]] = {
val defaultArgs = getDefaultArgs(c)(outerTpe)
// Intializes the type info
// We have to build this up front as if the case class definition moves to another file
// the annotation moves from the value onto the getter method?
val annotationData: Map[String, List[(Type, List[Tree])]] = outerTpe
.map { m =>
val mappedAnnotations = => (t.tpe, t.scalaArgs)) -> mappedAnnotations
}.groupBy(_._1).map {
case (k, l) =>
(k, ++ _))
}.filter {
case (_, v) =>
.collect { case m: MethodSymbol if m.isCaseAccessor => m }
.map { m =>
val fieldName =
val defaultVal = defaultArgs.get(fieldName)
val annotationInfo: List[(Type, Option[Int])] = annotationData.getOrElse(, Nil)
.collect {
case (tpe, List(Literal(Constant(siz: Int)))) if tpe =:= typeOf[com.twitter.scalding.db.macros.size] => (tpe, Some(siz))
case (tpe, _) if tpe =:= typeOf[com.twitter.scalding.db.macros.size] => c.abort(c.enclosingPosition, "Hit a size macro where we couldn't parse the value. Probably not a literal constant. Only literal constants are supported.")
case (tpe, _) if tpe <:< typeOf[com.twitter.scalding.db.macros.ScaldingDBAnnotation] => (tpe, None)
matchField(outerAccessorTree :+ m, m.returnType, FieldName(fieldName), defaultVal, annotationInfo, false)
// This algorithm returns the error from the first exception we run into.
.foldLeft(scala.util.Try[List[ColumnFormat[c.type]]](Nil)) {
case (pTry, nxt) =>
(pTry, nxt) match {
case (Success(l), Success(r)) => Success(l ::: r)
case (f @ Failure(_), _) => f
case (_, f @ Failure(_)) => f
val formats = expandMethod(Nil, T.tpe) match {
case Success(s) => s
case Failure(e) => (c.abort(c.enclosingPosition, e.getMessage))
val duplicateFields = formats
.filter(_._2.size > 1)
if (duplicateFields.nonEmpty) {
Duplicate field names found: ${duplicateFields.mkString(",")}.
Please check your nested case classes.
} else {
def getColumnDefn[T](c: Context)(implicit T: c.WeakTypeTag[T]): List[c.Expr[ColumnDefinition]] = {
import c.universe._
val columnFormats = getColumnFormats[T](c) {
case cf: ColumnFormat[_] =>
val nullableVal = if (cf.nullable)
val fieldTypeSelect = Select(q"", newTermName(cf.fieldType))
val res = q"""new
def getExtractor[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[ResultSetExtractor[T]] = {
import c.universe._
val columnFormats = getColumnFormats[T](c)
val rsmdTerm = newTermName(c.fresh("rsmd"))
// we validate two things from ResultSetMetadata
// 1. the column types match with actual DB schema
// 2. all non-nullable fields are indeed non-nullable in DB schema
val checks = {
case (cf: ColumnFormat[_], pos: Int) =>
val fieldName = cf.fieldName.toStr
val typeNameTerm = newTermName(c.fresh(s"colTypeName_$pos"))
// MySQL uses names like `DATE`, `INTEGER` and `VARCHAR`;
// Vertica uses names like `Date`, `Integer` and `Varchar`
val typeName = q"""
val $typeNameTerm = $rsmdTerm.getColumnTypeName(${pos + 1}).toUpperCase(java.util.Locale.US)
// certain types have synonyms, so we group them together here
// note: this is mysql specific
val typeValidation = cf.fieldType match {
case "VARCHAR" => q"""List("VARCHAR", "CHAR").contains($typeNameTerm)"""
case "BOOLEAN" | "TINYINT" => q"""List("BOOLEAN", "BOOL", "TINYINT").contains($typeNameTerm)"""
case "INT" => q"""List("INTEGER", "INT").contains($typeNameTerm)"""
// In Vertica, `INTEGER`, `INT`, `BIGINT`, `INT8`, `SMALLINT`, and `TINYINT` are all 64 bits
// In MySQL, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `INT`, and `BIGINT` are all <= 64 bits
// As the user has told us this field can store a `BIGINT`, we can safely accept any of these
// types from the database.
case "BIGINT" =>
q"""List("INTEGER", "INT", "BIGINT", "INT8", "SMALLINT",
"TINYINT", "SMALLINT", "MEDIUMINT").contains($typeNameTerm)"""
case f => q"""$f == $typeNameTerm"""
val typeAssert = q"""
if (!$typeValidation) {
throw new
"Mismatched type for column '" + $fieldName + "'. Expected " + ${cf.fieldType} +
" but set to " + $typeNameTerm + " in DB.")
val nullableTerm = newTermName(c.fresh(s"isNullable_$pos"))
val nullableValidation = q"""
val $nullableTerm = $rsmdTerm.isNullable(${pos + 1})
if ($nullableTerm == && ${cf.nullable}) {
throw new
"Column '" + $fieldName + "' is not nullable in DB.")
val rsTerm = newTermName(c.fresh("rs"))
val formats = {
case cf: ColumnFormat[_] => {
val fieldName = cf.fieldName.toStr
// java boxed types needed below to populate cascading's Tuple
val (box: Option[Tree], primitiveGetter: Tree) = cf.fieldType match {
case "VARCHAR" | "TEXT" =>
(None, q"""$rsTerm.getString($fieldName)""")
case "BOOLEAN" | "TINYINT" =>
(Some(q""""""), q"""$rsTerm.getBoolean($fieldName)""")
case "DATE" | "DATETIME" =>
(None, q"""Option($rsTerm.getTimestamp($fieldName)).map { ts => new java.util.Date(ts.getTime) }.orNull""")
// dates set to null are populated as None by tuple converter
// if the corresponding case class field is an Option[Date]
case "DOUBLE" =>
(Some(q""""""), q"""$rsTerm.getDouble($fieldName)""")
case "BIGINT" =>
(Some(q""""""), q"""$rsTerm.getLong($fieldName)""")
case "INT" | "SMALLINT" =>
(Some(q""""""), q"""$rsTerm.getInt($fieldName)""")
case f =>
(None, q"""sys.error("Invalid format " + $f + " for " + $fieldName)""")
// note: UNSIGNED BIGINT is currently unsupported
val valueTerm = newTermName(c.fresh("colValue"))
val boxed = { b => q"""$b($valueTerm)""" }.getOrElse(q"""$valueTerm""")
// primitiveGetter needs to be invoked before we can use wasNull
// to check if the column value that was read is null or not
{ val $valueTerm = $primitiveGetter; if ($rsTerm.wasNull) null else $boxed }
val tcTerm = newTermName(c.fresh("conv"))
val res = q"""
new[$T] {
def validate($rsmdTerm: _root_.scala.util.Try[Unit] = _root_.scala.util.Try { ..$checks }
def toCaseClass($rsTerm: java.sql.ResultSet, $tcTerm:[$T]): $T =
$tcTerm(new _root_.cascading.tuple.TupleEntry(new _root_.cascading.tuple.Tuple(..$formats)))
// ResultSet -> TupleEntry -> case class
def apply[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[ColumnDefinitionProvider[T]] = {
import c.universe._
val columns = getColumnDefn[T](c)
val resultSetExtractor = getExtractor[T](c)
val res = q"""
new[$T] with {
override val columns = List(..$columns)
override val resultSetExtractor = $resultSetExtractor