All Downloads are FREE. Search and download functionalities are using the official Maven repository.

zio.http.Response.scala Maven / Gradle / Ivy

/*
 * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package zio.http

import java.nio.file.{AccessDeniedException, NotDirectoryException}

import scala.annotation.tailrec

import zio._

import zio.stream.ZStream

import zio.schema.Schema

import zio.http.internal.{HeaderOps, OutputEncoder}
import zio.http.template._

final case class Response(
  status: Status = Status.Ok,
  headers: Headers = Headers.empty,
  body: Body = Body.empty,
) extends HeaderOps[Response] { self =>

  // To be used by encoders to avoid re-encoding static responses (optimization)
  private[http] var encoded: AnyRef = null

  def addCookie(cookie: Cookie.Response): Response =
    self.copy(headers = self.headers ++ Headers(Header.SetCookie(cookie)))

  /**
   * Adds flash values to the cookie-based flash-scope.
   */
  def addFlash[A](setter: Flash.Setter[A]): Response =
    self.addCookie(Flash.Setter.run(setter).copy(path = Some(Path.root)))

  /**
   * Collects the potentially streaming body of the response into a single
   * chunk.
   *
   * Any errors that occur from the collection of the body will be caught and
   * propagated to the Body
   */
  def collect(implicit trace: Trace): ZIO[Any, Nothing, Response] =
    self.body.materialize.map { b =>
      if (b eq self.body) self
      else self.copy(body = b)
    }

  def contentType(mediaType: MediaType): Response =
    self.addHeader("content-type", mediaType.fullType)

  /**
   * Consumes the streaming body fully and then discards it while also ignoring
   * any failures
   */
  def ignoreBody(implicit trace: Trace): ZIO[Any, Nothing, Response] = {
    val out   = self.copy(body = Body.empty)
    val body0 = self.body
    if (body0.isComplete) Exit.succeed(out)
    else body0.asStream.runDrain.ignore.as(out)
  }

  def patch(p: Response.Patch)(implicit trace: Trace): Response = p.apply(self)

  /**
   * Sets the status of the response
   */
  def status(status: Status): Response =
    copy(status = status)

  /**
   * Creates an Http from a Response
   */
  def toHandler(implicit trace: Trace): Handler[Any, Nothing, Any, Response] = Handler.fromResponse(self)

  /**
   * Updates the current Headers with new one, using the provided update
   * function passed.
   */
  override def updateHeaders(update: Headers => Headers)(implicit trace: Trace): Response =
    copy(headers = update(headers))

}

object Response {

  /**
   * Models the set of operations that one would want to apply on a Response.
   */
  sealed trait Patch { self =>
    def ++(that: Patch): Patch                                = Patch.Combine(self, that)
    def apply(res: Response)(implicit trace: Trace): Response = {

      @tailrec
      def loop(res: Response, patch: Patch): Response =
        patch match {
          case Patch.Empty                  => res
          case Patch.AddHeaders(headers)    => res.addHeaders(headers)
          case Patch.RemoveHeaders(headers) => res.removeHeaders(headers)
          case Patch.SetStatus(status)      => res.status(status)
          case Patch.Combine(self, other)   => loop(self(res), other)
          case Patch.UpdateHeaders(f)       => res.updateHeaders(f)
        }

      loop(res, self)
    }
  }

  object Patch {
    import Header.HeaderType

    case object Empty                                     extends Patch
    final case class AddHeaders(headers: Headers)         extends Patch
    final case class RemoveHeaders(headers: Set[String])  extends Patch
    final case class SetStatus(status: Status)            extends Patch
    final case class Combine(left: Patch, right: Patch)   extends Patch
    final case class UpdateHeaders(f: Headers => Headers) extends Patch

    def empty: Patch = Empty

    def addHeader(headerType: HeaderType)(value: headerType.HeaderValue): Patch =
      addHeader(headerType.name, headerType.render(value))

    def addHeader(header: Header): Patch                          = addHeaders(Headers(header))
    def addHeaders(headers: Headers): Patch                       = AddHeaders(headers)
    def addHeader(name: CharSequence, value: CharSequence): Patch = addHeaders(Headers(name, value))

    def removeHeaders(headerTypes: Set[HeaderType]): Patch = RemoveHeaders(headerTypes.map(_.name))
    def status(status: Status): Patch                      = SetStatus(status)
    def updateHeaders(f: Headers => Headers): Patch        = UpdateHeaders(f)
  }

  def badRequest: Response =
    error(Status.BadRequest)

  def badRequest(message: String): Response =
    error(Status.BadRequest, message)

  def error(status: Status.Error, message: String): Response =
    Response(status = status, body = Body.fromString(OutputEncoder.encodeHtml(message)))

  def error(status: Status.Error, body: Body): Response =
    Response(
      status = status,
      body = body,
      headers = if (body.mediaType.isEmpty) Headers.empty else Headers(Header.ContentType(body.mediaType.get)),
    )

  def error(status: Status.Error): Response =
    Response(status = status)

  def forbidden: Response =
    error(Status.Forbidden)

  def forbidden(message: String): Response =
    error(Status.Forbidden, message)

  def fromCause(cause: Cause[Any]): Response =
    fromCause(cause, ErrorResponseConfig.default)

  /**
   * Creates a new response from the specified cause. Note that this method is
   * not polymorphic, but will attempt to inspect the runtime class of the
   * failure inside the cause, if any.
   */
  @tailrec
  def fromCause(cause: Cause[Any], config: ErrorResponseConfig): Response =
    cause.failureOrCause match {
      case Left(failure: Response)  => failure
      case Left(failure: Throwable) => fromThrowable(failure, config)
      case Left(failure: Cause[_])  => fromCause(failure, config)
      case _                        =>
        val body =
          if (config.withErrorBody) Body.fromString(cause.prettyPrint).contentType(MediaType.text.`plain`)
          else Body.empty
        if (cause.isInterruptedOnly) error(Status.RequestTimeout, body)
        else error(Status.InternalServerError, body)
    }

  /**
   * Creates a new response from the specified cause, translating any typed
   * error to a response using the provided function.
   */
  def fromCauseWith[E](cause: Cause[E], config: ErrorResponseConfig)(f: E => Response): Response =
    cause.failureOrCause match {
      case Left(failure) => f(failure)
      case Right(cause)  => fromCause(cause, config)
    }

  /**
   * Creates a response with content-type set to text/event-stream
   * @param data
   *   \- stream of data to be sent as Server Sent Events
   */
  def fromServerSentEvents[T: Schema](data: ZStream[Any, Nothing, ServerSentEvent[T]])(implicit
    trace: Trace,
  ): Response = {
    val codec = ServerSentEvent.defaultBinaryCodec[T]
    Response(
      Status.Ok,
      contentTypeEventStream,
      Body.fromCharSequenceStreamChunked(data.map(codec.encode).map(_.asString)),
    )
  }

  /**
   * Creates a new response for the provided socket app
   */
  def fromSocketApp[R](app: WebSocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = {
    ZIO.environment[R].map { env =>
      Response(
        Status.SwitchingProtocols,
        Headers.empty,
        Body.fromSocketApp(app.provideEnvironment(env)),
      )
    }
  }

  def fromThrowable(throwable: Throwable): Response =
    fromThrowable(throwable, ErrorResponseConfig.default)

  /**
   * Creates a new response for the specified throwable. Note that this method
   * relies on the runtime class of the throwable.
   */
  def fromThrowable(throwable: Throwable, config: ErrorResponseConfig): Response = {
    throwable match { // TODO: Enhance
      case _: AccessDeniedException  => error(Status.Forbidden, throwableToMessage(throwable, Status.Forbidden, config))
      case _: IllegalAccessException => error(Status.Forbidden, throwableToMessage(throwable, Status.Forbidden, config))
      case _: IllegalAccessError     => error(Status.Forbidden, throwableToMessage(throwable, Status.Forbidden, config))
      case _: NotDirectoryException  =>
        error(Status.BadRequest, throwableToMessage(throwable, Status.BadRequest, config))
      case _: IllegalArgumentException        =>
        error(Status.BadRequest, throwableToMessage(throwable, Status.BadRequest, config))
      case _: java.io.FileNotFoundException   =>
        error(Status.NotFound, throwableToMessage(throwable, Status.NotFound, config))
      case _: java.net.ConnectException       =>
        error(Status.ServiceUnavailable, throwableToMessage(throwable, Status.ServiceUnavailable, config))
      case _: java.net.SocketTimeoutException =>
        error(Status.GatewayTimeout, throwableToMessage(throwable, Status.GatewayTimeout, config))
      case _ => error(Status.InternalServerError, throwableToMessage(throwable, Status.InternalServerError, config))
    }
  }

  private def throwableToMessage(throwable: Throwable, status: Status, config: ErrorResponseConfig): Body =
    if (!config.withErrorBody) Body.empty
    else {
      val rawTrace   = if (config.withStackTrace) throwable.getStackTrace else Array.empty[StackTraceElement]
      val stackTrace =
        if (config.withStackTrace && rawTrace.nonEmpty)
          (if (config.maxStackTraceDepth == 0) rawTrace
           else rawTrace.take(config.maxStackTraceDepth))
            .mkString("\n", "\n", "")
        else ""
      val message    = if (throwable.getMessage eq null) "" else throwable.getMessage
      bodyFromThrowable(message, stackTrace, status, config)
    }

  private def bodyFromThrowable(
    message: String,
    stackTrace: String,
    status: Status,
    config: ErrorResponseConfig,
  ): Body = {
    def htmlResponse: Body = {
      val data = Template.container(s"$status") {
        div(
          div(
            styles := "text-align: center",
            div(s"${status.code}", styles := "font-size: 20em"),
            div(message),
            div(stackTrace),
          ),
        )
      }
      Body.fromString("" + data.encode)
    }

    def textResponse: Body =
      Body.fromString {
        val statusCode = status.code
        s"${scala.Console.BOLD}${scala.Console.RED}${status}${scala.Console.RESET} - " +
          s"${scala.Console.BOLD}${scala.Console.CYAN}$statusCode${scala.Console.RESET} - " +
          s"$message" +
          s"${scala.Console.BOLD}${scala.Console.RED} $stackTrace ${scala.Console.RESET}"
      }

    def jsonMessage =
      Body.fromString(
        s"""{"status": "${status.code}", "message": "$message", "stackTrace": "$stackTrace"}""",
      )

    config.errorFormat match {
      case ErrorResponseConfig.ErrorFormat.Html => htmlResponse.contentType(config.errorFormat.mediaType)
      case ErrorResponseConfig.ErrorFormat.Text => textResponse.contentType(config.errorFormat.mediaType)
      case ErrorResponseConfig.ErrorFormat.Json => jsonMessage.contentType(config.errorFormat.mediaType)
    }
  }

  def gatewayTimeout: Response = error(Status.GatewayTimeout)

  def gatewayTimeout(message: String): Response = error(Status.GatewayTimeout, message)

  /**
   * Creates a response with content-type set to text/html
   */
  def html(data: Html, status: Status = Status.Ok): Response =
    Response(
      status,
      contentTypeHtml,
      Body.fromString("" + data.encode),
    )

  def httpVersionNotSupported: Response = error(Status.HttpVersionNotSupported)

  def httpVersionNotSupported(message: String): Response = error(Status.HttpVersionNotSupported, message)

  def internalServerError: Response = error(Status.InternalServerError)

  def internalServerError(message: String): Response = error(Status.InternalServerError, message)

  /**
   * Creates a response with content-type set to application/json
   */
  def json(data: CharSequence): Response =
    Response(
      Status.Ok,
      contentTypeJson,
      Body.fromCharSequence(data),
    )

  def networkAuthenticationRequired: Response = error(Status.NetworkAuthenticationRequired)

  def networkAuthenticationRequired(message: String): Response = error(Status.NetworkAuthenticationRequired, message)

  def notExtended: Response = error(Status.NotExtended)

  def notExtended(message: String): Response = error(Status.NotExtended, message)

  def notFound: Response = error(Status.NotFound)

  def notFound(message: String): Response = error(Status.NotFound, message)

  def notImplemented: Response = error(Status.NotImplemented)

  def notImplemented(message: String): Response = error(Status.NotImplemented, message)

  /**
   * Creates an empty response with status 200
   */
  def ok: Response = status(Status.Ok)

  /**
   * Creates an empty response with status 307 or 308 depending on if it's
   * permanent or not.
   *
   * Note: if you intend to always redirect a browser with a HTTP GET to the
   * given location you very likely should use `Response#seeOther` instead.
   */
  def redirect(location: URL, isPermanent: Boolean = false): Response = {
    val status = if (isPermanent) Status.PermanentRedirect else Status.TemporaryRedirect
    Response(status = status, headers = Headers(Header.Location(location)))
  }

  /**
   * Creates an empty response with status 303.
   */
  def seeOther(location: URL): Response =
    Response(status = Status.SeeOther, headers = Headers(Header.Location(location)))

  def serviceUnavailable: Response = error(Status.ServiceUnavailable)

  def serviceUnavailable(message: String): Response = error(Status.ServiceUnavailable, message)

  /**
   * Creates an empty response with the provided Status
   */
  def status(status: Status): Response =
    Response(status = status)

  /**
   * Creates a response with content-type set to text/plain
   */
  def text(text: CharSequence): Response =
    Response(
      Status.Ok,
      contentTypeText,
      Body.fromCharSequence(text),
    )

  def unauthorized: Response = error(Status.Unauthorized)

  def unauthorized(message: String): Response = error(Status.Unauthorized, message)

  private val contentTypeJson: Headers        = Headers(Header.ContentType(MediaType.application.json).untyped)
  private val contentTypeHtml: Headers        = Headers(Header.ContentType(MediaType.text.html).untyped)
  private val contentTypeText: Headers        = Headers(Header.ContentType(MediaType.text.plain).untyped)
  private val contentTypeEventStream: Headers = Headers(Header.ContentType(MediaType.text.`event-stream`).untyped)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy