All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sttp.tapir.internal.EndpointAnnotationsMacro.scala Maven / Gradle / Ivy

package sttp.tapir.internal

import sttp.model.Header
import sttp.model.headers.Cookie
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.EndpointIO.annotations._
import sttp.tapir.{Codec, EndpointTransput, MultipartCodec, RawPart, Schema}

import scala.collection.mutable
import scala.reflect.macros.blackbox

private[tapir] abstract class EndpointAnnotationsMacro(val c: blackbox.Context) {
  import c.universe._

  protected val headerType = c.weakTypeOf[header]
  protected val headersType = c.weakTypeOf[headers]
  protected val cookiesType = c.weakTypeOf[cookies]

  private val schemaDescriptionType = c.weakTypeOf[Schema.annotations.description]
  private val schemaEncodedExampleType = c.weakTypeOf[Schema.annotations.encodedExample]
  private val schemaDefaultType = c.weakTypeOf[Schema.annotations.default[_]]
  private val schemaFormatType = c.weakTypeOf[Schema.annotations.format]
  private val schemaDeprecatedType = c.weakTypeOf[Schema.annotations.deprecated]
  private val schemaHiddenType = c.weakTypeOf[Schema.annotations.hidden]
  private val schemaValidateType = c.weakTypeOf[Schema.annotations.validate[_]]

  private val descriptionType = c.weakTypeOf[description]
  private val exampleType = c.weakTypeOf[example]
  private val customiseType = c.weakTypeOf[customise]

  private val apikeyType = c.weakTypeOf[apikey]
  protected val securitySchemeNameType = c.weakTypeOf[securitySchemeName]

  protected def validateCaseClass[A](util: CaseClassUtil[c.type, A]): Unit = {
    if (1 < util.fields.flatMap(bodyAnnotation).size) {
      c.abort(c.enclosingPosition, "No more than one body annotation is allowed")
    }
  }

  type StringListCodec[T] = Codec[List[String], T, TextPlain]
  val stringListConstructor: c.universe.Type = typeOf[StringListCodec[_]].typeConstructor

  protected def makeHeaderIO(field: c.Symbol)(altName: Option[String]): Tree = {
    val name = altName.getOrElse(field.name.toTermName.decodedName.toString)
    q"_root_.sttp.tapir.header[${field.asTerm.info}]($name)"
  }

  protected def makeHeadersIO(field: c.Symbol): Tree =
    if (field.info.resultType =:= typeOf[List[Header]]) {
      q"_root_.sttp.tapir.headers"
    } else {
      c.abort(c.enclosingPosition, s"Annotation @headers can be applied only for field with type ${typeOf[List[Header]]}")
    }

  protected def makeCookiesIO(field: c.Symbol): Tree =
    if (field.info.resultType =:= typeOf[List[Cookie]]) {
      q"_root_.sttp.tapir.cookies"
    } else {
      c.abort(c.enclosingPosition, s"Annotation @cookies can be applied only for field with type ${typeOf[List[Cookie]]}")
    }

  protected def summonCodec(field: c.Symbol, tpe: Type): Tree = {
    val codecTpe = appliedType(tpe, field.asTerm.info)
    val codec = c.inferImplicitValue(codecTpe, silent = true)
    if (codec == EmptyTree) {
      c.abort(c.enclosingPosition, s"Unable to resolve implicit value of type ${codecTpe.dealias}")
    }
    codec
  }

  protected def bodyAnnotation(field: c.Symbol): Option[c.universe.Annotation] =
    field.annotations.find(_.tree.tpe <:< typeOf[body[_, _]])

  protected def makeBodyIO(field: c.Symbol)(ann: c.universe.Annotation): Tree = {
    val annTpe = ann.tree.tpe
    val bodyTypeType = annTpe.member(TermName("bodyType")).infoIn(annTpe).finalResultType
    val rawType = bodyTypeType.typeArgs.head
    val resultType = field.asTerm.info
    if (rawType =:= typeOf[Seq[RawPart]]) { // multipart body
      val codecTpe = appliedType(typeOf[MultipartCodec[_]], resultType)
      val codec = c.inferImplicitValue(codecTpe, silent = true)
      if (codec == EmptyTree) {
        c.abort(c.enclosingPosition, s"Unable to resolve implicit value of type ${codecTpe.dealias}")
      }
      q"_root_.sttp.tapir.EndpointIO.Body(${codec}.rawBodyType, ${codec}.codec, _root_.sttp.tapir.EndpointIO.Info.empty)"
    } else {
      val codecFormatType = annTpe.member(TermName("cf")).infoIn(annTpe).finalResultType
      val codecTpe = appliedType(typeOf[Codec[_, _, _]], rawType, resultType, codecFormatType)
      val codec = c.inferImplicitValue(codecTpe, silent = true)
      if (codec == EmptyTree) {
        c.abort(c.enclosingPosition, s"Unable to resolve implicit value of type ${codecTpe.dealias}")
      }
      q"_root_.sttp.tapir.EndpointIO.Body(${c.untypecheck(ann.tree)}.bodyType, $codec, _root_.sttp.tapir.EndpointIO.Info.empty)"
    }
  }

  protected def mapToTargetFunc[A](inputIdxToFieldIdx: mutable.Map[Int, Int], util: CaseClassUtil[c.type, A]): Tree = {
    val className = util.className
    if (inputIdxToFieldIdx.size > 1) {
      val tupleTypeComponents = (0 until inputIdxToFieldIdx.size) map { idx =>
        val field = util.fields(inputIdxToFieldIdx(idx))
        q"${field.asTerm.info}"
      }

      val fieldIdxToInputIdx = inputIdxToFieldIdx.map(_.swap)

      val tupleType = tq"(..$tupleTypeComponents)"
      val ctorArgs = (0 until fieldIdxToInputIdx.size) map { idx =>
        val fieldName = TermName(s"_${fieldIdxToInputIdx(idx) + 1}")
        q"t.$fieldName"
      }

      q"(t: $tupleType) => $className(..$ctorArgs)"
    } else if (inputIdxToFieldIdx.size == 1) {
      q"(t: ${util.fields.head.info}) => $className(t)"
    } else {
      q"(t: ${}) => $className()"
    }
  }

  protected def mapFromTargetFunc[A](inputIdxToFieldIdx: mutable.Map[Int, Int], util: CaseClassUtil[c.type, A]): Tree = {
    val tupleArgs = (0 until inputIdxToFieldIdx.size) map { idx =>
      val field = util.fields(inputIdxToFieldIdx(idx))
      val fieldName = TermName(s"${field.name}")
      q"t.$fieldName"
    }
    val classType = util.classSymbol.asType
    q"(t: $classType) => (..$tupleArgs)"
  }

  protected def addMetadataFromAnnotations[A](input: Tree, field: Symbol, util: CaseClassUtil[c.type, A]): Tree = {
    val transformations: List[Tree => Tree] = List(
      i => util.extractStringArgFromAnnotation(field, descriptionType).fold(i)(desc => q"$i.description($desc)"),
      i => util.extractStringArgFromAnnotation(field, schemaDescriptionType).fold(i)(desc => q"$i.schema(_.description($desc))"),
      i =>
        util
          .extractTreeFromAnnotation(field, exampleType)
          .fold(i)(example => q"$i.example($example)"),
      i =>
        util
          .extractTreeFromAnnotation(field, schemaEncodedExampleType)
          .fold(i)(encodedExample => q"$i.schema(_.encodedExample($encodedExample))"),
      i =>
        util
          .extractFirstTreeArgFromAnnotation(field, schemaDefaultType)
          .fold(i)(default => q"$i.default($default)"),
      i =>
        util
          .extractStringArgFromAnnotation(field, schemaFormatType)
          .fold(i)(format => q"$i.schema(_.format($format))"),
      i => if (util.annotated(field, schemaDeprecatedType)) q"$i.deprecated()" else i,
      i => if (util.annotated(field, schemaHiddenType)) q"$i.hidden()" else i,
      i =>
        util
          .extractTreeFromAnnotation(field, schemaValidateType)
          .fold(i)(validator => q"$i.validate($validator)"),
      i =>
        util.findAnnotation(field, apikeyType).fold(i) { a =>
          val challenge = authChallenge(a)
          setSecuritySchemeName(
            q"_root_.sttp.tapir.EndpointInput.Auth($i, $challenge, _root_.sttp.tapir.EndpointInput.AuthType.ApiKey(), _root_.sttp.tapir.EndpointInput.AuthInfo.Empty)",
            util.findAnnotation(field, securitySchemeNameType)
          )
        },
      i =>
        util
          .extractFirstTreeArgFromAnnotation(field, customiseType)
          .fold(i) { f =>
            q"sttp.tapir.internal.EndpointAnnotationsMacro.customise($i, ${c.untypecheck(f)})"
          }
    )

    transformations.foldLeft(input)((i, f) => f(i))
  }

  protected def info(any: Any): Unit =
    c.info(c.enclosingPosition, any.toString, true)

  protected def authChallenge(annotation: Annotation): Tree = {
    q"${c.untypecheck(annotation.tree)}.challenge"
  }

  protected def setSecuritySchemeName[A](auth: Tree, schemeName: Option[Annotation]): Tree = {
    schemeName.fold(auth) { name =>
      q"${c.untypecheck(auth)}.securitySchemeName(${c.untypecheck(name.tree)}.name)"
    }
  }
}

private[tapir] object EndpointAnnotationsMacro {
  // we assume that the customisation function doesn't return a value of a different type
  def customise[X <: EndpointTransput[_]](i: X, f: EndpointTransput[_] => EndpointTransput[_]): X = f(i).asInstanceOf[X]
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy