All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sttp.tapir.server.interpreter.ServerInterpreter.scala Maven / Gradle / Ivy

There is a newer version: 1.11.9
Show newest version
package sttp.tapir.server.interpreter

import sttp.capabilities.StreamMaxLengthExceededException
import sttp.model.{Headers, StatusCode}
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.tapir.internal.{Params, ParamsAsAny, RichOneOfBody}
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interceptor._
import sttp.tapir.server.model.{MaxContentLength, ServerResponse, ValuedEndpointOutput}
import sttp.tapir.server.{model, _}
import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile}
import sttp.tapir.EndpointInfo
import sttp.tapir.AttributeKey

class ServerInterpreter[R, F[_], B, S](
    serverEndpoints: ServerRequest => List[ServerEndpoint[R, F]],
    requestBody: RequestBody[F, S],
    toResponseBody: ToResponseBody[B, S],
    interceptors: List[Interceptor[F]],
    deleteFile: TapirFile => F[Unit]
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]) {
  def apply(request: ServerRequest): F[RequestResult[B]] = monad.suspend {
    callInterceptors(interceptors, Nil, responder(defaultSuccessStatusCode)).apply(request, serverEndpoints(request))
  }

  /** Accumulates endpoint interceptors and calls `next` with the potentially transformed request. */
  private def callInterceptors(
      is: List[Interceptor[F]],
      eisAcc: List[EndpointInterceptor[F]],
      responder: Responder[F, B]
  ): RequestHandler[F, R, B] = {
    is match {
      case Nil => RequestHandler.from { (request, ses, _) => firstNotNone(request, ses, eisAcc.reverse, Nil) }
      case is =>
        is.head match {
          case ei: EndpointInterceptor[F] => callInterceptors(is.tail, ei :: eisAcc, responder)
          case i: RequestInterceptor[F] =>
            i(
              responder,
              { ei => RequestHandler.from { (request, ses, _) => callInterceptors(is.tail, ei :: eisAcc, responder).apply(request, ses) } }
            )
        }
    }
  }

  /** Try decoding subsequent server endpoints, until a non-None response is returned. */
  private def firstNotNone(
      request: ServerRequest,
      ses: List[ServerEndpoint[R, F]],
      endpointInterceptors: List[EndpointInterceptor[F]],
      accumulatedFailureContexts: List[DecodeFailureContext]
  ): F[RequestResult[B]] =
    ses match {
      case Nil => (RequestResult.Failure(accumulatedFailureContexts.reverse): RequestResult[B]).unit
      case ses =>
        val se = ses.head
        tryServerEndpoint[se.SECURITY_INPUT, se.PRINCIPAL, se.INPUT, se.ERROR_OUTPUT, se.OUTPUT](
          request,
          se,
          endpointInterceptors
        )
          .flatMap {
            case RequestResult.Failure(failureContexts) =>
              firstNotNone(request, ses.tail, endpointInterceptors, failureContexts ++: accumulatedFailureContexts)
            case r => r.unit
          }
    }

  private val defaultSecurityFailureResponse =
    ServerResponse[B](StatusCode.InternalServerError, Nil, None, None).unit

  private def endpointHandler(
      securityFailureResponse: => F[ServerResponse[B]],
      endpointInterceptors: List[EndpointInterceptor[F]]
  ): EndpointHandler[F, B] =
    endpointInterceptors.foldRight(defaultEndpointHandler(securityFailureResponse)) { case (interceptor, handler) =>
      interceptor(responder(defaultSuccessStatusCode), handler)
    }

  private def tryServerEndpoint[A, U, I, E, O](
      request: ServerRequest,
      se: ServerEndpoint.Full[A, U, I, E, O, R, F],
      endpointInterceptors: List[EndpointInterceptor[F]]
  ): F[RequestResult[B]] = {

    val resultOrValueFrom = new ResultOrValueFrom {
      def onDecodeFailure(input: EndpointInput[_], failure: DecodeResult.Failure): F[RequestResult[B]] = {
        val decodeFailureContext = interceptor.DecodeFailureContext(se.endpoint, input, failure, request)
        endpointHandler(defaultSecurityFailureResponse, endpointInterceptors)
          .onDecodeFailure(decodeFailureContext)
          .map {
            case Some(response) => RequestResult.Response(response)
            case None           => RequestResult.Failure(List(decodeFailureContext))
          }
      }
    }

    // 1. decoding both security & regular basic inputs - note that this does *not* include decoding the body
    val decodeBasicContext1 = DecodeInputsContext(request)
    // the security input doesn't have to match the whole path, a prefix is fine
    val (securityBasicInputs, decodeBasicContext2) =
      DecodeBasicInputs(se.endpoint.securityInput, decodeBasicContext1, matchWholePath = false)
    // the regular input is required to match the whole remaining path; otherwise a decode failure is reported
    // to keep the progress in path matching, we are using the context returned by decoding the security input
    val (regularBasicInputs, _) = DecodeBasicInputs(se.endpoint.input, decodeBasicContext2)
    (for {
      // 2. if the decoding failed, short-circuiting further processing with the decode failure that has a lower sort
      // index (so that the correct one is passed to the decode failure handler)
      _ <- resultOrValueFrom(DecodeBasicInputsResult.higherPriorityFailure(securityBasicInputs, regularBasicInputs))
      // 3. computing the security input value
      securityValues <- resultOrValueFrom(decodeBody(request, securityBasicInputs, se.info))
      securityParams <- resultOrValueFrom(InputValue(se.endpoint.securityInput, securityValues))
      inputValues <- resultOrValueFrom(regularBasicInputs)
      a = securityParams.asAny.asInstanceOf[A]
      // 4. running the security logic
      securityLogicResult <- ResultOrValue(
        se.securityLogic(monad)(a).map(Right(_): Either[RequestResult[B], Either[E, U]]).handleError { case t: Throwable =>
          endpointHandler(monad.error(t), endpointInterceptors)
            .onSecurityFailure(SecurityFailureContext(se, a, request))
            .map(r => Left(RequestResult.Response(r)): Either[RequestResult[B], Either[E, U]])
        }
      )
      response <- securityLogicResult match {
        case Left(e) =>
          resultOrValueFrom.value(
            endpointHandler(
              responder(defaultErrorStatusCode)(request, model.ValuedEndpointOutput(se.endpoint.errorOutput, e)),
              endpointInterceptors
            )
              .onSecurityFailure(SecurityFailureContext(se, a, request))
              .map(r => RequestResult.Response(r): RequestResult[B])
          )

        case Right(u) =>
          for {
            // 5. decoding the body of regular inputs, computing the input value, and running the main logic
            values <- resultOrValueFrom(decodeBody(request, inputValues, se.endpoint.info))
            params <- resultOrValueFrom(InputValue(se.endpoint.input, values))
            response <- resultOrValueFrom.value(
              endpointHandler(defaultSecurityFailureResponse, endpointInterceptors)
                .onDecodeSuccess(interceptor.DecodeSuccessContext(se, a, u, params.asAny.asInstanceOf[I], request))
                .map(r => RequestResult.Response(r): RequestResult[B])
            )
          } yield response
      }
    } yield response).fold
  }

  private def decodeBody(
      request: ServerRequest,
      result: DecodeBasicInputsResult,
      endpointInfo: EndpointInfo
  ): F[DecodeBasicInputsResult] =
    result match {
      case values: DecodeBasicInputsResult.Values =>
        val maxBodyLength = endpointInfo.attribute(AttributeKey[MaxContentLength]).map(_.value)
        values.bodyInputWithIndex match {
          case Some((Left(oneOfBodyInput), _)) =>
            oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match {
              case Some(Left(body))                                          => decodeBody(request, values, body, maxBodyLength)
              case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body, maxBodyLength)
              case None                                                      => unsupportedInputMediaTypeResponse(request, oneOfBodyInput)
            }
          case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) =>
            decodeStreamingBody(request, values, bodyInput, maxBodyLength)
          case None => (values: DecodeBasicInputsResult).unit
        }
      case failure: DecodeBasicInputsResult.Failure => (failure: DecodeBasicInputsResult).unit
    }

  private def decodeStreamingBody(
      request: ServerRequest,
      values: DecodeBasicInputsResult.Values,
      bodyInput: EndpointIO.StreamBodyWrapper[Any, Any],
      maxBodyLength: Option[Long]
  ): F[DecodeBasicInputsResult] =
    (bodyInput.codec.decode(requestBody.toStream(request, maxBodyLength)) match {
      case DecodeResult.Value(bodyV)     => values.setBodyInputValue(bodyV)
      case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
    }).unit

  private def decodeBody[RAW, T](
      request: ServerRequest,
      values: DecodeBasicInputsResult.Values,
      bodyInput: EndpointIO.Body[RAW, T],
      maxBodyLength: Option[Long]
  ): F[DecodeBasicInputsResult] = {
    requestBody
      .toRaw(request, bodyInput.bodyType, maxBodyLength)
      .flatMap { v =>
        bodyInput.codec.decode(v.value) match {
          case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit
          case failure: DecodeResult.Failure =>
            v.createdFiles
              .foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file)))
              .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult)
        }
      }
      .handleError { case e: StreamMaxLengthExceededException =>
        (DecodeBasicInputsResult.Failure(bodyInput, DecodeResult.Error("", e)): DecodeBasicInputsResult).unit
      }
  }

  private def unsupportedInputMediaTypeResponse(
      request: ServerRequest,
      oneOfBodyInput: EndpointIO.OneOfBody[_, _]
  ): F[DecodeBasicInputsResult] =
    (DecodeBasicInputsResult.Failure(
      oneOfBodyInput,
      DecodeResult
        .Mismatch(oneOfBodyInput.variants.map(_.range.toString()).mkString(", or: "), request.contentType.getOrElse(""))
    ): DecodeBasicInputsResult).unit

  private def defaultEndpointHandler(securityFailureResponse: => F[ServerResponse[B]]): EndpointHandler[F, B] =
    new EndpointHandler[F, B] {
      override def onDecodeSuccess[A, U, I](
          ctx: DecodeSuccessContext[F, A, U, I]
      )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] =
        ctx.serverEndpoint
          .logic(implicitly)(ctx.principal)(ctx.input)
          .flatMap {
            case Right(result) =>
              responder(defaultSuccessStatusCode)(ctx.request, model.ValuedEndpointOutput(ctx.serverEndpoint.output, result))
            case Left(err) =>
              responder(defaultErrorStatusCode)(ctx.request, model.ValuedEndpointOutput(ctx.serverEndpoint.errorOutput, err))
          }

      override def onSecurityFailure[A](
          ctx: SecurityFailureContext[F, A]
      )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = securityFailureResponse

      override def onDecodeFailure(
          ctx: DecodeFailureContext
      )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] =
        (None: Option[ServerResponse[B]]).unit(monad)
    }

  private def responder(defaultStatusCode: StatusCode): Responder[F, B] = new Responder[F, B] {
    override def apply[O](request: ServerRequest, output: ValuedEndpointOutput[O]): F[ServerResponse[B]] = {
      val outputValues =
        new EncodeOutputs(toResponseBody, request.acceptsContentTypes.getOrElse(Nil))
          .apply(output.output, ParamsAsAny(output.value), OutputValues.empty)
      val statusCode = outputValues.statusCode.getOrElse(defaultStatusCode)

      val headers = outputValues.headers
      (statusCode, outputValues.body) match {
        case (_, Some(bodyFromHeaders)) => ServerResponse(statusCode, headers, Some(bodyFromHeaders(Headers(headers))), Some(output)).unit
        case (_, None)                  => ServerResponse(statusCode, headers, None: Option[B], Some(output)).unit
      }
    }
  }

  private val defaultSuccessStatusCode: StatusCode = StatusCode.Ok
  private val defaultErrorStatusCode: StatusCode = StatusCode.BadRequest

  private case class ResultOrValue[T](v: F[Either[RequestResult[B], T]]) {
    def flatMap[U](f: T => ResultOrValue[U]): ResultOrValue[U] = {
      ResultOrValue(v.flatMap {
        case Left(r)  => (Left(r): Either[RequestResult[B], U]).unit
        case Right(v) => f(v).v
      })
    }
    def map[U](f: T => U): ResultOrValue[U] = {
      ResultOrValue(v.map {
        case Left(r)  => Left(r): Either[RequestResult[B], U]
        case Right(v) => Right(f(v))
      })
    }
    def fold(implicit ev: T =:= RequestResult[B]): F[RequestResult[B]] = v.map {
      case Left(r)  => r
      case Right(r) => r
    }
  }

  private abstract class ResultOrValueFrom {
    def apply(v: F[DecodeBasicInputsResult]): ResultOrValue[DecodeBasicInputsResult.Values] = ResultOrValue(v.flatMap {
      case v: DecodeBasicInputsResult.Values               => (Right(v): Either[RequestResult[B], DecodeBasicInputsResult.Values]).unit
      case DecodeBasicInputsResult.Failure(input, failure) => onDecodeFailure(input, failure).map(Left(_))
    })
    def apply(v: InputValueResult): ResultOrValue[Params] = v match {
      case InputValueResult.Value(params, _)        => ResultOrValue((Right(params): Either[RequestResult[B], Params]).unit)
      case InputValueResult.Failure(input, failure) => ResultOrValue(onDecodeFailure(input, failure).map(Left(_)))
    }
    def apply(v: DecodeBasicInputsResult): ResultOrValue[DecodeBasicInputsResult.Values] = v match {
      case v: DecodeBasicInputsResult.Values =>
        ResultOrValue((Right(v): Either[RequestResult[B], DecodeBasicInputsResult.Values]).unit)
      case DecodeBasicInputsResult.Failure(input, failure) => ResultOrValue(onDecodeFailure(input, failure).map(Left(_)))
    }
    def apply(f: Option[DecodeBasicInputsResult.Failure]): ResultOrValue[Unit] = f match {
      case None                                                  => ResultOrValue((Right(()): Either[RequestResult[B], Unit]).unit)
      case Some(DecodeBasicInputsResult.Failure(input, failure)) => ResultOrValue(onDecodeFailure(input, failure).map(Left(_)))
    }
    def value[T](v: F[T]): ResultOrValue[T] = ResultOrValue(v.map(Right(_)))

    def onDecodeFailure(input: EndpointInput[_], failure: DecodeResult.Failure): F[RequestResult[B]]
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy