All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.zeal18.zio.mongodb.bson.codecs.Macro.scala Maven / Gradle / Ivy

package io.github.zeal18.zio.mongodb.bson.codecs

import scala.quoted.*
import scala.compiletime
import scala.deriving.Mirror
import org.bson.BsonInvalidOperationException
import org.bson.BsonReader
import org.bson.BsonSerializationException
import org.bson.BsonType
import org.bson.BsonWriter
import org.bson.codecs.DecoderContext
import org.bson.codecs.EncoderContext
import io.github.zeal18.zio.mongodb.bson.codecs.error.BsonError
import scala.annotation.tailrec
import scala.deriving.*
import scala.quoted.*
import EasierValDef.*
import io.github.zeal18.zio.mongodb.bson.annotations.BsonId
import io.github.zeal18.zio.mongodb.bson.annotations.BsonProperty
import io.github.zeal18.zio.mongodb.bson.annotations.BsonIgnore

object Macro:
  inline def delayedSummonInline[A] = scala.compiletime.summonInline[A]

  def derived[A: Type](logCode: Boolean)(using Quotes): Expr[Codec[A]] =
    if isCaseObject[A] then deriveCaseObject(logCode)
    else if isValueClass[A] then ???
    else if isProduct[A] || isCoproduct[A] then derivedADT(logCode)
    else '{ compiletime.error("Derivation of selected type is not supported") }

  private def derivedADT[A: Type](logCode: Boolean)(using Quotes): Expr[Codec[A]] =
    Preparations(logCode) { preparations =>
      derivedADTWithPreparations(logCode, preparations)
    }

  private def derivedADTWithPreparations[A: Type](logCode: Boolean, preparations: Preparations)(
    using Quotes,
  ): Prepared[A] =
    Expr.summonOrError[Mirror.Of[A]] match
      case '{ $m: Mirror.ProductOf[A] } =>
        deriveProduct(logCode, preparations, m)
      case '{ $m: Mirror.SumOf[A] } =>
        deriveCoproduct(logCode, preparations, m)

  private def isCaseObject[A: Type](using q: Quotes): Boolean =
    import q.reflect.*

    TypeRepr.of[A].typeSymbol.flags.is(Flags.Module)

  private def isValueClass[A: Type](using Quotes): Boolean = false
  private def isProduct[A: Type](using q: Quotes): Boolean =
    import q.reflect.*

    TypeRepr.of[A].typeSymbol.flags.is(Flags.Case)

  private def isCoproduct[A: Type](using q: Quotes): Boolean =
    import q.reflect.*

    val flags = TypeRepr.of[A].typeSymbol.flags

    flags.is(Flags.Sealed) || flags.is(Flags.Enum)

  private def deriveProduct[A: Type](
    logCode: Boolean,
    preparations: Preparations,
    mirror: Expr[Mirror.ProductOf[A]],
  )(using q: Quotes): Prepared[A] =
    preparations.addDeferred[A] {
      import q.reflect.*

      val caseClassTypeInfo = TypeInfo[A]

      if logCode then println(s"Deriving product: ${Type.show[A]}")

      val caseClass                   = TypeRepr.of[A]
      val caseClassName: Expr[String] = Expr(caseClassTypeInfo.short)

      val fields   = TypeRepr.of[A].typeSymbol.caseFields
      val defaults = getDefaultValues[A]
      val annotations = fields
        .map { f =>
          f.name -> f.annotations.map(_.asExpr)
        }
        .filter(_._2.nonEmpty)
        .toMap

      val labels: Map[String, String] = fields
        .map { f =>
          val newName =
            annotations.getOrElse(f.name, Nil).collectFirst {
              case '{ new BsonId() }             => "_id"
              case '{ new BsonProperty($label) } => label.valueOrAbort
            }

          f.name -> newName
        }
        .collect { case (k, Some(v)) => k -> v }
        .toMap

      val ignored: Set[String] = fields
        .map { f =>
          annotations.getOrElse(f.name, Nil).collectFirst { case '{ new BsonIgnore() } =>
            f.name
          }
        }
        .collect { case Some(value) => value }
        .toSet

      val decode =
        if caseClass.typeArgs.isEmpty then deriveCaseClassDecode[A](defaults, labels, ignored)
        else deriveCaseClassDecodeWithTypeParams[A](mirror, defaults, labels, ignored)

      val filedsToProcess = fields.filterNot(f => ignored.contains(f.name))

      val fieldTypes = filedsToProcess.map(t => caseClass.memberType(t))

      val preparedFields = Expr.ofList(filedsToProcess.map { f =>
        val tpe = caseClass.memberType(f)
        tpe.asType match {
          case '[t] =>
            val paramTypeInfo = TypeInfo[t]
            val codecExpr     = '{ () => ${ preparations.getStablisedImplicitInstance[t] } }

            val derefDefSymbol = Symbol.newMethod(
              Symbol.spliceOwner,
              "_",
              MethodType(List("source"))(
                _ => List(caseClass),
                _ => tpe,
              ),
            )

            val functionDef = DefDef(
              derefDefSymbol,
              {
                case (src :: Nil) :: Nil =>
                  Some(Select(src.asExprOf[A].asTerm, f).changeOwner(derefDefSymbol))
                case _ => None
              },
            )

            val deref =
              Block(List(functionDef), Closure(Ref(derefDefSymbol), None)).asExprOf[Function1[A, t]]
            val label: Expr[String] = Expr(labels.getOrElse(f.name, f.name))

            '{ CaseClass.Field($label, $codecExpr, $deref) }
        }
      })

      val caseClassExpr     = '{ CaseClass[A]($caseClassName, $decode, $preparedFields) }
      val caseClassFullName = Expr(caseClassTypeInfo.full)

      '{ CaseClassCodec[A]($caseClassFullName, FlatCaseClassCodec[A]($caseClassExpr)) }
    }
  end deriveProduct

  /** this is a hack for instantiating classes with a type parameter
    *
    * for example
    * ```scala
    * case class Leaf[A](a: A) derives Codec
    * ```
    *
    * would generate the following constructor
    * ```scala
    * new Leaf[A](a = ???)
    * ```
    *
    * which fails with `constructor Leaf in class Leaf does not take parameters`.
    * Possible solutions are to find a way to delete the type parameter from the constructor
    * or to use Mirror for instantiation. This hack does the second way.
    */
  private def deriveCaseClassDecodeWithTypeParams[A: Type](
    mirror: Expr[Mirror.ProductOf[A]],
    defaults: Map[String, Expr[Any]],
    labels: Map[String, String],
    ignored: Set[String],
  )(using q: Quotes): Expr[Map[String, Any] => A] =
    import q.reflect.*

    val caseClass         = TypeRepr.of[A]
    val caseClassTypeInfo = TypeInfo[A]
    val fields            = Fields.fromMirror(mirror)

    def fieldsToA(data: Expr[Map[String, Any]]): Expr[A] =
      val elems = fields.map { f =>
        import f.{Type as F, typeInstance}

        val (name, notFound) = labels.get(f.name) match
          case None => (Expr(f.name), Expr(s"Missing field: '${f.name}'"))
          case Some(label) =>
            (Expr(label), Expr(s"Missing field: '$label' (renamed from '${f.name}')"))

        if ignored.contains(f.name) then
          defaults.get(f.name) match
            case None =>
              report.errorAndAbort(
                s"Field '${caseClassTypeInfo.full}#${f.name}' is ignored but doesn't have a default value",
              )
            case Some(default) => default.asExprOf[F]
        else
          defaults.get(f.name) match
            case None =>
              '{
                $data
                  .getOrElse($name, throw new org.bson.BsonSerializationException($notFound))
                  .asInstanceOf[F]
              }
            case Some(default) => '{ $data.getOrElse($name, $default).asInstanceOf[F] }
      }

      '{ $mirror.fromProduct(Tuple.fromArray(Array(${ Varargs(elems) }*))) }

    '{ map => ${ fieldsToA('map) } }

  private def deriveCaseClassDecode[A: Type](
    defaults: Map[String, Expr[Any]],
    labels: Map[String, String],
    ignored: Set[String],
  )(using q: Quotes): Expr[Map[String, Any] => A] =
    import q.reflect.*

    val caseClass         = TypeRepr.of[A]
    val caseClassTypeInfo = TypeInfo[A]

    val fields = TypeRepr.of[A].typeSymbol.caseFields

    def fieldToArg(data: Expr[Map[String, Any]], field: Symbol): NamedArg =
      val (name, notFound) = labels.get(field.name) match
        case None => (Expr(field.name), Expr(s"Missing field: '${field.name}'"))
        case Some(label) =>
          (Expr(label), Expr(s"Missing field: '$label' (renamed from '${field.name}')"))

      val tpe = caseClass.memberType(field)

      tpe.asType match
        case '[t] =>
          val valueExpr =
            if ignored.contains(field.name) then
              defaults.get(field.name) match
                case None =>
                  report.errorAndAbort(
                    s"Field '${caseClassTypeInfo.full}#${field.name}' is ignored but doesn't have a default value",
                  )
                case Some(default) => default.asExprOf[t]
            else
              defaults.get(field.name) match
                case None =>
                  '{
                    $data
                      .getOrElse($name, throw new org.bson.BsonSerializationException($notFound))
                      .asInstanceOf[t]
                  }
                case Some(default) => '{ $data.getOrElse($name, $default).asInstanceOf[t] }

          val argTerm =
            if isRepeated(tpe) then
              // need to change type from 'Seq[Int]' to 'Int*' for a varargs argument
              val subtype = TypeRepr.of[t].typeArgs.head
              subtype.asType match
                case '[st] =>
                  val repeatedAnyTypeTree =
                    Applied(TypeIdent(defn.RepeatedParamClass), List(TypeTree.of[st]))
                  Typed(valueExpr.asTerm, repeatedAnyTypeTree)
            else valueExpr.asTerm

          NamedArg(field.name, argTerm)

    '{ map =>
      ${
        val fieldsList = fields.map(fieldToArg('map, _))
        New(Inferred(TypeRepr.of[A]))
          .select(caseClass.typeSymbol.primaryConstructor)
          .appliedToArgs(fieldsList)
          .asExprOf[A]
      }
    }

  private def deriveCoproduct[A: Type](
    logCode: Boolean,
    preparations: Preparations,
    mirror: Expr[Mirror.SumOf[A]],
  )(using q: Quotes): Prepared[A] =
    import q.reflect.*

    val children = TypeRepr.of[A].typeSymbol.children
    val isEnum = children.forall(_.flags.is(Flags.Module)) ||
      (TypeRepr.of[A].typeSymbol.flags.is(Flags.Enum) && children.forall(_.caseFields.isEmpty))

    if isEnum then Prepared(deriveEnum(logCode))
    else deriveMixedCoproduct(logCode, preparations, mirror)

  private def deriveEnum[A: Type](logCode: Boolean)(using q: Quotes): Expr[Codec[A]] =
    import q.reflect.*

    val enumTypeInfo = TypeInfo[A]

    if logCode then println(s"Deriving enum type: ${enumTypeInfo.full}")

    val enumSymbol = TypeRepr.of[A].typeSymbol
    val children   = TypeRepr.of[A].typeSymbol.children

    def childToName(child: Symbol): CaseDef =
      CaseDef(Ident(child.termRef), None, Block(Nil, Literal(StringConstant(child.name))))

    def nameToChild(child: Symbol): CaseDef =
      val someChild = '{ Some(${ Ident(child.termRef).asExprOf[A] }) }
      CaseDef(Literal(StringConstant(child.name)), None, Block(Nil, someChild.asTerm))

    val nameByValue: Expr[A => String] = '{ (a: A) =>
      ${
        val cases = children.map(childToName)
        Match('a.asTerm, cases).asExprOf[String]
      }
    }
    val valueByName: Expr[String => Option[A]] = '{ (s: String) =>
      ${
        val cases = children.map(nameToChild) :+
          CaseDef(Wildcard(), None, '{ None }.asTerm)

        Match('s.asTerm, cases).asExprOf[Option[A]]
      }
    }

    val result = '{ EnumCodec[A](${ Expr(enumTypeInfo.full) }, $nameByValue, $valueByName) }

    if logCode then printResult(result)

    result
  end deriveEnum

  private def deriveMixedCoproduct[A: Type](
    logCode: Boolean,
    preparations: Preparations,
    mirror: Expr[Mirror.SumOf[A]],
  )(using q: Quotes): Prepared[A] =
    import q.reflect.*

    val enumTypeInfo = TypeInfo[A]

    if logCode then println(s"Deriving coproduct type: ${Type.show[A]}")

    val fields = flattenCoproductSubtypes(mirror).zipWithIndex
    if logCode then
      println(s"Flattened coproduct subtypes are: ${fields.map(_._1.name).mkString(", ")}")

    fields.map(_._1.name).groupBy(identity).filter(_._2.size > 1).toList match
      case Nil => ()
      case ambiguous =>
        report.errorAndAbort(
          s"Error deriving '${enumTypeInfo.full}': Ambiguous subtypes: '${ambiguous.map(_._1).mkString("', '")}'\n" +
            "Make sure all sealed subtypes and their sealed subtypes have unique names",
        )

    val nonRecursiveCases = Array.fill[Option[Expr[Codec[?]]]](fields.size)(None)

    def fieldSymbol(field: Field): Symbol =
      val t = field.typeRepr
      if t.termSymbol.flags.is(q.reflect.Flags.Case) then t.termSymbol
      else t.typeSymbol

    fields.foreach { case (field, idx) =>
      import field.{Type as F, typeInstance}

      val symbol = fieldSymbol(field)

      if symbol.flags.is(Flags.Module) || symbol.flags.is(Flags.StableRealizable) then
        if logCode then println(s"Deriving object child: ${Type.show[F]}")

        val obj = Ident(symbol.termRef).asExprOf[F]

        val codec = '{ ConstCodec($obj) }
        nonRecursiveCases(idx) = Some(codec)
      else if symbol.isClassDef then
        if logCode then println(s"Deriving class child: ${Type.show[F]}")

        Expr.summon[Codec[F]] match
          case Some(rf) =>
            if logCode then println(s"Found implicit codec for ${Type.show[F]}")
            val codec = preparations.stabliseInstance(Expr.summonLater[Codec[F]])
            nonRecursiveCases(idx) = Some(codec)
            codec
          case None =>
            if logCode then println(s"Missing implicit codec, deriving codec for ${Type.show[F]}")
            derivedADTWithPreparations[F](logCode, preparations)
      else
        report.errorAndAbort(
          s"Unsupported coproduct element: '${field.name}' with type '${Type.show[F]}'",
        )
    }

    preparations.addDeferred[A] {
      val childrenCodecs: Map[String, Expr[Codec[?]]] =
        fields.map { case (field, idx) =>
          import field.{Type as F, typeInstance}

          val symbol = fieldSymbol(field)

          val codec: Expr[Codec[?]] =
            val c = nonRecursiveCases(idx).getOrElse(preparations.need[F].varRef)
            '{
              if $c.isInstanceOf[CaseClassCodec[F]] then $c.asInstanceOf[CaseClassCodec[F]].flat
              else $c
            }

          field.name -> codec
        }.toMap

      def childToName(field: Field): CaseDef =
        childrenCodecs.get(field.name) match
          case None => throw new Exception("")
          case Some(codec) =>
            val tuple  = '{ (${ Expr(field.name) }, $codec) }
            val symbol = fieldSymbol(field)
            if symbol.isClassDef then
              CaseDef(Typed(Wildcard(), TypeTree.of[field.Type]), None, Block(Nil, tuple.asTerm))
            else CaseDef(Ident(symbol.termRef), None, Block(Nil, tuple.asTerm))

      def nameToChild(field: Field): CaseDef =
        childrenCodecs.get(field.name) match
          case None => throw new Exception("")
          case Some(codec) =>
            val someChild = '{ Some($codec.asInstanceOf[Codec[A]]) }
            CaseDef(Literal(StringConstant(field.name)), None, Block(Nil, someChild.asTerm))

      val codecByValue: Expr[A => (String, Codec[?])] = '{ (a: A) =>
        ${
          val cases = fields.map { case (f, _) => childToName(f) }
          Match('a.asTerm, cases).asExprOf[(String, Codec[?])]
        }
      }
      val codecByName: Expr[String => Option[Codec[A]]] = '{ (s: String) =>
        ${
          val cases = fields.map { case (f, _) => nameToChild(f) } :+
            CaseDef(Wildcard(), None, '{ None }.asTerm)

          Match('s.asTerm, cases).asExprOf[Option[Codec[A]]]
        }
      }
      val discriminator = Expr("_t")

      val result = '{
        CoproductCodec[A](
          ${ Expr(enumTypeInfo.full) },
          $discriminator,
          $codecByValue,
          $codecByName,
        )
      }

      result
    }

  private def deriveCaseObject[A: Type](logCode: Boolean)(using q: Quotes): Expr[Codec[A]] =
    import q.reflect.*

    val objectTypeInfo = TypeInfo[A]

    if logCode then println(s"Deriving object: ${objectTypeInfo.full}")

    val obj = Ident(TypeRepr.of[A].typeSymbol.companionModule.termRef).asExprOf[A]

    val result = '{
      CaseObjectCodec(${ Expr(objectTypeInfo.short) }, ${ Expr(objectTypeInfo.full) }, $obj)
    }

    if logCode then printResult(result)

    result

  private[codecs] def printResult(result: Expr[?])(using q: Quotes): Unit =
    import q.reflect.*
    println(
      s"\nDerived ${result.asTerm.tpe.show}:\n${result.asTerm.show.replace("io.github.zeal18.zio.mongodb.bson.codecs.", "").replace(".apply(", "(").replace("scala.", "").replace("RecursiveDerivationTest.", "")}\n",
    )

  private def isRepeated(using q: Quotes)(typeRepr: q.reflect.TypeRepr): Boolean =
    typeRepr match
      case a: q.reflect.AnnotatedType =>
        a.annotation.tpe match
          case tr: q.reflect.TypeRef => tr.name == "Repeated"
          case _                     => false
      case _ => false

  private def getDefaultValues[T: Type](using Quotes): Map[String, Expr[Any]] =
    import quotes.reflect.*

    val tpe = TypeRepr.of[T].typeSymbol

    tpe.primaryConstructor.paramSymss.flatten
      .filter(_.isValDef)
      .zipWithIndex
      .flatMap { case (field, i) =>
        tpe.companionClass
          .declaredMethod(s"$$lessinit$$greater$$default$$${i + 1}")
          .headOption
          .flatMap(_.tree.asInstanceOf[DefDef].rhs)
          .map(field.name -> _.asExprOf[Any])
      }
      .toMap

  private def flattenCoproductSubtypes[A: Type](mirror: Expr[Mirror.SumOf[A]])(using
    q: Quotes,
  ): List[Field] =
    val fields = Fields.fromMirror(mirror)
    fields.flatMap { f =>
      import f.{Type as F, typeInstance}

      Expr.summonOrError[Mirror.Of[F]] match
        case '{ $m: Mirror.ProductOf[F] } => List(f)
        case '{ $m: Mirror.SumOf[F] }     => flattenCoproductSubtypes[F](m)
    }

  extension (unused: Expr.type)
    /** Requires that macro be transparent.
      *
      * https://github.com/lampepfl/dotty/issues/12359
      */
    def summonLater[A: Type](using Quotes): Expr[A] =
      '{ Macro.delayedSummonInline[A] }

    def summonOrError[A](using Type[A])(using q: Quotes): Expr[A] =
      import quotes.reflect.*
      Implicits.search(TypeRepr.of[A]) match
        case iss: ImplicitSearchSuccess => iss.tree.asExpr.asInstanceOf[Expr[A]]
        case isf: ImplicitSearchFailure => q.reflect.report.errorAndAbort(isf.explanation)

  /** Used as a placeholder to make the type checker happy and be replaced by the real coded afterwards
    */
  private[codecs] class FakeCodec[A] extends Codec[A]:
    override def decode(reader: BsonReader, decoderContext: DecoderContext): A              = ???
    override def encode(writer: BsonWriter, value: A, encoderContext: EncoderContext): Unit = ???

end Macro




© 2015 - 2025 Weber Informatics LLC | Privacy Policy