All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.circe.generic.util.macros.DerivationMacros.scala Maven / Gradle / Ivy

There is a newer version: 0.15.0-M1
Show newest version
package io.circe.generic.util.macros

import io.circe.{ Decoder, Encoder }
import io.circe.generic.decoding.ReprDecoder
import macrocompat.bundle
import scala.annotation.tailrec
import scala.reflect.macros.whitebox
import shapeless.{ CNil, Coproduct, HList, HNil, Lazy }
import shapeless.labelled.KeyTag

@bundle
abstract class DerivationMacros[RD[_], RE[_], DD[_], DE[_]] {
  val c: whitebox.Context

  import c.universe._

  protected[this] def RD: TypeTag[RD[_]]
  protected[this] def RE: TypeTag[RE[_]]
  protected[this] def DD: TypeTag[DD[_]]
  protected[this] def DE: TypeTag[DE[_]]

  protected[this] def hnilReprDecoder: Tree

  protected[this] def decodeMethodName: TermName
  protected[this] def decodeMethodArgs: List[Tree] = Nil

  protected[this] def decodeAccumulatingMethodName: TermName
  protected[this] def decodeAccumulatingMethodArgs: List[Tree] = decodeMethodArgs

  protected[this] def encodeMethodName: TermName
  protected[this] def encodeMethodArgs: List[Tree] = Nil

  protected[this] def decodeField(name: String, decode: TermName): Tree
  protected[this] def decodeFieldAccumulating(name: String, decode: TermName): Tree

  protected[this] def decodeSubtype(name: String, decode: TermName): Tree
  protected[this] def decodeSubtypeAccumulating(name: String, decode: TermName): Tree

  protected[this] def encodeField(name: String, encode: TermName, value: TermName): Tree
  protected[this] def encodeSubtype(name: String, encode: TermName, value: TermName): Tree

  private[this] def fullDecodeMethodArgs(tpe: Type): List[List[Tree]] =
    List(q"c: _root_.io.circe.HCursor") :: (if (decodeMethodArgs.isEmpty) Nil else List(decodeMethodArgs))

  private[this] def fullDecodeAccumulatingMethodArgs(tpe: Type): List[List[Tree]] =
    List(q"c: _root_.io.circe.HCursor") :: (
      if (decodeAccumulatingMethodArgs.isEmpty) Nil else List(decodeAccumulatingMethodArgs)
    )

  private[this] def fullEncodeMethodArgs(tpe: Type): List[List[Tree]] =
    List(q"a: $tpe") :: (if (encodeMethodArgs.isEmpty) Nil else List(encodeMethodArgs))

  /**
   * Crash the attempted derivation because of a failure related to the
   * specified type.
   *
   * Note that these failures are not generally visible to users.
   */
  private[this] def fail(tpe: Type): Nothing =
    c.abort(c.enclosingPosition, s"Cannot generically derive instance: $tpe")

  /**
   * Represents an element at the head of a `shapeless.HList` or
   * `shapeless.Coproduct`.
   */
  private[this] case class Member(label: String, keyType: Type, valueType: Type, acc: Type, accTail: Type)

  /**
   * Represents an `shapeless.HList` or `shapeless.Coproduct` type in a way
   * that's more convenient to work with.
   */
  private[this] class Members(val underlying: List[Member]) {
    /**
     * Fold over the elements of this (co-)product while accumulating instances
     * of some type class for each.
     */
    def fold[Z](resolver: Type => Tree)(init: Z)(f: (Member, TermName, Z) => Z): (List[Tree], Z) = {
      val (instanceMap, result) = underlying.foldRight((Map.empty[Type, (TermName, Tree)], init)) {
        case (member @ Member(_, _, valueType, _, _), (instanceMap, acc)) =>
          val (instanceName, instance) = instanceMap.getOrElse(valueType, (TermName(c.freshName), resolver(valueType)))
          val newInstances = instanceMap.updated(valueType, (instanceName, instance))

          (newInstances, f(member, instanceName, acc))
      }

      val instanceDefs = instanceMap.values.map {
        case (instanceName, instance) => q"private[this] val $instanceName = $instance"
      }

      (instanceDefs.toList, result)
    }
  }

  private[this] val HListType: Type = typeOf[HList]
  private[this] val CoproductType: Type = typeOf[Coproduct]

  private[this] object Members {
    private[this] val ShapelessSym = typeOf[HList].typeSymbol.owner
    private[this] val HNilSym = typeOf[HNil].typeSymbol
    private[this] val HConsSym = typeOf[shapeless.::[_, _]].typeSymbol
    private[this] val CNilSym = typeOf[CNil].typeSymbol
    private[this] val CConsSym = typeOf[shapeless.:+:[_, _]].typeSymbol
    private[this] val ShapelessLabelledType = typeOf[shapeless.labelled.type]
    private[this] val KeyTagSym = typeOf[KeyTag[_, _]].typeSymbol
    private[this] val ShapelessTagType = typeOf[shapeless.tag.type]
    private[this] val ScalaSymbolType = typeOf[scala.Symbol]

    case class Entry(label: String, keyType: Type, valueType: Type)

    object Entry {
      def unapply(tpe: Type): Option[(String, Type, Type)] = tpe.dealias match {
        case RefinedType(List(fieldType, TypeRef(lt, KeyTagSym, List(tagType, taggedFieldType))), _)
          if lt =:= ShapelessLabelledType && fieldType =:= taggedFieldType =>
            tagType.dealias match {
              case RefinedType(List(st, TypeRef(tt, ts, ConstantType(Constant(fieldKey: String)) :: Nil)), _)
                if st =:= ScalaSymbolType && tt =:= ShapelessTagType =>
                  Some((fieldKey, tagType, fieldType))
              case _ => None
            }
        case _ => None
      }
    }

    def fromType(tpe: Type): Members = tpe.dealias match {
      case TypeRef(ThisType(ShapelessSym), HNilSym | CNilSym, Nil) => new Members(Nil)
      case acc @ TypeRef(ThisType(ShapelessSym), HConsSym | CConsSym, List(fieldType, tailType)) =>
        fieldType match {
          case Entry(label, keyType, valueType) =>
            new Members(Member(label, keyType, valueType, acc, tailType) :: fromType(tailType).underlying)
          case _ => fail(tpe)
        }
      case _ => fail(tpe)
    }
  }

  @tailrec
  private[this] def resolveInstance(tcs: List[(Type, Boolean)])(tpe: Type): Tree = tcs match {
    case (tc, lazily) :: rest =>
      val applied = c.universe.appliedType(tc.typeConstructor, List(tpe))
      val target = if (lazily) c.universe.appliedType(typeOf[Lazy[_]].typeConstructor, List(applied)) else applied
      val inferred = c.inferImplicitValue(target, silent = true)

      inferred match {
        case EmptyTree => resolveInstance(rest)(tpe)
        case instance if lazily => q"{ $instance }.value"
        case instance => instance
      }
    case Nil => fail(tpe)
  }

  val ReprDecoderUtils = symbolOf[ReprDecoder.type].asClass.module

  private[this] def hlistDecoderParts(members: Members): (List[c.Tree], (c.Tree, c.Tree)) = members.fold(
    resolveInstance(List((typeOf[Decoder[_]], false)))
  )((q"$ReprDecoderUtils.hnilResult": Tree, q"$ReprDecoderUtils.hnilResultAccumulating": Tree)) {
    case (Member(name, nameTpe, tpe, _, accTail), instanceName, (acc, accAccumulating)) => (
      q"""
        $ReprDecoderUtils.consResults[
          _root_.io.circe.Decoder.Result,
          $nameTpe,
          $tpe,
          $accTail
        ](
          ${ decodeField(name, instanceName) },
          $acc
        )(_root_.io.circe.Decoder.resultInstance)
      """,
      q"""
        $ReprDecoderUtils.consResults[
          _root_.io.circe.AccumulatingDecoder.Result,
          $nameTpe,
          $tpe,
          $accTail
        ](
          ${ decodeFieldAccumulating(name, instanceName) },
          $accAccumulating
        )(_root_.io.circe.AccumulatingDecoder.resultInstance)
      """
    )
  }

  private[this] val cnilResult: Tree = q"""
    _root_.scala.util.Left[_root_.io.circe.DecodingFailure, _root_.shapeless.CNil](
      _root_.io.circe.DecodingFailure("CNil", c.history)
    ): _root_.scala.util.Either[_root_.io.circe.DecodingFailure, _root_.shapeless.CNil]
  """

  private[this] val cnilResultAccumulating: Tree = q"""
    _root_.cats.data.Validated.invalidNel[_root_.io.circe.DecodingFailure, _root_.shapeless.CNil](
      _root_.io.circe.DecodingFailure("CNil", c.history)
    )
  """

  private[this] def coproductDecoderParts(members: Members): (List[c.Tree], (c.Tree, c.Tree)) = members.fold(
    resolveInstance(List((typeOf[Decoder[_]], false), (DD.tpe, true)))
  )((cnilResult, cnilResultAccumulating)) {
    case (Member(name, nameTpe, tpe, current, accTail), instanceName, (acc, accAccumulating)) => (
      q"""
        ${ decodeSubtype(name, instanceName) } match {
          case _root_.scala.Some(result) => result.right.map(v =>
           $ReprDecoderUtils.injectLeftValue[$nameTpe, $tpe, $accTail](v)
          )
          case _root_.scala.None => $acc.right.map(_root_.shapeless.Inr(_))
        }
      """,
      q"""
        ${ decodeSubtypeAccumulating(name, instanceName) } match {
          case _root_.scala.Some(result) => result.map(v =>
            $ReprDecoderUtils.injectLeftValue[$nameTpe, $tpe, $accTail](v)
          )
          case _root_.scala.None => $accAccumulating.map(_root_.shapeless.Inr(_))
        }
      """
    )
  }

  protected[this] def constructDecoder[R](implicit R: c.WeakTypeTag[R]): c.Tree = {
    val isHList = R.tpe <:< HListType
    val isCoproduct = !isHList && R.tpe <:< CoproductType

    if (!isHList && !isCoproduct) fail(R.tpe) else {
      val members = Members.fromType(R.tpe)

      if (isHList && members.underlying.isEmpty) q"$hnilReprDecoder" else {
        val (instanceDefs, (result, resultAccumulating)) =
          if (isHList) hlistDecoderParts(members) else coproductDecoderParts(members)
        val instanceType = appliedType(RD.tpe.typeConstructor, List(R.tpe))

        q"""
          new $instanceType {
            ..$instanceDefs

            final def $decodeMethodName(
              ...${ fullDecodeMethodArgs(R.tpe) }
            ): _root_.io.circe.Decoder.Result[$R] = $result

            final override def $decodeAccumulatingMethodName(
              ...${ fullDecodeAccumulatingMethodArgs(R.tpe) }
            ): _root_.io.circe.AccumulatingDecoder.Result[$R] = $resultAccumulating
          }: $instanceType
        """
      }
    }
  }

  protected[this] def hlistEncoderParts(members: Members): (List[c.Tree], c.Tree) = {
    val (instanceDefs, (pattern, fields)) = members.fold(resolveInstance(List((typeOf[Encoder[_]], false))))(
      (pq"_root_.shapeless.HNil": Tree, List.empty[Tree])
    ) {
      case (Member(name, _, tpe, _, _), instanceName, (patternAcc, fieldsAcc)) =>
        val currentName = TermName(c.freshName)

        (
          pq"_root_.shapeless.::($currentName, $patternAcc)",
          encodeField(name, instanceName, currentName) :: fieldsAcc
        )
    }

    (
      instanceDefs,
      q"""
        a match {
          case $pattern =>
            _root_.io.circe.JsonObject.fromIterable(_root_.scala.collection.immutable.Vector(..$fields))
        }
      """
    )
  }

  private[this] def coproductEncoderParts(members: Members): (List[c.Tree], c.Tree) = {
    val (instanceDefs, patternAndCase) = members.fold(
      resolveInstance(List((typeOf[Encoder[_]], false), (DE.tpe, true)))
    )(
      cq"""_root_.shapeless.Inr(_) => _root_.scala.sys.error("Cannot encode CNil")"""
    ) {
      case (Member(name, _, tpe, _, _), instanceName, acc) =>
        val tailName = TermName(c.freshName)
        val currentName = TermName(c.freshName)

        cq"""
          _root_.shapeless.Inr($tailName) => $tailName match {
            case _root_.shapeless.Inl($currentName) =>
              ${ encodeSubtype(name, instanceName, currentName) }
            case $acc
          }
        """
    }

    (instanceDefs, q"_root_.shapeless.Inr(a) match { case $patternAndCase }")
  }

  protected[this] def constructEncoder[R](implicit R: c.WeakTypeTag[R]): c.Tree = {
    val isHList = R.tpe <:< HListType
    val isCoproduct = !isHList && R.tpe <:< CoproductType

    if (!isHList && !isCoproduct) fail(R.tpe) else {
      val members = Members.fromType(R.tpe)

      val (instanceDefs, instanceImpl) =
        if (isHList) hlistEncoderParts(members) else coproductEncoderParts(members)

      val instanceType = appliedType(RE.tpe.typeConstructor, List(R.tpe))

      q"""
        new $instanceType {
          ..$instanceDefs

          final def $encodeMethodName(...${ fullEncodeMethodArgs(R.tpe) }): _root_.io.circe.JsonObject = $instanceImpl
        }: $instanceType
      """
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy