All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.outworkers.phantom.macros.RootMacro.scala Maven / Gradle / Ivy

/*
 * Copyright 2013 - 2020 Outworkers Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.outworkers.phantom.macros

import com.outworkers.phantom.builder.query.sasi.Mode
import com.outworkers.phantom.column.AbstractColumn
import com.outworkers.phantom.connectors.KeySpace
import com.outworkers.phantom.keys.SASIIndex
import com.outworkers.phantom.macros.toolbelt.{HListHelpers, WhiteboxToolbelt}
import com.outworkers.phantom.{CassandraTable, SelectTable}

import scala.collection.compat._
import scala.collection.immutable.ListMap
import scala.reflect.macros.whitebox
import scala.Iterable

@macrocompat.bundle
trait RootMacro extends HListHelpers with WhiteboxToolbelt {
  val c: whitebox.Context
  import c.universe._

  protected[this] val primitivePkg = q"_root_.com.outworkers.phantom.builder.primitives"
  protected[this] val rowType = tq"_root_.com.outworkers.phantom.Row"
  protected[this] val builder = q"_root_.com.outworkers.phantom.builder.QueryBuilder"
  protected[this] val macroPkg = q"_root_.com.outworkers.phantom.macros"
  protected[this] val builderPkg = q"_root_.com.outworkers.phantom.builder.query"
  protected[this] val enginePkg = q"_root_.com.outworkers.phantom.builder.query.engine"
  protected[this] val strTpe = tq"_root_.java.lang.String"
  protected[this] val colType = typeOf[com.outworkers.phantom.column.AbstractColumn[_]]
  protected[this] val sasiIndexTpe = typeOf[SASIIndex[_ <: Mode]]
  protected[this] val collections = q"_root_.scala.collection.immutable"
  protected[this] val rowTerm = TermName("row")
  protected[this] val tableTerm = TermName("table")
  protected[this] val inputTerm = TermName("input")
  protected[this] val keyspaceType = typeOf[KeySpace]
  protected[this] val nothingTpe: Type = typeOf[scala.Nothing]

  val knownList = List("Any", "Object", "RootConnector")

  val tableSym: Symbol = typeOf[CassandraTable[_, _]].typeSymbol
  val selectTable: Symbol = typeOf[SelectTable[_, _]].typeSymbol
  val rootConn: Symbol = typeOf[SelectTable[_, _]].typeSymbol
  val colSymbol: Symbol = typeOf[AbstractColumn[_]].typeSymbol

  val notImplementedName: TermName = TermName("???")
  val notImplemented: Symbol = typeOf[Predef.type].member(notImplementedName)
  val fromRowName: TermName = TermName("fromRow")

  trait RootField {
    def name: TermName

    def tpe: Type

    def debugString: String = s"${q"$name"} : ${printType(tpe)}"
  }

  object Record {
    case class Field(name: TermName, tpe: Type, index: Int) extends RootField
  }

  object Column {
    case class Field(name: TermName, tpe: Type) extends RootField
  }

  def caseFields(tpe: Type): Seq[(Name, Type)] = {
    object CaseField {
      def unapply(arg: TermSymbol): Option[(Name, Type)] = {
        if (arg.isVal && arg.isCaseAccessor) {
          Some(TermName(arg.name.toString.trim) -> arg.typeSignature)
        } else {
          None
        }
      }
    }

    tpe.decls.toSeq.collect { case CaseField(name, fType) => name -> fType }
  }

  implicit class FieldOps(val col: Seq[RootField]) {
    def typeMap: ListMap[Type, Seq[TermName]] = {
      col.foldLeft(ListMap.empty[Type, Seq[TermName]]) { case (acc, f) =>
        acc + (f.tpe -> (acc.getOrElse(f.tpe, Seq.empty[TermName]) :+ f.name))
      }
    }
  }

  def tupleTerm(index: Int, aug: Int = 1): TermName = {
    TermName("_" + (index + aug).toString)
  }

  trait RecordMatch

  case class Unmatched(
    field: Record.Field,
    reason: String = ""
  ) extends RecordMatch

  case class MatchedField(
    left: Record.Field,
    right: Column.Field
  ) extends RecordMatch {
    def column: Column.Field = right

    def record: Record.Field = left
  }

  implicit class ListMapOps[K, V, M[X] <: Iterable[X]](
    val lm: ListMap[K, M[V]]
  )(implicit cbf: Factory[V, M[V]]) {

    /**
      * Every entry in this ordered map is a traversable of type [[M]].
      * That means every key holds a sequence of elements.
      * This function will remove the element [[elem]] from that sequence
      * for the provided key.
      */
    def remove(key: K, elem: V): ListMap[K, M[V]] = {
      lm.get(key) match {
        case Some(col) => lm + (key -> cbf.fromSpecific(col.filterNot(elem ==)))
        case None => lm
      }
    }
  }

  case class TableDescriptor(
    tableTpe: Type,
    recordType: Type,
    members: Seq[Column.Field],
    matches: Seq[RecordMatch] = Nil
  ) {

    def unmatchedColumns: Seq[Column.Field] = {
      members.filterNot(m => matched.exists(r => r.right.name == m.name))
    }

    def withMatch(m: RecordMatch): TableDescriptor = {
      this.copy(matches = matches :+ m)
    }

    /**
      * This is just done for the naming convenience, but the functionality of distinguishing between
      * matched and unmatched is implemented
      * using an ADT and collect, so it doesn't actually matter if we append to the same place.
      *
      * @param m The record match.
      * @return An immutable copy of the table descriptor with one extra unmatched record.
      */
    def withoutMatch(m: RecordMatch): TableDescriptor = withMatch(m)

    def unmatched: Seq[Unmatched] = matches.collect {
      case u: Unmatched => u
    }

    def matched: Seq[MatchedField] = matches.collect {
      case m: MatchedField => m
    }

    def fromRow: Option[Tree] = {
      if (unmatched.isEmpty) {
        val columnNames = matched.sortBy(_.left.index).map { m =>
          q"$tableTerm.${m.right.name}.apply($rowTerm)"
        }

        Some(q"""new $recordType(..$columnNames)""")
      } else {
        None
      }
    }

    def debugList(fields: Seq[RootField]): Seq[String] = fields.map(u =>
      s"${u.name.decodedName}: ${printType(u.tpe)}"
    )

    /**
      * Creates a map to show users how record fields map to columns inside the table.
      * This is done when they want to inspect the generated macro trees and report
      * bugs and as a convenience feature for us at debugging time.
      * @return An interpolated quoted tree that contains a [[Map[String, String]] definition.
      */
    def debugMap: Tree = {
      val tuples = matched.map(m => {
        val recordTerm = m.record.name.decodedName.toString
        val colTerm = m.record.name.decodedName.toString
        val recordType = printType(m.record.tpe)
        val colType = printType(m.column.tpe)

        q"""
           _root_.scala.Tuple2($recordTerm + ":" + $recordType, $colTerm + ":" + $colType)
        """
      })

      q"_root_.scala.collection.immutable.Map.apply[String, String](..$tuples)"
    }

    /**
      * The reference term is a tuple field pointing to the tuple index found on a store type.
      * If the Cassandra table has more columns than the record field, such as when users
      * chose to store a denormalised variant of a record indexed by a new ID, the store
      * input type will become a tuple of that ID and the record type.
      *
      * So in effect: {{{
      *   case class Record(name: String, timestamp: DateTime)
      *
      *   class Records extends CassandraTable[Records, Record] {
      *
      *     object id extends UUIDColumn with PartitionKey
      *     object name extends StringColumn with PrimaryKey
      *     object timestamp extends DateTimeColumn
      *
      *     // Will end up with a store method that has the following type signature.
      *     def store(input: (UUID, Record)): InsertQuery.Default[Records, Record]
      *   }
      * }}}
      *
      * In these scenarios, we need a way to refer to [[input._index]] as part of the generated
      * store method, where the numerical value of the tuple index is equal to the number of
      * unmatched columns(found in the table but not the record) plus one more for the record type
      * itself and another to compensate for tuples being indexed from 1 instead of 0.
      *
      * @return An optional [[TermName]] of the form [[TermName]]
      */
    val referenceTerm: Option[Tree] = {
      if (unmatchedColumns.isEmpty) {
        Some(hlistNatRef(0))
      } else {
        Some(hlistNatRef(unmatchedColumns.size))
      }
    }

    protected[this] def unmatchedValue(field: Column.Field, ref: Tree): Tree = {
      q"$enginePkg.CQLQuery($tableTerm.${field.name}.asCql($ref))"
    }

    protected[this] def valueTerm(field: MatchedField, refTerm: Option[Tree]): Tree = {
      refTerm match {
        case Some(ref) => q"$enginePkg.CQLQuery($tableTerm.${field.right.name}.asCql($ref.${field.left.name}))"
        case None => q"$enginePkg.CQLQuery($tableTerm.${field.right.name}.asCql($inputTerm.${field.left.name}))"
      }
    }

    /**
      * Short cut method to create a full CQL query using the a particular column
      * inside a table. This will create something like the folloing:
      * {{{
      *  com.outworkers.phantom.
      * }}}
      */
    def tableField(fieldName: TermName): Tree = {
      q"$enginePkg.CQLQuery($tableTerm.$fieldName.name)"
    }

    def hlistNatRef(index: Int): Tree = {
      val indexTerm = TermName("_" + index.toString)

      q"$inputTerm.apply(_root_.shapeless.Nat.$indexTerm)"
    }

    def storeMethod: Option[Tree] = storeType flatMap { sTpe =>
      if (unmatched.isEmpty) {
        val unmatchedColumnInserts = unmatchedColumns.zipWithIndex map { case (field, index) =>
          q"${tableField(field.name)} -> ${unmatchedValue(field, hlistNatRef(index))}"
        }

        val insertions = matched map { field =>
          q"${tableField(field.right.name)} -> ${valueTerm(field, referenceTerm)}"
        }

        val finalDefinitions = unmatchedColumnInserts ++ insertions

        info(s"Inferred store input type: ${printType(sTpe)} for ${printType(tableTpe)}")

        val tree = q"""$tableTerm.`insertValues`(..$finalDefinitions)"""
        Some(tree)
      } else {
        None
      }
    }

    def hListStoreType: Option[Type] = {
      if (unmatchedColumns.isEmpty) {
        Some(mkHListType(List(recordType)))
      } else {
        c.warning(
          c.enclosingPosition,
          s"Found unmatched columns for ${printType(tableTpe)}: ${debugList(unmatchedColumns)}"
        )

        val cols = unmatchedColumns.map(_.tpe) :+ recordType

        if (cols.size > maxTupleSize) {
          c.warning(
            c.enclosingPosition,
            s"Created an HList type of ${cols.size} fields, consider reducing the column count for clarity"
          )
        }

        Some(mkHListType(cols.toList))
      }
    }

    private[this] val maxTupleSize = 22

    /**
     * Automatically creates a [[shapeless.HList]] from the types found in a table as described in the documentation.
     */
    def storeType: Option[Type] = {
      if (unmatchedColumns.isEmpty) {
        Some(mkHListType(recordType :: Nil))
      } else {
        info(s"Found unmatched columns for ${printType(tableTpe)}: ${debugList(unmatchedColumns)}")

        val cols = unmatchedColumns.map(_.tpe) :+ recordType

        if (cols.size > maxTupleSize) {
          c.warning(
            c.enclosingPosition,
            s"Table ${printType(tableTpe)} has ${cols.size} fields, consider reducing the number of columns"
          )
        }
        Some(mkHListType(cols.toList))
      }
    }

    def showExtractor: String = matched.map(f =>
      s"rec.${f.left.name} -> table.${f.right.name} | ${printType(f.right.tpe)}"
    ) mkString "\n"
  }

  object TableDescriptor {
    def empty(table: Type, rec: Type, members: Seq[Column.Field]): TableDescriptor = {
      new TableDescriptor(table, rec, members) {
        override def storeMethod: Option[c.universe.Tree] = None
        override def storeType: Option[Type] = None
        override def fromRow: Option[Tree] = None
      }
    }
  }

  /**
    * A "generic" type extractor that's meant to produce a list of fields from a record type.
    * We support a narrow domain of types for automated generation, currently including:
    * - Case classes
    * - Tuples
    *
    * To achieve this, we simply have specific ways of extracting the types from the underlying records,
    * and producing a [[Record.Field]] for each of the members in the product type,
    *
    * @param tpe The underlying record type that was passed as the second argument to a Cassandra table.
    * @return An iterable of fields, each containing a [[TermName]] and a [[Type]] that describe a record member.
    */
  def extractRecordMembers(tpe: Type): Seq[Record.Field] = {
    tpe.typeSymbol match {
      case sym if sym.fullName.startsWith("scala.Tuple") =>
        (Seq.tabulate(tpe.typeArgs.size)(identity) map {
          index => tupleTerm(index)
        } zip tpe.typeArgs).zipWithIndex map { case ((term, tp), index) =>
          Record.Field(term, tp, index)
        }

      case sym if sym.isClass && sym.asClass.isCaseClass =>
        caseFields(tpe).zipWithIndex map { case ((nm, tp), i) =>
          Record.Field(nm.toTermName, tp, i)
        }

      case _ => Seq.empty[Record.Field]
    }
  }

  def filterMembers[Filter : TypeTag](
    tpe: Type,
    exclusions: Symbol => Option[Symbol]
  ): Seq[Symbol] = {
    (
      for {
        baseClass <- tpe.baseClasses.reverse.flatMap(exclusions(_))
        symbol <- baseClass.typeSignature.members.sorted
        if symbol.typeSignature <:< typeOf[Filter]
      } yield symbol
    ).distinct
  }

  def filterMembers[T : WeakTypeTag, Filter : TypeTag](
    exclusions: Symbol => Option[Symbol] = { s: Symbol => Some(s) }
  ): Seq[Symbol] = {
    filterMembers[Filter](weakTypeOf[T], exclusions)
  }

  def filterColumns[Filter : TypeTag](columns: Seq[Type]): Seq[Type] = {
    columns.filter(_.baseClasses.exists(typeOf[Filter].typeSymbol ==))
  }


  def filterColumns(columns: Seq[Type], filter: Type): Seq[Type] = {
    columns.filter(t => t <:< filter)
  }

  def extractColumnMembers(table: Type, columns: List[Symbol]): List[Column.Field] = {
    /**
      * We filter for the members of the table type that
      * directly subclass [[AbstractColumn[_]]. For every one of those methods, we
      * are going to look at what type argument was passed by the specific column definition
      * when extending [[AbstractColumn[_]] as this will tell us the Scala output type
      * of the given column.
      * We create a list of these types and if they match the case class types expected,
      * it means we can auto-generate a fromRow implementation.
      */
    columns.map { member =>
      val memberType = member.typeSignatureIn(table)

      memberType.baseClasses.find(colSymbol ==) match {
        case Some(root) =>
          // Here we expect to have a single type argument or type param
          // because we know root here will point to an AbstractColumn[_] symbol.
          root.typeSignature.typeParams match {
            // We use the special API to see what type was passed through to AbstractColumn[_]
            // with special thanks to https://github.com/joroKr21 for helping me not rip
            // off the remainder of my already receding hairline.
            case head :: Nil => Column.Field(
              member.asModule.name.toTermName,
              head.asType.toType.asSeenFrom(memberType, colSymbol)
            )
            case _ => c.abort(
              c.enclosingPosition,
              "Expected exactly one type parameter provided for root column type"
            )
          }
        case None => c.abort(
          c.enclosingPosition,
          s"Could not find root column type for ${member.asModule.name}"
        )
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy