All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tethys.derivation.Derivation.scala Maven / Gradle / Ivy

The newest version!
package tethys.derivation

import tethys.derivation.builder.{
  ReaderDerivationConfig,
  WriterDerivationConfig
}
import tethys.writers.tokens.TokenWriter
import tethys.readers.{FieldName, ReaderError}
import tethys.readers.tokens.TokenIterator
import tethys.{
  JsonObjectWriter,
  JsonReader,
  JsonWriter,
  ReaderBuilder,
  WriterBuilder
}
import scala.Tuple2
import scala.annotation.tailrec
import scala.compiletime.{constValueTuple, summonInline}
import scala.quoted.*
import scala.collection.mutable
import scala.deriving.Mirror

private[tethys] object Derivation:

  inline def deriveJsonWriterForProduct[T](
      inline config: WriterBuilder[T]
  ): JsonObjectWriter[T] =
    ${ DerivationMacro.deriveJsonWriterForProduct[T]('{ config }) }

  inline def deriveJsonWriterForSum[T]: JsonObjectWriter[T] =
    ${ DerivationMacro.deriveJsonWriterForSum[T] }

  inline def deriveJsonReaderForProduct[T](
      inline config: ReaderBuilder[T]
  ): JsonReader[T] =
    ${ DerivationMacro.deriveJsonReaderForProduct[T]('{ config }) }

  @deprecated
  inline def deriveJsonReaderForProductLegacy[T](
      inline config: ReaderDerivationConfig
  )(using mirror: Mirror.ProductOf[T]): JsonReader[T] =
    ${
      DerivationMacro
        .deriveJsonReaderForProductLegacy[T]('{ config }, '{ mirror })
    }

  @deprecated
  inline def deriveJsonWriterForProductLegacy[T](
      inline config: WriterDerivationConfig
  )(using mirror: Mirror.ProductOf[T]): JsonObjectWriter[T] =
    ${
      DerivationMacro
        .deriveJsonWriterForProductLegacy[T]('{ config }, '{ mirror })
    }

  @deprecated
  inline def deriveJsonWriterForSumLegacy[T](
      inline config: WriterDerivationConfig
  ): JsonObjectWriter[T] =
    ${ DerivationMacro.deriveJsonWriterForSumLegacy[T]('{ config }) }

  inline def deriveJsonReaderForSum[T]: JsonReader[T] =
    ${ DerivationMacro.deriveJsonReaderForSum[T] }

object DerivationMacro:
  def deriveJsonWriterForProduct[T: Type](config: Expr[WriterBuilder[T]])(using
      quotes: Quotes
  ): Expr[JsonObjectWriter[T]] =
    new DerivationMacro(quotes).deriveJsonWriterForProduct[T](config)

  def deriveJsonWriterForSum[T: Type](using
      quotes: Quotes
  ): Expr[JsonObjectWriter[T]] =
    new DerivationMacro(quotes).deriveJsonWriterForSum[T](None)

  def deriveJsonReaderForProduct[T: Type](config: Expr[ReaderBuilder[T]])(using
      quotes: Quotes
  ): Expr[JsonReader[T]] =
    new DerivationMacro(quotes).deriveJsonReaderForProduct[T](config)

  def deriveJsonReaderForSum[T: Type](using
      quotes: Quotes
  ): Expr[JsonReader[T]] =
    new DerivationMacro(quotes).deriveJsonReaderForSum[T]

  @deprecated
  def deriveJsonReaderForProductLegacy[T: Type](
      config: Expr[ReaderDerivationConfig],
      mirror: Expr[Mirror.ProductOf[T]]
  )(using quotes: Quotes): Expr[JsonReader[T]] =
    new DerivationMacro(quotes)
      .deriveJsonReaderForProductLegacy[T](config, mirror)

  @deprecated
  def deriveJsonWriterForProductLegacy[T: Type](
      config: Expr[WriterDerivationConfig],
      mirror: Expr[Mirror.ProductOf[T]]
  )(using quotes: Quotes): Expr[JsonObjectWriter[T]] =
    new DerivationMacro(quotes)
      .deriveJsonWriterForProductLegacy[T](config, mirror)

  @deprecated
  def deriveJsonWriterForSumLegacy[T: Type](
      config: Expr[WriterDerivationConfig]
  )(using quotes: Quotes): Expr[JsonObjectWriter[T]] =
    new DerivationMacro(quotes).deriveJsonWriterForSumLegacy[T](config)

private[derivation] class DerivationMacro(val quotes: Quotes)
    extends ConfigurationMacroUtils:
  import quotes.reflect.*

  def deriveJsonWriterForProduct[T: Type](
      config: Expr[WriterBuilder[T]]
  ): Expr[JsonObjectWriter[T]] =
    val fields = prepareWriterProductFields(config)
    val (missingWriters, refs) =
      deriveMissingWriters(TypeRepr.of[T], fields.map(_.tpe))
    val writer = Block(
      missingWriters,
      '{
        new JsonObjectWriter[T]:
          override def writeValues(value: T, tokenWriter: TokenWriter): Unit =
            ${
              Expr.block(
                fields.map { field =>
                  field.tpe.asType match
                    case '[f] =>
                      val writer = refs
                        .get(field.tpe)
                        .fold(lookup[JsonWriter[f]])(_.asExprOf[JsonWriter[f]])
                      '{
                        ${ writer }.write(
                          ${ field.label },
                          ${ field.value('{ value }.asTerm).asExprOf[f] },
                          tokenWriter
                        )
                      }
                },
                '{}
              )
            }
      }.asTerm
    )
    writer.asExprOf[JsonObjectWriter[T]]

  def deriveJsonWriterForSum[T: Type](
      legacyConfig: Option[DiscriminatorConfig]
  ): Expr[JsonObjectWriter[T]] =
    val tpe = TypeRepr.of[T]
    val parsedConfig = parseSumConfig[T]
    val types = getAllChildren(tpe)
    val (missingWriters, refs) = deriveMissingWritersForSum(types)
    val mirror = '{ summonInline[Mirror.SumOf[T]] }
    val writer = Block(
      missingWriters,
      '{
        new JsonObjectWriter[T]:
          override def writeValues(value: T, tokenWriter: TokenWriter): Unit =
            ${
              legacyConfig.fold('{}) {
                case DiscriminatorConfig(label, tpe, values) =>
                  '{
                    JsonWriter.stringWriter.write(
                      name = ${ Expr(label) },
                      value = ${
                        Expr.ofList(
                          types.map(t =>
                            Expr(t.typeSymbol.name.filterNot(_ == '$'))
                          )
                        )
                      }.apply(${ mirror }.ordinal(value)),
                      tokenWriter = tokenWriter
                    )
                  }
              }
            }
            ${
              parsedConfig.discriminator.fold('{}) {
                case DiscriminatorConfig(label, tpe, discriminators) =>
                  tpe.asType match
                    case '[discriminatorType] =>
                      '{
                        ${ lookup[JsonWriter[discriminatorType]] }.write(
                          name = ${ Expr(label) },
                          value = ${
                            Select
                              .unique('{ value }.asTerm, label)
                              .asExprOf[discriminatorType]
                          },
                          tokenWriter = tokenWriter
                        )
                      }
              }
            }
            ${
              matchByTypeAndWrite(
                term = '{ value }.asTerm,
                types = types,
                write = (ref, tpe) =>
                  tpe.asType match
                    case '[t] =>
                      val writer = refs
                        .get(tpe)
                        .fold(lookup[JsonObjectWriter[t]])(
                          _.asExprOf[JsonObjectWriter[t]]
                        )
                      '{
                        ${ writer }
                          .writeValues(${ ref.asExprOf[t] }, tokenWriter)
                      }
              )
            }
      }.asTerm
    )
    writer.asExprOf[JsonObjectWriter[T]]

  private def deriveMissingWritersForSum(
      types: List[TypeRepr]
  ): (List[ValDef], Map[TypeRepr, Ref]) =
    val (stats, refs) = types.flatMap { tpe =>
      tpe.asType match
        case '[t] =>
          val symbol = Symbol.newVal(
            Symbol.spliceOwner,
            s"given_JsonWriter_${tpe.show(using Printer.TypeReprShortCode)}",
            TypeRepr.of[JsonObjectWriter[t]],
            Flags.Given,
            Symbol.noSymbol
          )
          val valDef = Option.when(lookupOpt[JsonObjectWriter[t]].isEmpty)(
            ValDef(
              symbol,
              Some('{
                JsonObjectWriter.derived[t](using ${ lookup[Mirror.Of[t]] })
              }.asTerm)
            )
          )
          valDef.map(valDef => (valDef, (tpe, Ref(valDef.symbol))))
    }.unzip
    (stats, refs.toMap)

  private def tpeAsString(tpe: TypeRepr) =
    tpe.dealias.show(using Printer.TypeReprCode)

  private def deriveMissingWriters(
      thisTpe: TypeRepr,
      tpes: List[TypeRepr]
  ): (List[ValDef], Map[TypeRepr, Ref]) =
    val (stats, refs) = distinct(tpes)
      .filterNot(isRecursive(thisTpe, _))
      .flatMap { tpe =>
        tpe.asType match
          case '[t] =>
            lookupOpt[JsonWriter[t]].map {
              _.asTerm match
                case ident: Ident =>
                  Left(ident)
                case other =>
                  Right(other)
            } match
              case Some(Left(writer)) =>
                None

              case other =>
                val valDef = ValDef(
                  Symbol.newVal(
                    Symbol.spliceOwner,
                    s"given_JsonWriter_${tpe.show(using Printer.TypeReprShortCode)}",
                    TypeRepr.of[JsonWriter[t]],
                    Flags.Given,
                    Symbol.noSymbol
                  ),
                  Some(
                    other
                      .flatMap(_.toOption)
                      .getOrElse {
                        tpe match
                          case or: OrType =>
                            deriveOrTypeJsonWriter[t].asTerm
                          case _ =>
                            '{
                              JsonObjectWriter.derived[t](using
                                ${ lookup[scala.deriving.Mirror.Of[t]] }
                              )
                            }.asTerm
                      }
                  )
                )
                Some((valDef, (tpe, Ref(valDef.symbol))))
      }
      .unzip
    (stats, refs.toMap)

  private def deriveOrTypeJsonWriter[T: Type]: Expr[JsonWriter[T]] =
    def collectTypes(tpe: TypeRepr, acc: List[TypeRepr] = Nil): List[TypeRepr] =
      tpe match
        case OrType(left, right) =>
          collectTypes(left, Nil) ::: acc ::: collectTypes(right, Nil)
        case other => other :: acc

    val types = collectTypes(TypeRepr.of[T])
    val (missingWriters, refs) = deriveMissingWriters(TypeRepr.of[T], types)
    val term = Block(
      missingWriters,
      '{
        new JsonWriter[T]:
          def write(value: T, tokenWriter: TokenWriter): Unit =
            ${
              matchByTypeAndWrite(
                term = '{ value }.asTerm,
                types = types,
                (ref, tpe) =>
                  tpe.asType match
                    case '[t] =>
                      val writer = refs
                        .get(tpe)
                        .fold(lookup[JsonWriter[t]])(_.asExprOf[JsonWriter[t]])
                      '{ ${ writer }.write(${ ref.asExprOf[t] }, tokenWriter) }
              )
            }
      }.asTerm
    )
    term.asExprOf[JsonWriter[T]]

  private def matchByTypeAndWrite(
      term: Term,
      types: List[TypeRepr],
      write: (Ref, TypeRepr) => Expr[Unit]
  ): Expr[Unit] =
    Match(
      term,
      types.map { tpe =>
        tpe.asType match
          case '[t] =>
            val valDef = ValDef(
              Symbol.newVal(
                Symbol.spliceOwner,
                "value",
                tpe,
                Flags.EmptyFlags,
                Symbol.noSymbol
              ),
              Some(Typed(term, TypeTree.of[t]))
            )
            CaseDef(
              pattern = Bind(valDef.symbol, Typed(Wildcard(), TypeTree.of[t])),
              guard = None,
              rhs = write(Ref(valDef.symbol), tpe).asTerm
            )
      }
    ).asExprOf[Unit]

  def deriveJsonReaderForProduct[T: Type](
      config: Expr[ReaderBuilder[T]]
  ): Expr[JsonReader[T]] =
    val tpe = TypeRepr.of[T]
    val (fields, isStrict) = prepareReaderProductFields[T](config)
    val existingLabels = fields.map(_.name).toSet
    val fieldsWithoutReader = fields.collect {
      case field: ReaderField.Extracted if field.reader => field.name
    }

    val (basicFields, extractedFields) = fields.partitionMap {
      case field: ReaderField.Basic     => Left(field)
      case field: ReaderField.Extracted => Right(field)
    }

    val expectedFieldNames =
      basicFields.map(_.name).toSet ++ extractedFields.flatMap(
        _.extractors.map(_._1)
      ) -- extractedFields.map(_.name)

    def failIfNotInitialized(fieldName: Expr[FieldName]): Expr[Unit] =
      basicFields.filterNot(_.default.nonEmpty) match
        case refs @ head :: tail =>
          val boolExpr = tail.foldLeft('{
            !${ head.initRef.asExprOf[Boolean] }
          }) { (acc, el) =>
            '{ ${ acc } || !${ el.initRef.asExprOf[Boolean] } }
          }
          '{
            if { $boolExpr } then
              val uninitializedFields =
                new scala.collection.mutable.ArrayBuffer[String](${
                  Expr(refs.size)
                })
              ${
                Expr.block(
                  refs.map { ref =>
                    '{
                      if !${ ref.initRef.asExprOf[Boolean] } then
                        uninitializedFields += ${ Expr(ref.name) }
                    }
                  },
                  '{}
                )
              }
              ReaderError.wrongJson(
                "Can not extract fields from json: " + uninitializedFields
                  .mkString(", ")
              )(${ fieldName })
          }

        case Nil =>
          '{}

    if tpe.typeSymbol.flags.is(Flags.Module) then
      '{ JsonReader.const(${ Ref(tpe.termSymbol).asExprOf[T] }) }
    else
      val (missingReaders, refs) =
        deriveMissingReaders(tpe, basicFields.map(_.tpe))
      val term = Block(
        missingReaders,
        '{
          new JsonReader[T]:
            given JsonReader[T] = this
            override def read(it: TokenIterator)(using fieldName: FieldName) =
              if !it.currentToken().isObjectStart then
                ReaderError.wrongJson(
                  "Expected object start but found: " + it
                    .currentToken()
                    .toString
                )
              else
                it.nextToken()
                ${
                  Block(
                    fields.flatMap(_.initialize),
                    '{
                      while (!it.currentToken().isObjectEnd)
                        val jsonName = it.fieldName()
                        it.nextToken()
                        ${
                          Match(
                            selector = '{ jsonName }.asTerm,
                            cases = fields.flatMap(
                              _.initializeFieldCase(
                                refs,
                                '{ it },
                                '{ fieldName }
                              )
                            ) :+
                              CaseDef(
                                Wildcard(),
                                None,
                                if isStrict then
                                  '{
                                    ReaderError.wrongJson(
                                      s"unexpected field '$jsonName', expected one of ${${ Expr(expectedFieldNames.mkString("'", "', '", "'")) }}"
                                    )
                                  }.asTerm
                                else '{ it.skipExpression(); () }.asTerm
                              )
                          ).asExprOf[Unit]
                        }
                      it.nextToken()

                      ${ failIfNotInitialized('{ fieldName }) }

                      ${
                        val allRefs =
                          fields.map(field => field.name -> field.ref).toMap
                        Expr.block(
                          extractedFields
                            .flatMap(_.extract(allRefs, '{ fieldName }))
                            .map(_.asExprOf[Unit]),
                          '{}
                        )
                      }

                      ${
                        New(TypeTree.of[T])
                          .select(tpe.classSymbol.get.primaryConstructor)
                          .appliedToTypes(tpe.typeArgs)
                          .appliedToArgs(
                            fields
                              .filterNot(_.idx == -1)
                              .sortBy(_.idx)
                              .map(_.ref)
                          )
                          .asExprOf[T]
                      }

                    }.asTerm
                  ).asExprOf[T]
                }
        }.asTerm
      )
      term.asExprOf[JsonReader[T]]

  def deriveJsonReaderForSum[T: Type]: Expr[JsonReader[T]] =
    val tpe = TypeRepr.of[T]
    val parsed = parseSumConfig[T]
    val children = getAllChildren(tpe)
    parsed.discriminator match
      case Some(DiscriminatorConfig(label, tpe, discriminators)) =>
        tpe.asType match
          case '[discriminator] =>
            val (discriminatorStats, discriminatorRefs) =
              discriminators.zipWithIndex
                .map((term, idx) =>
                  val stat = ValDef(
                    Symbol.newVal(
                      Symbol.spliceOwner,
                      s"Discriminator_$idx",
                      term.tpe,
                      Flags.Private,
                      Symbol.noSymbol
                    ),
                    Some(term)
                  )
                  (stat, Ref(stat.symbol))
                )
                .unzip
            val (readers, refs) = deriveMissingReaders(TypeRepr.of[T], children)
            val term = Block(
              readers ++ discriminatorStats,
              '{
                JsonReader.builder
                  .addField[discriminator](
                    name = ${ Expr(label) },
                    jsonReader = ${ lookup[JsonReader[discriminator]] }
                  )
                  .selectReader[T] { discriminator =>
                    ${
                      Match(
                        '{ discriminator }.asTerm,
                        children
                          .zip(discriminatorRefs)
                          .map((tpe, branchDiscriminator) =>
                            tpe.asType match
                              case '[t] =>
                                CaseDef(
                                  branchDiscriminator,
                                  None,
                                  Typed(
                                    refs.getOrElse(
                                      tpe,
                                      lookup[JsonReader[t]].asTerm
                                    ),
                                    TypeTree.of[JsonReader[? <: T]]
                                  )
                                )
                          ) :+ CaseDef(
                          Wildcard(),
                          None,
                          '{
                            ReaderError.wrongJson(
                              s"Unexpected discriminator found: $discriminator"
                            )(using FieldName(${ Expr(label) }))
                          }.asTerm
                        )
                      ).asExprOf[JsonReader[? <: T]]
                    }
                  }
              }.asTerm
            )
            term.asExprOf[JsonReader[T]]

      case None =>
        report.errorAndAbort(
          "Discriminator is required to derive JsonReader for sum type. Use @selector annotation"
        )

  private def distinct(tpes: List[TypeRepr]) =
    tpes.foldLeft(List.empty[TypeRepr]) { (acc, tpe) =>
      if (acc.exists(_ =:= tpe)) acc
      else tpe :: acc
    }

  private def isRecursive(tpe: TypeRepr, childTpe: TypeRepr): Boolean =
    tpe =:= childTpe || (childTpe match
      case AppliedType(_, types) => types.exists(isRecursive(tpe, _))
      case _                     => false
    )

  private def deriveMissingReaders(
      thisTpe: TypeRepr,
      tpes: List[TypeRepr]
  ): (List[ValDef], Map[TypeRepr, Ref]) =
    val (stats, refs) = distinct(tpes)
      .filterNot(isRecursive(thisTpe, _))
      .flatMap { tpe =>
        tpe.asType match
          case '[t] =>
            lookupOpt[JsonReader[t]].map {
              _.asTerm match
                case ident: Ident =>
                  Left(ident)
                case other =>
                  Right(other)
            } match
              case Some(Left(_)) =>
                None

              case other =>
                val valDef = ValDef(
                  Symbol.newVal(
                    Symbol.spliceOwner,
                    s"given_JsonReader_${tpe.show(using Printer.TypeReprShortCode)}",
                    TypeRepr.of[JsonReader[t]],
                    Flags.Given,
                    Symbol.noSymbol
                  ),
                  Some(
                    other
                      .flatMap(_.toOption)
                      .getOrElse {
                        '{
                          JsonReader.derived[t](using
                            ${ lookup[scala.deriving.Mirror.Of[t]] }
                          )
                        }.asTerm
                      }
                  )
                )
                Some((valDef, (tpe, Ref(valDef.symbol))))
      }
      .unzip
    (stats, refs.toMap)

  @deprecated
  def deriveJsonReaderForProductLegacy[T: Type](
      config: Expr[ReaderDerivationConfig],
      mirror: Expr[Mirror.ProductOf[T]]
  ): Expr[JsonReader[T]] =
    deriveJsonReaderForProduct(
      parseLegacyReaderDerivationConfig(config, mirror)
    )

  @deprecated
  def deriveJsonWriterForProductLegacy[T: Type](
      config: Expr[WriterDerivationConfig],
      mirror: Expr[Mirror.ProductOf[T]]
  ): Expr[JsonObjectWriter[T]] =
    deriveJsonWriterForProduct(
      parseLegacyWriterDerivationConfig(config, mirror)
    )

  @deprecated
  def deriveJsonWriterForSumLegacy[T: Type](
      config: Expr[WriterDerivationConfig]
  ): Expr[JsonObjectWriter[T]] =
    deriveJsonWriterForSum(Some(parseLegacyDiscriminator(config)))




© 2015 - 2024 Weber Informatics LLC | Privacy Policy