All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twilio.guardrail.generators.Scala.Http4sServerGenerator.scala Maven / Gradle / Ivy

The newest version!
package com.twilio.guardrail.generators.Scala

import cats.Monad
import cats.data.NonEmptyList
import cats.implicits._
import cats.Traverse
import com.twilio.guardrail.{ CustomExtractionField, RenderedRoutes, StrictProtocolElems, SwaggerUtil, Target, TracingField, UserError }
import com.twilio.guardrail.core.Tracker
import com.twilio.guardrail.extract.{ ServerRawResponse, TracingLabel }
import com.twilio.guardrail.generators.{ LanguageParameter, LanguageParameters }
import com.twilio.guardrail.generators.syntax._
import com.twilio.guardrail.generators.operations.TracingLabelFormatter
import com.twilio.guardrail.generators.syntax.Scala._
import com.twilio.guardrail.languages.ScalaLanguage
import com.twilio.guardrail.protocol.terms.{ ContentType, Header, Response, Responses }
import com.twilio.guardrail.protocol.terms.server._
import com.twilio.guardrail.shims._
import com.twilio.guardrail.terms.{ CollectionsLibTerms, RouteMeta, SecurityScheme }

import scala.meta._
import _root_.io.swagger.v3.oas.models.PathItem.HttpMethod
import _root_.io.swagger.v3.oas.models.Operation

object Http4sServerGenerator {
  def ServerTermInterp(implicit Cl: CollectionsLibTerms[ScalaLanguage, Target]): ServerTerms[ScalaLanguage, Target] =
    new ServerTermInterp
  class ServerTermInterp(implicit Cl: CollectionsLibTerms[ScalaLanguage, Target]) extends ServerTerms[ScalaLanguage, Target] {
    val customExtractionTypeName: Type.Name = Type.Name("E")

    def splitOperationParts(operationId: String): (List[String], String) = {
      val parts = operationId.split('.')
      (parts.drop(1).toList, parts.last)
    }
    implicit def MonadF: Monad[Target] = Target.targetInstances
    def generateResponseDefinitions(
        responseClsName: String,
        responses: Responses[ScalaLanguage],
        protocolElems: List[StrictProtocolElems[ScalaLanguage]]
    ) =
      Target.pure(Http4sHelper.generateResponseDefinitions(responseClsName, responses, protocolElems))

    def buildCustomExtractionFields(operation: Tracker[Operation], resourceName: List[String], customExtraction: Boolean) =
      for {
        _ <- Target.log.debug(s"buildCustomExtractionFields(${operation.unwrapTracker.showNotNull}, ${resourceName}, ${customExtraction})")
        res <- if (customExtraction) {
          for {
            operationId <- operation
              .downField("operationId", _.getOperationId())
              .raiseErrorIfEmpty("Missing operationId")
            operationId_ = Lit.String(splitOperationParts(operationId.unwrapTracker)._2)
          } yield Some(
            CustomExtractionField[ScalaLanguage](
              LanguageParameter.fromParam(param"extracted: $customExtractionTypeName"),
              q"""customExtract(${operationId_})"""
            )
          )
        } else Target.pure(None)
      } yield res

    def buildTracingFields(operation: Tracker[Operation], resourceName: List[String], tracing: Boolean) =
      Target.log.function("buildTracingFields")(for {
        _ <- Target.log.debug(s"Args: ${operation}, ${resourceName}, ${tracing}")
        res <- if (tracing) {
          for {
            operationId <- operation
              .downField("operationId", _.getOperationId())
              .map(_.map(splitOperationParts(_)._2))
              .raiseErrorIfEmpty("Missing operationId")
            label <- Target.fromOption[Lit.String](
              TracingLabel(operation)
                .map(Lit.String(_))
                .orElse(resourceName.lastOption.map(clientName => TracingLabelFormatter(clientName, operationId.unwrapTracker).toLit)),
              UserError(s"Missing client name (${operation.showHistory})")
            )
          } yield Some(TracingField[ScalaLanguage](LanguageParameter.fromParam(param"traceBuilder: TraceBuilder[F]"), q"""trace(${label})"""))
        } else Target.pure(None)
      } yield res)

    def generateRoutes(
        tracing: Boolean,
        resourceName: String,
        handlerName: String,
        basePath: Option[String],
        routes: List[GenerateRouteMeta[ScalaLanguage]],
        protocolElems: List[StrictProtocolElems[ScalaLanguage]],
        securitySchemes: Map[String, SecurityScheme[ScalaLanguage]]
    ) =
      for {
        renderedRoutes <- routes
          .traverse {
            case GenerateRouteMeta(
                operationId,
                methodName,
                responseClsName,
                customExtractionFields,
                tracingFields,
                sr @ RouteMeta(path, method, operation, securityRequirements),
                parameters,
                responses
                ) =>
              generateRoute(resourceName, basePath, methodName, responseClsName, sr, customExtractionFields, tracingFields, parameters, responses)
          }
          .map(_.flatten.sortBy(_.methodName))
        routeTerms = renderedRoutes.map(_.route)
        combinedRouteTerms <- combineRouteTerms(routeTerms)
        methodSigs = renderedRoutes.map(_.methodSig)
      } yield {
        RenderedRoutes[ScalaLanguage](
          List(combinedRouteTerms),
          List.empty,
          methodSigs,
          renderedRoutes
            .flatMap(_.supportDefinitions)
            .groupBy(_.structure)
            .flatMap(_._2.headOption)
            .toList
            .sortBy(_.toString()), // Only unique supportDefinitions by structure
          renderedRoutes.flatMap(_.handlerDefinitions)
        )
      }

    def renderHandler(
        handlerName: String,
        methodSigs: List[scala.meta.Decl.Def],
        handlerDefinitions: List[scala.meta.Stat],
        responseDefinitions: List[scala.meta.Defn],
        customExtraction: Boolean
    ) =
      Target.log.function("renderHandler")(for {
        _ <- Target.log.debug(s"Args: ${handlerName}, ${methodSigs}")
        extractType = List(tparam"-$customExtractionTypeName").filter(_ => customExtraction)
        tParams     = List(tparam"F[_]") ++ extractType
      } yield q"""
      trait ${Type.Name(handlerName)}[..$tParams] {
        ..${methodSigs ++ handlerDefinitions}
      }
    """)

    def getExtraRouteParams(customExtraction: Boolean, tracing: Boolean) =
      Target.log.function("getExtraRouteParams")(for {
        _ <- Target.log.debug(s"getExtraRouteParams(${tracing})")
        mapRoute = param"""mapRoute: (String, Request[F], F[Response[F]]) => F[Response[F]] = (_: String, _: Request[F], r: F[Response[F]]) => r"""
        customExtraction_ = if (customExtraction) {
          Option(param"""customExtract: String => Request[F] => $customExtractionTypeName""")
        } else Option.empty
        tracing_ = if (tracing) {
          Option(param"""trace: String => Request[F] => TraceBuilder[F]""")
        } else Option.empty
      } yield customExtraction_.toList ::: tracing_.toList ::: List(mapRoute))

    def generateSupportDefinitions(
        tracing: Boolean,
        securitySchemes: Map[String, SecurityScheme[ScalaLanguage]]
    ) =
      Target.pure(List.empty)

    def renderClass(
        resourceName: String,
        handlerName: String,
        annotations: List[scala.meta.Mod.Annot],
        combinedRouteTerms: List[scala.meta.Stat],
        extraRouteParams: List[scala.meta.Term.Param],
        responseDefinitions: List[scala.meta.Defn],
        supportDefinitions: List[scala.meta.Defn],
        customExtraction: Boolean
    ): Target[List[Defn]] =
      Target.log.function("renderClass")(for {
        _ <- Target.log.debug(s"Args: ${resourceName}, ${handlerName}, , ${extraRouteParams}")
        extractType     = List(customExtractionTypeName).map(x => tparam"$x").filter(_ => customExtraction)
        resourceTParams = List(tparam"F[_]") ++ extractType
        handlerTParams  = List(Type.Name("F")) ++ List(customExtractionTypeName).filter(_ => customExtraction)
        routesParams    = List(param"handler: ${Type.Name(handlerName)}[..$handlerTParams]")
      } yield q"""
        class ${Type.Name(resourceName)}[..$resourceTParams](..$extraRouteParams)(implicit F: Async[F]) extends Http4sDsl[F] with CirceInstances {

          ..${supportDefinitions};
          def routes(..${routesParams}): HttpRoutes[F] = HttpRoutes.of {
            ..${combinedRouteTerms}
          }
        }
      """ +: responseDefinitions)

    def getExtraImports(tracing: Boolean, supportPackage: NonEmptyList[String]) =
      Target.log.function("getExtraImports")(
        for {
          _ <- Target.log.debug(s"Args: ${tracing}")
        } yield List(
          q"import org.http4s.circe.CirceInstances",
          q"import org.http4s.dsl.Http4sDsl",
          q"import fs2.text._"
        )
      )

    def httpMethodToHttp4s(method: HttpMethod): Target[Term.Name] = method match {
      case HttpMethod.DELETE => Target.pure(Term.Name("DELETE"))
      case HttpMethod.GET    => Target.pure(Term.Name("GET"))
      case HttpMethod.PATCH  => Target.pure(Term.Name("PATCH"))
      case HttpMethod.POST   => Target.pure(Term.Name("POST"))
      case HttpMethod.PUT    => Target.pure(Term.Name("PUT"))
      case other             => Target.raiseUserError(s"Unknown method: ${other}")
    }

    def pathStrToHttp4s(basePath: Option[String], path: Tracker[String], pathArgs: List[LanguageParameter[ScalaLanguage]]): Target[(Pat, Option[Pat])] =
      (basePath.getOrElse("") + path.unwrapTracker).stripPrefix("/") match {
        case "" => Target.pure((p"${Term.Name("Root")}", None))
        case finalPath =>
          for {
            pathDirective <- SwaggerUtil.paths
              .generateUrlHttp4sPathExtractors(Tracker.cloneHistory(path, finalPath), pathArgs)
          } yield pathDirective
      }

    def directivesFromParams[T](
        required: LanguageParameter[ScalaLanguage] => Type => Target[T],
        multi: LanguageParameter[ScalaLanguage] => Type => (Term => Term) => Target[T],
        multiOpt: LanguageParameter[ScalaLanguage] => Type => (Term => Term) => Target[T],
        optional: LanguageParameter[ScalaLanguage] => Type => Target[T]
    )(params: List[LanguageParameter[ScalaLanguage]]): Target[List[T]] =
      for {
        directives <- params.traverse[Target, T] {
          case scalaParam @ LanguageParameter(_, param, _, _, argType) =>
            val containerTransformations = Map[String, Term => Term](
              "Iterable"   -> identity _,
              "List"       -> (term => q"$term.toList"),
              "Vector"     -> (term => q"$term.toVector"),
              "Seq"        -> (term => q"$term.toSeq"),
              "IndexedSeq" -> (term => q"$term.toIndexedSeq")
            )

            param match {
              case param"$_: Option[$container[$tpe]]" if containerTransformations.contains(container.syntax) =>
                multiOpt(scalaParam)(tpe)(containerTransformations(container.syntax))
              case param"$_: Option[$container[$tpe]] = $_" if containerTransformations.contains(container.syntax) =>
                multiOpt(scalaParam)(tpe)(containerTransformations(container.syntax))
              case param"$_: Option[$tpe]" =>
                optional(scalaParam)(tpe)
              case param"$_: Option[$tpe] = $_" =>
                optional(scalaParam)(tpe)
              case param"$_: $container[$tpe]" if containerTransformations.contains(container.syntax) =>
                multi(scalaParam)(tpe)(containerTransformations(container.syntax))
              case param"$_: $container[$tpe] = $_" if containerTransformations.contains(container.syntax) =>
                multi(scalaParam)(tpe)(containerTransformations(container.syntax))
              case _ => required(scalaParam)(argType)
            }
        }
      } yield directives

    def bodyToHttp4s(methodName: String, body: Option[LanguageParameter[ScalaLanguage]]): Target[Option[Term => Term]] =
      Target.pure(
        body.map {
          case LanguageParameter(_, _, paramName, _, _) =>
            content => q"req.decodeWith(${Term.Name(s"${methodName.uncapitalized}Decoder")}, strict = false) { ${param"$paramName"} => $content }"
        }
      )

    case class Param(generator: Option[Enumerator.Generator], matcher: Option[(Term, Pat)], handlerCallArg: Term)

    def headersToHttp4s: List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] =
      directivesFromParams(
        arg => {
          case t"String" =>
            Target.pure(Param(None, Some((q"req.headers.get(${arg.argName.toLit}.ci).map(_.value)", p"Some(${Pat.Var(arg.paramName)})")), arg.paramName))
          case tpe =>
            Target.pure(
              Param(
                None,
                Some((q"req.headers.get(${arg.argName.toLit}.ci).map(_.value).map(Json.fromString(_).as[$tpe])", p"Some(Right(${Pat.Var(arg.paramName)}))")),
                arg.paramName
              )
            )
        },
        arg => _ => _ => Target.raiseUserError(s"Unsupported Iterable[${arg}"),
        arg => _ => _ => Target.raiseUserError(s"Unsupported Option[Iterable[${arg}]]"),
        arg => {
          case t"String" => Target.pure(Param(None, None, q"req.headers.get(${arg.argName.toLit}.ci).map(_.value)"))
          case tpe =>
            Target.pure(
              Param(
                None,
                Some((q"req.headers.get(${arg.argName.toLit}.ci).map(_.value).map(Json.fromString(_).as[$tpe]).sequence", p"Right(${Pat.Var(arg.paramName)})")),
                arg.paramName
              )
            )
        }
      )

    def qsToHttp4s(methodName: String): List[LanguageParameter[ScalaLanguage]] => Target[Option[Pat]] =
      params =>
        directivesFromParams(
          arg => _ => Target.pure(p"${Term.Name(s"${methodName.capitalize}${arg.argName.value.capitalize}Matcher")}(${Pat.Var(arg.paramName)})"),
          arg => _ => _ => Target.pure(p"${Term.Name(s"${methodName.capitalize}${arg.argName.value.capitalize}Matcher")}(${Pat.Var(arg.paramName)})"),
          arg => _ => _ => Target.pure(p"${Term.Name(s"${methodName.capitalize}${arg.argName.value.capitalize}Matcher")}(${Pat.Var(arg.paramName)})"),
          arg => _ => Target.pure(p"${Term.Name(s"${methodName.capitalize}${arg.argName.value.capitalize}Matcher")}(${Pat.Var(arg.paramName)})")
        )(params).map {
          case Nil => Option.empty
          case x :: xs =>
            Some(xs.foldLeft[Pat](x) { case (a, n) => p"${a} +& ${n}" })
        }

    def formToHttp4s: List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] =
      directivesFromParams(
        arg => {
          case t"String" =>
            Target.pure(
              Param(None, Some((q"urlForm.values.get(${arg.argName.toLit}).flatMap(_.headOption)", p"Some(${Pat.Var(arg.paramName)})")), arg.paramName)
            )
          case tpe =>
            Target.pure(
              Param(
                None,
                Some(
                  (
                    q"urlForm.values.get(${arg.argName.toLit}).flatMap(_.headOption).map(Json.fromString(_).as[$tpe])",
                    p"Some(Right(${Pat.Var(arg.paramName)}))"
                  )
                ),
                arg.paramName
              )
            )
        },
        arg => {
          case t"String" =>
            _ => Target.pure(Param(None, Some((q"urlForm.values.get(${arg.argName.toLit})", p"Some(${Pat.Var(arg.paramName)})")), q"${arg.paramName}.toList"))
          case tpe =>
            _ =>
              Target.pure(
                Param(
                  None,
                  Some(
                    (
                      q"urlForm.values.get(${arg.argName.toLit}).flatMap(_.toList).traverse(Json.fromString(_).as[$tpe])",
                      p"Some(Right(${Pat.Var(arg.paramName)}))"
                    )
                  ),
                  arg.paramName
                )
              )
        },
        arg => {
          case t"String" => _ => Target.pure(Param(None, None, q"urlForm.values.get(${arg.argName.toLit}).map(_.toList)"))
          case tpe =>
            _ =>
              Target.pure(
                Param(
                  None,
                  Some(
                    (
                      q"urlForm.values.get(${arg.argName.toLit}).flatMap(_.toList).map(Json.fromString(_).as[$tpe]).sequence.sequence",
                      p"Right(${Pat.Var(arg.paramName)})"
                    )
                  ),
                  arg.paramName
                )
              )
        },
        arg => {
          case t"String" =>
            Target.pure(Param(None, Some((q"urlForm.values.get(${arg.argName.toLit}).traverse(_.toList)", p"List(${Pat.Var(arg.paramName)})")), arg.paramName))
          case tpe =>
            Target.pure(
              Param(
                None,
                Some(
                  (
                    q"urlForm.values.get(${arg.argName.toLit}).traverse(_.toList).map(_.traverse(Json.fromString(_).as[$tpe]))",
                    p"List(Right(${Pat.Var(arg.paramName)}))"
                  )
                ),
                arg.paramName
              )
            )
        }
      )

    def asyncFormToHttp4s(methodName: String): List[LanguageParameter[ScalaLanguage]] => Target[List[Param]] =
      directivesFromParams(
        arg =>
          elemType =>
            if (arg.isFile) {
              Target.pure(
                Param(
                  None,
                  Some((q"multipart.parts.find(_.name.contains(${arg.argName.toLit})).map(_.body)", p"Some(${Pat.Var(arg.paramName)})")),
                  arg.paramName
                )
              )
            } else
              elemType match {
                case t"String" =>
                  Target.pure(
                    Param(
                      Some(
                        enumerator"${Pat.Var(Term.Name(s"${arg.argName.value}Option"))} <- multipart.parts.find(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid).sequence"
                      ),
                      Some(
                        (
                          Term.Name(s"${arg.argName.value}Option"),
                          p"Some(${Pat.Var(arg.paramName)})"
                        )
                      ),
                      arg.paramName
                    )
                  )
                case tpe =>
                  Target.pure(
                    Param(
                      Some(
                        enumerator"${Pat.Var(Term.Name(s"${arg.argName.value}Option"))} <- multipart.parts.find(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid.flatMap(str => F.fromEither(Json.fromString(str).as[$tpe]))).sequence"
                      ),
                      Some(
                        (
                          Term.Name(s"${arg.argName.value}Option"),
                          p"Some(${Pat.Var(arg.paramName)})"
                        )
                      ),
                      arg.paramName
                    )
                  )
              },
        arg =>
          elemType =>
            _ =>
              if (arg.isFile) {
                Target.pure(Param(None, None, q"multipart.parts.filter(_.name.contains(${arg.argName.toLit})).map(_.body)"))
              } else
                elemType match {
                  case t"String" =>
                    Target.pure(
                      Param(
                        Some(
                          enumerator"${Pat.Var(arg.paramName)} <- multipart.parts.filter(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid).sequence"
                        ),
                        None,
                        arg.paramName
                      )
                    )
                  case tpe =>
                    Target.pure(
                      Param(
                        Some(
                          enumerator"${Pat.Var(arg.paramName)} <- multipart.parts.filter(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid.flatMap(str => F.fromEither(Json.fromString(str).as[$tpe]))).sequence"
                        ),
                        None,
                        arg.paramName
                      )
                    )
                },
        arg =>
          elemType =>
            _ =>
              if (arg.isFile) {
                Target.pure(Param(None, None, q"Option(multipart.parts.filter(_.name.contains(${arg.argName.toLit})).map(_.body)).filter(_.nonEmpty)"))
              } else
                elemType match {
                  case t"String" =>
                    Target.pure(
                      Param(
                        Some(
                          enumerator"${Pat.Var(arg.paramName)} <- multipart.parts.filter(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid).sequence.map(Option(_).filter(_.nonEmpty))"
                        ),
                        None,
                        arg.paramName
                      )
                    )
                  case tpe =>
                    Target.pure(
                      Param(
                        Some(
                          enumerator"${Pat.Var(arg.paramName)} <- multipart.parts.filter(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid.flatMap(str => F.fromEither(Json.fromString(str).as[$tpe]))).sequence.map(Option(_).filter(_.nonEmpty))"
                        ),
                        None,
                        arg.paramName
                      )
                    )
                },
        arg =>
          elemType =>
            if (arg.isFile) {
              Target.pure(Param(None, None, q"multipart.parts.find(_.name.contains(${arg.argName.toLit})).map(_.body)"))
            } else
              elemType match {
                case t"String" =>
                  Target.pure(
                    Param(
                      Some(
                        enumerator"${Pat.Var(arg.paramName)} <- multipart.parts.find(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid).sequence"
                      ),
                      None,
                      arg.paramName
                    )
                  )
                case tpe =>
                  Target.pure(
                    Param(
                      Some(
                        enumerator"${Pat.Var(arg.paramName)} <- multipart.parts.find(_.name.contains(${arg.argName.toLit})).map(_.body.through(utf8Decode).compile.foldMonoid.flatMap(str => F.fromEither(Json.fromString(str).as[$tpe]))).sequence"
                      ),
                      None,
                      arg.paramName
                    )
                  )
              }
      )

    case class RenderedRoute(methodName: String, route: Case, methodSig: Decl.Def, supportDefinitions: List[Defn], handlerDefinitions: List[Stat])

    def generateRoute(
        resourceName: String,
        basePath: Option[String],
        methodName: String,
        responseClsName: String,
        route: RouteMeta,
        customExtractionFields: Option[CustomExtractionField[ScalaLanguage]],
        tracingFields: Option[TracingField[ScalaLanguage]],
        parameters: LanguageParameters[ScalaLanguage],
        responses: Responses[ScalaLanguage]
    ): Target[Option[RenderedRoute]] =
      // Generate the pair of the Handler method and the actual call to `complete(...)`
      Target.log.function("generateRoute")(for {
        _ <- Target.log.debug(s"Args: ${resourceName}, ${basePath}, ${route}, ${tracingFields}")
        RouteMeta(path, method, operation, securityRequirements) = route

        formArgs   <- prepareParameters(parameters.formParams)
        headerArgs <- prepareParameters(parameters.headerParams)
        pathArgs   <- prepareParameters(parameters.pathParams)
        qsArgs     <- prepareParameters(parameters.queryStringParams)
        bodyArgs   <- prepareParameters(parameters.bodyParams)

        http4sMethod <- httpMethodToHttp4s(method)
        pathWithQs   <- pathStrToHttp4s(basePath, path, pathArgs)
        (http4sPath, additionalQs) = pathWithQs
        http4sQs   <- qsToHttp4s(methodName)(qsArgs)
        http4sBody <- bodyToHttp4s(methodName, bodyArgs)
        asyncFormProcessing = formArgs.exists(_.isFile)
        http4sForm         <- if (asyncFormProcessing) asyncFormToHttp4s(methodName)(formArgs) else formToHttp4s(formArgs)
        http4sHeaders      <- headersToHttp4s(headerArgs)
        supportDefinitions <- generateSupportDefinitions(route, parameters)
      } yield {
        val (responseCompanionTerm, responseCompanionType) =
          (Term.Name(responseClsName), Type.Name(responseClsName))
        val responseType = ServerRawResponse(operation)
          .filter(_ == true)
          .fold[Type](t"$responseCompanionType")(Function.const(t"Response[F]"))
        val orderedParameters: List[List[LanguageParameter[ScalaLanguage]]] = List((pathArgs ++ qsArgs ++ bodyArgs ++ formArgs ++ headerArgs).toList) ++
              tracingFields
                .map(_.param)
                .map(List(_)) ++
              customExtractionFields
                .map(_.param)
                .map(List(_))

        val entityProcessor = http4sBody
          .orElse(Some((content: Term) => q"req.decode[UrlForm] { urlForm => $content }").filter(_ => formArgs.nonEmpty && formArgs.forall(!_.isFile)))
          .orElse(Some((content: Term) => q"req.decode[Multipart[F]] { multipart => $content }").filter(_ => formArgs.nonEmpty))
        val fullRouteMatcher =
          NonEmptyList.fromList(List(additionalQs, http4sQs).flatten).fold(p"$http4sMethod -> $http4sPath") { qs =>
            p"$http4sMethod -> $http4sPath :? ${qs.reduceLeft((a, n) => p"$a :& $n")}"
          }
        val fullRouteWithTracingMatcher = tracingFields
          .map(_ => p"$fullRouteMatcher ${Term.Name(s"usingFor${methodName.capitalize}")}(traceBuilder)")
          .getOrElse(fullRouteMatcher)
        val fullRouteWithTracingAndExtraction = customExtractionFields
          .map(_ => p"$fullRouteWithTracingMatcher ${Term.Name(s"extractorFor${methodName.capitalize}")}(extracted)")
          .getOrElse(fullRouteWithTracingMatcher)
        val handlerCallArgs: List[List[Term]] = List(List(responseCompanionTerm)) ++ List(
                (pathArgs ++ qsArgs ++ bodyArgs).map(_.paramName) ++ (http4sForm ++ http4sHeaders).map(_.handlerCallArg)
              ) ++
              tracingFields.map(_.param.paramName).map(List(_)) ++
              customExtractionFields.map(_.param.paramName).map(List(_))
        val handlerCall = q"handler.${Term.Name(methodName)}(...${handlerCallArgs})"
        val isGeneric   = Http4sHelper.isDefinitionGeneric(responses)
        val responseExpr = ServerRawResponse(operation)
          .filter(_ == true)
          .fold[Term] {
            val marshallers = responses.value.map {
              case Response(statusCodeName, valueType, headers) =>
                val responseTerm  = Term.Name(s"${statusCodeName.value}")
                val baseRespType  = Type.Select(responseCompanionTerm, Type.Name(statusCodeName.value))
                val respType      = if (isGeneric) Type.Apply(baseRespType, List(t"F")) else baseRespType
                val generatorName = Term.Name(s"$methodName${statusCodeName}EntityResponseGenerator")
                val encoderName   = Term.Name(s"$methodName${statusCodeName}Encoder")
                (valueType, headers.value) match {
                  case (None, Nil) =>
                    p"case $responseCompanionTerm.$responseTerm => F.pure(Response[F](status = org.http4s.Status.${statusCodeName}))"
                  case (Some(_), Nil) =>
                    p"case resp: $respType => $generatorName(resp.value)(F,$encoderName)"
                  case (None, headersList) =>
                    val (http4sHeaders, http4sHeadersDefinitions) = createHttp4sHeaders(headersList)
                    p"""case resp: $respType =>
                          ..$http4sHeadersDefinitions
                          F.pure(Response[F](status = org.http4s.Status.$statusCodeName, headers = Headers($http4sHeaders)))
                      """
                  case (Some(_), headersList) =>
                    val (http4sHeaders, http4sHeadersDefinitions) = createHttp4sHeaders(headersList)
                    val valueTerm                                 = q"resp.value"
                    p"""case resp: $respType =>
                          ..$http4sHeadersDefinitions
                          $generatorName($valueTerm, $http4sHeaders:_*)(F,$encoderName)
                      """
                }
            }
            q"$handlerCall flatMap ${Term.PartialFunction(marshallers)}"
          }(_ => handlerCall)
        val matchers = (http4sForm ++ http4sHeaders).flatMap(_.matcher)
        val responseInMatch = NonEmptyList.fromList(matchers).fold(responseExpr) {
          case NonEmptyList((expr, pat), Nil) =>
            Term.Match(expr, List(Case(pat, None, responseExpr), Case(p"_", None, q"""BadRequest("Invalid data")""")))
          case matchers @ NonEmptyList(_, _) =>
            val NonEmptyList(head, xs) = matchers.reverse
            val (base, rest)           = xs.splitAt(21).bimap(left => NonEmptyList(head, left).reverse, _.grouped(21).map(_.reverse.unzip).toList)
            val (buildTerms, buildPat) = rest.foldLeft[(Term => Term, Pat => Pat)]((identity, identity)) {
              case ((accTerm, accPat), (nextTermGroup, nextPatGroup)) =>
                (next => accTerm(q"(..${nextTermGroup :+ next})"), next => accPat(p"(..${nextPatGroup :+ next})"))
            }

            val (fullTerm, fullPat) = base match {
              case NonEmptyList((term, pat), Nil) =>
                (buildTerms(term), buildPat(pat))
              case NonEmptyList((term, pat), xs) =>
                val (terms, pats) = xs.unzip
                (buildTerms(q"(..${term +: terms})"), buildPat(p"(..${pat +: pats})"))
            }
            Term.Match(
              fullTerm,
              List(
                Case(fullPat, None, responseExpr),
                Case(Pat.Wildcard(), None, q"""BadRequest("Invalid data")""")
              )
            )
        }
        val responseInMatchInFor = (http4sForm ++ http4sHeaders).flatMap(_.generator) match {
          case Nil        => responseInMatch
          case generators => q"for {..${generators :+ enumerator"response <- $responseInMatch"}} yield response"
        }
        val routeBody = entityProcessor.fold[Term](responseInMatchInFor)(_.apply(responseInMatchInFor))

        val fullRoute: Case =
          p"""case req @ $fullRouteWithTracingAndExtraction =>
             mapRoute($methodName, req, {$routeBody})
            """

        val respond: List[List[Term.Param]] = List(List(param"respond: $responseCompanionTerm.type"))

        val params: List[List[Term.Param]] = respond ++ orderedParameters.map(
                _.map(
                  scalaParam =>
                    scalaParam.param.copy(
                      decltpe =
                        (
                          if (scalaParam.isFile) {
                            if (scalaParam.required) {
                              Some(t"Stream[F, Byte]")
                            } else {
                              Some(t"Option[Stream[F, Byte]]")
                            }
                          } else {
                            scalaParam.param.decltpe
                          }
                        )
                    )
                )
              )

        val consumes = operation.unwrapTracker.consumes.toList.flatMap(ContentType.unapply(_))
        val produces = operation.unwrapTracker.produces.toList.flatMap(ContentType.unapply(_))
        val codecs   = if (ServerRawResponse(operation).getOrElse(false)) Nil else generateCodecs(methodName, bodyArgs, responses, consumes, produces)
        val respType = if (isGeneric) t"$responseType[F]" else responseType
        Some(
          RenderedRoute(
            methodName,
            fullRoute,
            q"""def ${Term.Name(methodName)}(...${params}): F[$respType]""",
            supportDefinitions ++ generateQueryParamMatchers(methodName, qsArgs) ++ codecs ++
                tracingFields
                  .map(_.term)
                  .map(generateTracingExtractor(methodName, _)) ++
                customExtractionFields
                  .map(_.term)
                  .map(generateCustomExtractionFieldsExtractor(methodName, _)),
            List.empty //handlerDefinitions
          )
        )
      })

    def createHttp4sHeaders(headers: List[Header[ScalaLanguage]]): (Term.Name, List[Defn.Val]) = {
      val (names, definitions) = headers.map {
        case Header(name, required, _, termName) =>
          val nameLiteral = Lit.String(name)
          val headerName  = Term.Name(s"${name.toCamelCase}Header")
          val pattern     = Pat.Var(headerName)
          val v           = if (required) q"List(Header($nameLiteral, resp.$termName))" else q"resp.$termName.map(Header($nameLiteral,_)).toList"
          (headerName, q"val $pattern = $v")
      }.unzip
      val headersTerm = Term.Name("responseHeaders")
      val allHeaders  = q"val ${Pat.Var(headersTerm)} = List(..$names).flatten"
      (headersTerm, definitions :+ allHeaders)
    }

    def combineRouteTerms(terms: List[Case]): Target[Term] =
      Target.log.function("combineRouteTerms")(for {
        _      <- Target.log.debug(s"Args: <${terms.length} routes>")
        routes <- Target.fromOption(NonEmptyList.fromList(terms), UserError("Generated no routes, no source to generate"))
        _      <- routes.traverse(route => Target.log.debug(route.toString))
      } yield scala.meta.Term.PartialFunction(routes.toList))

    def generateSupportDefinitions(route: RouteMeta, parameters: LanguageParameters[ScalaLanguage]): Target[List[Defn]] =
      for {
        operation <- Target.pure(route.operation)

        pathArgs = parameters.pathParams
      } yield {
        generatePathParamExtractors(pathArgs)
      }

    def generatePathParamExtractors(pathArgs: List[LanguageParameter[ScalaLanguage]]): List[Defn] =
      pathArgs
        .map(_.argType)
        .flatMap({
          // Strip out provided type extractors (see dsl/src/main/scala/org/http4s/dsl/impl/Path.scala)
          case t"Int"            => None
          case t"Long"           => None
          case t"String"         => None
          case t"java.util.UUID" => None
          // Attempt to provide useful extractor names
          case tpe @ Type.Name(name)                 => Some(name -> tpe)
          case tpe @ Type.Select(_, Type.Name(name)) => Some(name -> tpe)
          // Give up for non-primitive type and rely on backtick-escaping
          case tpe => Some(tpe.toString -> tpe)
        })
        .distinctBy(_._1)
        .map({
          case (name, tpe) =>
            q"""
            object ${Term.Name(s"${name}Var")} {
              def unapply(str: String): Option[${tpe}] = {
                if (!str.isEmpty)
                  Json.fromString(str).as[${tpe}].toOption
                else
                  None
              }
            }
          """
        })

    def generateQueryParamMatchers(methodName: String, qsArgs: List[LanguageParameter[ScalaLanguage]]): List[Defn] = {
      val (decoders, matchers) = qsArgs
        .traverse({
          case LanguageParameter(_, param, _, argName, argType) =>
            val containerTransformations = Map[String, Term => Term](
              "Iterable"   -> identity _,
              "List"       -> (term => q"$term.toList"),
              "Vector"     -> (term => q"$term.toVector"),
              "Seq"        -> (term => q"$term.toSeq"),
              "IndexedSeq" -> (term => q"$term.toIndexedSeq")
            )
            val matcherName = Term.Name(s"${methodName.capitalize}${argName.value.capitalize}Matcher")
            val (queryParamMatcher, elemType) = param match {
              case param"$_: Option[$container[$tpe]]" if containerTransformations.contains(container.syntax) =>
                (q"""
                  object ${matcherName} {
                    val delegate = new OptionalMultiQueryParamDecoderMatcher[$tpe](${argName.toLit}) {}
                    def unapply(params: Map[String, Seq[String]]): Option[Option[$container[$tpe]]] = delegate.unapply(params).collectFirst {
                      case cats.data.Validated.Valid(value) => Option(value).filter(_.nonEmpty).map(x => ${containerTransformations(container.syntax)(q"x")})
                    }
                  }
                 """, tpe)
              case param"$_: Option[$container[$tpe]] = $_" if containerTransformations.contains(container.syntax) =>
                (q"""
                  object ${matcherName} {
                    val delegate = new OptionalMultiQueryParamDecoderMatcher[$tpe](${argName.toLit}) {}
                    def unapply(params: Map[String, Seq[String]]): Option[Option[$container[$tpe]]] = delegate.unapply(params).collectFirst {
                      case cats.data.Validated.Valid(value) => Option(value).filter(_.nonEmpty).map(x => ${containerTransformations(container.syntax)(q"x")})
                    }
                  }
                 """, tpe)
              case param"$_: Option[$tpe]" =>
                (
                  q"""object ${matcherName} extends OptionalQueryParamDecoderMatcher[$tpe](${argName.toLit})""",
                  tpe
                )
              case param"$_: Option[$tpe] = $_" =>
                (
                  q"""object ${matcherName} extends OptionalQueryParamDecoderMatcher[$tpe](${argName.toLit})""",
                  tpe
                )
              case param"$_: $container[$tpe]" if containerTransformations.contains(container.syntax) =>
                (q"""
                   object ${matcherName} {
                     val delegate = new QueryParamDecoderMatcher[$tpe](${argName.toLit}) {}
                     def unapply(params: Map[String, Seq[String]]): Option[$container[$tpe]] = delegate.unapplySeq(params).map(x => ${containerTransformations(
                  container.syntax
                )(q"x")})
                   }
                 """, tpe)
              case param"$_: $container[$tpe] = $_" if containerTransformations.contains(container.syntax) =>
                (q"""
                   object ${matcherName} {
                     val delegate = new QueryParamDecoderMatcher[$tpe](${argName.toLit}) {}
                     def unapply(params: Map[String, Seq[String]]): Option[$container[$tpe]] = delegate.unapplySeq(params).map(x => ${containerTransformations(
                  container.syntax
                )(q"x")})
                   }
                 """, tpe)
              case _ =>
                (
                  q"""object ${matcherName} extends QueryParamDecoderMatcher[$argType](${argName.toLit})""",
                  argType
                )
            }
            if (!List("Unit", "Boolean", "Double", "Float", "Short", "Int", "Long", "Char", "String").contains(elemType.toString())) {
              val queryParamDecoder = q"""
                implicit val ${Pat.Var(Term.Name(s"${elemType.toString()}QueryParamDecoder"))}: QueryParamDecoder[$elemType] = (value: QueryParameterValue) =>
                    Json.fromString(value.value).as[$elemType]
                      .leftMap(t => ParseFailure("Query decoding failed", t.getMessage))
                      .toValidatedNel
              """
              (List((elemType, queryParamDecoder)), queryParamMatcher)
            } else {
              (List.empty, queryParamMatcher)
            }
        })

      decoders.distinctBy(_._1.toString()).map(_._2) ++ matchers
    }

    /**
      * It's not possible to use backticks inside pattern matching as it has different semantics: backticks inside match
      * are just references to an already existing bindings.
      */
    def prepareParameters[F[_]: Traverse](parameters: F[LanguageParameter[ScalaLanguage]]): Target[F[LanguageParameter[ScalaLanguage]]] =
      if (parameters.exists(param => param.paramName.syntax != param.paramName.value)) {
        // let's try to prefix them all with underscore and see if it helps
        for {
          _ <- Target.log.debug("Found that not all parameters could be represented as unescaped terms")
          res <- parameters.traverse[Target, LanguageParameter[ScalaLanguage]] { param =>
            val newName = Term.Name(s"_${param.paramName.value}")
            Target.log.debug(s"Escaping param ${param.argName.value}").flatMap { _ =>
              if (newName.syntax == newName.value) Target.pure(param.withParamName(newName))
              else Target.raiseUserError(s"Can't escape parameter with name ${param.argName.value}.")
            }
          }
        } yield res
      } else {
        Target.pure(parameters)
      }

    def generateCodecs(
        methodName: String,
        bodyArgs: Option[LanguageParameter[ScalaLanguage]],
        responses: Responses[ScalaLanguage],
        consumes: Seq[ContentType],
        produces: Seq[ContentType]
    ): List[Defn.Val] =
      generateDecoders(methodName, bodyArgs, consumes) ++ generateEncoders(methodName, responses, produces) ++ generateResponseGenerators(
            methodName,
            responses
          )

    def generateDecoders(methodName: String, bodyArgs: Option[LanguageParameter[ScalaLanguage]], consumes: Seq[ContentType]): List[Defn.Val] =
      bodyArgs.toList.flatMap {
        case LanguageParameter(_, _, _, _, argType) =>
          List(
            q"private[this] val ${Pat.Typed(Pat.Var(Term.Name(s"${methodName.uncapitalized}Decoder")), t"EntityDecoder[F, $argType]")} = ${Http4sHelper
              .generateDecoder(argType, consumes)}"
          )
      }

    def generateEncoders(methodName: String, responses: Responses[ScalaLanguage], produces: Seq[ContentType]): List[Defn.Val] =
      for {
        response    <- responses.value
        (_, tpe, _) <- response.value
      } yield {
        val contentTypes = response.value.map(_._1).map(List(_)).getOrElse(produces) //for OpenAPI 3.x we should take ContentType from the response
        q"private[this] val ${Pat.Var(Term.Name(s"$methodName${response.statusCodeName}Encoder"))} = ${Http4sHelper.generateEncoder(tpe, contentTypes)}"
      }

    def generateResponseGenerators(methodName: String, responses: Responses[ScalaLanguage]): List[Defn.Val] =
      for {
        response <- responses.value
        if response.value.nonEmpty
      } yield {
        q"private[this] val ${Pat.Var(Term.Name(s"$methodName${response.statusCodeName}EntityResponseGenerator"))} = ${Http4sHelper
          .generateEntityResponseGenerator(q"org.http4s.Status.${response.statusCodeName}")}"
      }

    def generateTracingExtractor(methodName: String, tracingField: Term): Defn.Object =
      q"""
         object ${Term.Name(s"usingFor${methodName.capitalize}")} {
           def unapply(r: Request[F]): Option[(Request[F], TraceBuilder[F])] = Some(r -> $tracingField(r))
         }
       """

    def generateCustomExtractionFieldsExtractor(methodName: String, extractField: Term): Defn.Object =
      q"""
         object ${Term.Name(s"extractorFor${methodName.capitalize}")} {
           def unapply(r: Request[F]): Option[(Request[F], $customExtractionTypeName)] = Some(r -> $extractField(r))
         }
       """
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy