All Downloads are FREE. Search and download functionalities are using the official Maven repository.

zio.http.Middleware.scala Maven / Gradle / Ivy

/*
 * Copyright 2023 the ZIO HTTP contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package zio.http

import java.io.File
import java.nio.charset.Charset

import scala.annotation.nowarn

import zio._
import zio.metrics._

import zio.http.codec.{PathCodec, SegmentCodec}

@nowarn("msg=shadows type")
trait Middleware[-UpperEnv] { self =>

  def apply[Env1 <: UpperEnv, Err](routes: Routes[Env1, Err]): Routes[Env1, Err]

  def @@[UpperEnv1 <: UpperEnv](
    that: Middleware[UpperEnv1],
  ): Middleware[UpperEnv1] =
    self ++ that

  def ++[UpperEnv1 <: UpperEnv](
    that: Middleware[UpperEnv1],
  ): Middleware[UpperEnv1] =
    new Middleware[UpperEnv1] {
      def apply[Env1 <: UpperEnv1, Err](
        routes: Routes[Env1, Err],
      ): Routes[Env1, Err] =
        self(that(routes))
    }
}

@nowarn("msg=shadows type")
object Middleware extends HandlerAspects {

  /**
   * Configuration for the CORS aspect.
   */
  final case class CorsConfig(
    allowedOrigin: Header.Origin => Option[Header.AccessControlAllowOrigin] = origin =>
      Some(Header.AccessControlAllowOrigin.Specific(origin)),
    allowedMethods: Header.AccessControlAllowMethods = Header.AccessControlAllowMethods.All,
    allowedHeaders: Header.AccessControlAllowHeaders = Header.AccessControlAllowHeaders.All,
    allowCredentials: Header.AccessControlAllowCredentials = Header.AccessControlAllowCredentials.Allow,
    exposedHeaders: Header.AccessControlExposeHeaders = Header.AccessControlExposeHeaders.All,
    maxAge: Option[Header.AccessControlMaxAge] = None,
  )

  /**
   * Creates a aspect for Cross-Origin Resource Sharing (CORS) using default
   * options.
   * @see
   *   https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
   */
  def cors: Middleware[Any] = cors(CorsConfig())

  /**
   * Creates a aspect for Cross-Origin Resource Sharing (CORS).
   * @see
   *   https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
   */
  def cors(config: CorsConfig): Middleware[Any] = {
    def allowedHeaders(
      requestedHeaders: Option[Header.AccessControlRequestHeaders],
      allowedHeaders: Header.AccessControlAllowHeaders,
    ): Header.AccessControlAllowHeaders =
      // Returning an intersection of requested headers and allowed headers
      // if there are no requested headers, we return the configured allowed headers without modification
      allowedHeaders match {
        case Header.AccessControlAllowHeaders.Some(values) =>
          requestedHeaders match {
            case Some(Header.AccessControlRequestHeaders(headers)) =>
              val intersection = headers.toSet.intersect(values.toSet)
              NonEmptyChunk.fromIterableOption(intersection) match {
                case Some(values) => Header.AccessControlAllowHeaders.Some(values)
                case None         => Header.AccessControlAllowHeaders.None
              }
            case None                                              => allowedHeaders
          }
        case Header.AccessControlAllowHeaders.All          =>
          requestedHeaders match {
            case Some(Header.AccessControlRequestHeaders(headers)) => Header.AccessControlAllowHeaders.Some(headers)
            case _                                                 => Header.AccessControlAllowHeaders.All
          }
        case Header.AccessControlAllowHeaders.None         => Header.AccessControlAllowHeaders.None
      }

    def corsHeaders(
      allowOrigin: Header.AccessControlAllowOrigin,
      requestedHeaders: Option[Header.AccessControlRequestHeaders],
      isPreflight: Boolean,
    ): Headers =
      Headers(
        allowOrigin,
        config.allowedMethods,
        config.allowCredentials,
      ) ++
        Headers.ifThenElse(isPreflight)(
          onTrue = Headers(allowedHeaders(requestedHeaders, config.allowedHeaders)),
          onFalse = Headers(config.exposedHeaders),
        ) ++ config.maxAge.fold(Headers.empty)(Headers(_))

    // HandlerAspect:
    val aspect =
      HandlerAspect.interceptHandlerStateful[Any, Headers, Unit](
        Handler.fromFunction[Request] { request =>
          val originHeader = request.header(Header.Origin)
          val acrhHeader   = request.header(Header.AccessControlRequestHeaders)

          originHeader match {
            case Some(origin) =>
              config.allowedOrigin(origin) match {
                case Some(allowOrigin) if config.allowedMethods.contains(request.method) =>
                  (corsHeaders(allowOrigin, acrhHeader, isPreflight = false), (request, ()))
                case _                                                                   =>
                  (Headers.empty, (request, ()))
              }

            case None => (Headers.empty, (request, ()))
          }
        },
      )(Handler.fromFunction[(Headers, Response)] { case (headers, response) =>
        response.addHeaders(headers)
      })

    val optionsRoute = {
      implicit val trace: Trace = Trace.empty

      Method.OPTIONS / trailing -> handler { (_: Path, request: Request) =>
        val originHeader = request.header(Header.Origin)
        val acrmHeader   = request.header(Header.AccessControlRequestMethod)
        val acrhHeader   = request.header(Header.AccessControlRequestHeaders)

        (
          originHeader,
          acrmHeader,
        ) match {
          case (Some(origin), Some(acrm)) =>
            config.allowedOrigin(origin) match {
              case Some(allowOrigin) if config.allowedMethods.contains(acrm.method) =>
                Response(
                  Status.NoContent,
                  headers = corsHeaders(allowOrigin, acrhHeader, isPreflight = true),
                )

              case _ =>
                Response.notFound
            }
          case _                          =>
            Response.notFound
        }
      }
    }

    new Middleware[Any] {
      def apply[Env1, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        optionsRoute +: (routes @@ aspect)
    }
  }

  def ensureHeader(header: Header.HeaderType)(make: => header.HeaderValue): Middleware[Any] =
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { h =>
          handler { (req: Request) =>
            if (req.headers.contains(header.name)) h(req)
            else h(req.addHeader(make))
          }
        }
    }

  def ensureHeader(headerName: String)(make: => String): Middleware[Any] =
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { h =>
          handler { (req: Request) =>
            if (req.headers.contains(headerName)) h(req)
            else h(req.addHeader(headerName, make))
          }
        }
    }

  private[http] case class ForwardedHeaders(headers: Headers)

  def forwardHeaders(header: Header.HeaderType, headers: Header.HeaderType*)(implicit
    trace: Trace,
  ): Middleware[Any] = {
    val allHeaders = header +: headers
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { h =>
          handler { (req: Request) =>
            val headerValues = ChunkBuilder.make[Header]()
            headerValues.sizeHint(allHeaders.length)
            var i            = 0
            while (i < allHeaders.length) {
              val name = allHeaders(i)
              req.headers.get(name).foreach { value =>
                headerValues += value
              }
              i += 1
            }
            RequestStore.update[ForwardedHeaders] { old =>
              ForwardedHeaders {
                old.map(_.headers).getOrElse(Headers.empty) ++
                  Headers.fromIterable(headerValues.result())
              }
            } *> h(req)
          }
        }
    }
  }

  def forwardHeaders(headerName: String, headerNames: String*)(implicit trace: Trace): Middleware[Any] = {
    val allHeaders = headerName +: headerNames
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { h =>
          handler { (req: Request) =>
            val headerValues = ChunkBuilder.make[Header]()
            headerValues.sizeHint(allHeaders.length)
            var i            = 0
            while (i < allHeaders.length) {
              val name = allHeaders(i)
              req.headers.get(name).foreach { value =>
                headerValues += Header.Custom(name, value)
              }
              i += 1
            }
            RequestStore.update[ForwardedHeaders] { old =>
              ForwardedHeaders {
                old.map(_.headers).getOrElse(Headers.empty) ++
                  Headers.fromIterable(headerValues.result())
              }
            } *> h(req)
          }
        }
    }
  }

  def logAnnotate(key: => String, value: => String)(implicit trace: Trace): Middleware[Any] =
    logAnnotate(LogAnnotation(key, value))

  def logAnnotate(logAnnotation: => LogAnnotation, logAnnotations: LogAnnotation*)(implicit
    trace: Trace,
  ): Middleware[Any] =
    logAnnotate((logAnnotation +: logAnnotations).toSet)

  def logAnnotate(logAnnotations: => Set[LogAnnotation])(implicit trace: Trace): Middleware[Any] =
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { h =>
          handler((req: Request) => ZIO.logAnnotate(logAnnotations)(h(req)))
        }
    }

  /**
   * Creates a middleware that will annotate log messages that are logged while
   * a request is handled with log annotations derived from the request.
   */
  def logAnnotate(fromRequest: Request => Set[LogAnnotation])(implicit trace: Trace): Middleware[Any] =
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { h =>
          handler((req: Request) => ZIO.logAnnotate(fromRequest(req))(h(req)))
        }
    }

  /**
   * Creates a middleware that will annotate log messages that are logged while
   * a request is handled with the names and the values of the specified
   * headers.
   */
  def logAnnotateHeaders(headerName: String, headerNames: String*)(implicit trace: Trace): Middleware[Any] =
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = {
        val headers = headerName +: headerNames
        routes.transform[Env1] { h =>
          handler((req: Request) => {
            val annotations = Set.newBuilder[LogAnnotation]
            annotations.sizeHint(headers.length)
            var i           = 0
            while (i < headers.length) {
              val name = headers(i)
              annotations += LogAnnotation(name, req.headers.get(name).mkString)
              i += 1
            }
            ZIO.logAnnotate(annotations.result())(h(req))
          })
        }
      }
    }

  /**
   * Creates middleware that will annotate log messages that are logged while a
   * request is handled with the names and the values of the specified headers.
   */
  def logAnnotateHeaders(header: Header.HeaderType, headers: Header.HeaderType*)(implicit
    trace: Trace,
  ): Middleware[Any] =
    logAnnotateHeaders(header.name, headers.map(_.name): _*)

  def timeout(duration: Duration)(implicit trace: Trace): Middleware[Any] =
    new Middleware[Any] {
      def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        routes.transform[Env1] { handler =>
          handler.timeoutFail(Response(status = Status.RequestTimeout))(duration)
        }
    }

  /**
   * Creates middleware that will track metrics.
   *
   * @param totalRequestsName
   *   Total HTTP requests metric name.
   * @param requestDurationName
   *   HTTP request duration metric name.
   * @param requestDurationBoundaries
   *   Boundaries for the HTTP request duration metric.
   * @param extraLabels
   *   A set of extra labels all metrics will be tagged with.
   */
  def metrics(
    concurrentRequestsName: String = "http_concurrent_requests_total",
    totalRequestsName: String = "http_requests_total",
    requestDurationName: String = "http_request_duration_seconds",
    requestDurationBoundaries: MetricKeyType.Histogram.Boundaries = defaultBoundaries,
    extraLabels: Set[MetricLabel] = Set.empty,
  )(implicit trace: Trace): Middleware[Any] = {
    val requestsTotal: Metric.Counter[RuntimeFlags] = Metric.counterInt(totalRequestsName)
    val concurrentRequests: Metric.Gauge[Double]    = Metric.gauge(concurrentRequestsName)
    val requestDuration: Metric.Histogram[Double]   = Metric.histogram(requestDurationName, requestDurationBoundaries)
    val nanosToSeconds: Double                      = 1e9d

    def labelsForRequest(routePattern: RoutePattern[_]): Set[MetricLabel] =
      Set(
        MetricLabel("method", routePattern.method.render),
        MetricLabel("path", routePattern.pathCodec.render),
      ) ++ extraLabels

    def labelsForResponse(res: Response): Set[MetricLabel] =
      Set(
        MetricLabel("status", res.status.code.toString),
      )

    def report(
      start: Long,
      requestLabels: Set[MetricLabel],
      labels: Set[MetricLabel],
    )(implicit trace: Trace): ZIO[Any, Nothing, Unit] =
      for {
        _   <- requestsTotal.tagged(labels).increment
        _   <- concurrentRequests.tagged(requestLabels).decrement
        end <- Clock.nanoTime
        took = end - start
        _ <- requestDuration.tagged(labels).update(took / nanosToSeconds)
      } yield ()

    def aspect(routePattern: RoutePattern[_])(implicit trace: Trace): HandlerAspect[Any, Unit] =
      HandlerAspect.interceptHandlerStateful(Handler.fromFunctionZIO[Request] { req =>
        val requestLabels = labelsForRequest(routePattern)

        for {
          start <- Clock.nanoTime
          _     <- concurrentRequests.tagged(requestLabels).increment
        } yield ((start, requestLabels), (req, ()))
      })(Handler.fromFunctionZIO[((Long, Set[MetricLabel]), Response)] { case ((start, requestLabels), response) =>
        val allLabels = requestLabels ++ labelsForResponse(response)

        report(start, requestLabels, allLabels).as(response)
      })

    new Middleware[Any] {
      def apply[Env1, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
        Routes.fromIterable(
          routes.routes.map(route => route.transform[Env1](_ @@ aspect(route.routePattern))),
        )
    }
  }

  private sealed trait StaticServe[-R, +E] { self =>
    def run(path: Path, req: Request): Handler[R, E, Request, Response]

  }

  private object StaticServe {
    def make[R, E](f: (Path, Request) => Handler[R, E, Request, Response]): StaticServe[R, E] =
      new StaticServe[R, E] {
        override def run(path: Path, request: Request) = f(path, request)
      }

    def fromDirectory(docRoot: File)(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) =>
      val target = new File(docRoot.getAbsolutePath + path.encode)
      if (target.getCanonicalPath.startsWith(docRoot.getCanonicalPath))
        Handler.fromFile(target, Charset.defaultCharset())
      else {
        Handler.fromZIO(
          ZIO.logWarning(s"attempt to access file outside of docRoot: ${target.getAbsolutePath}"),
        ) *> Handler.badRequest
      }
    }

    def fromResource(resourcePrefix: String)(implicit trace: Trace): StaticServe[Any, Throwable] = make { (path, _) =>
      // validate that resourcePrefix starts with an optional slash, followed by at least 1 java identifier character
      val rp = if (resourcePrefix.startsWith("/")) resourcePrefix else "/" + resourcePrefix
      if (rp.length < 2 || !Character.isJavaIdentifierStart(rp.charAt(1))) {
        Handler.die(new IllegalArgumentException("resourcePrefix must have at least 1 valid character"))
      } else {
        Handler.fromResource(s"${resourcePrefix}/${path.dropLeadingSlash.encode}")
      }
    }

  }

  private def toMiddleware[E](path: Path, staticServe: StaticServe[Any, E])(implicit trace: Trace): Middleware[Any] =
    new Middleware[Any] {

      private def checkFishy(acc: Boolean, segment: String): Boolean = {
        val stop = segment.indexOf('/') >= 0 || segment.indexOf('\\') >= 0 || segment == ".."
        acc || stop
      }

      override def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = {
        val mountpoint =
          Method.GET / path.segments.map(PathCodec.literal).reduceLeftOption(_ / _).getOrElse(PathCodec.empty)
        val pattern    = mountpoint / trailing
        val other      = Routes(
          pattern -> Handler
            .identity[Request]
            .flatMap { request =>
              val isFishy = request.path.segments.foldLeft(false)(checkFishy)
              if (isFishy) {
                Handler.fromZIO(ZIO.logWarning(s"fishy request detected: ${request.path.encode}")) *> Handler.badRequest
              } else {
                val segs   = pattern.pathCodec.segments.collect { case SegmentCodec.Literal(v) =>
                  v
                }
                val unnest = segs.foldLeft(Path.empty)(_ / _).addLeadingSlash
                val path   = request.path.unnest(unnest).addLeadingSlash
                staticServe.run(path, request).sandbox
              }
            },
        )
        routes ++ other
      }

    }

  /**
   * Creates a middleware for serving static files from the directory `docRoot`
   * at the url path `path`.
   *
   * Example: `val serveDirectory = Middleware.serveDirectory(Path.empty /
   * "assets", new File("/some/local/path"))`
   *
   * With this middleware in place, a request to
   * `https://www.domain.com/assets/folder/file1.jpg` would serve the local file
   * `/some/local/path/folder/file1.jpg`.
   */
  def serveDirectory(path: Path, docRoot: File)(implicit trace: Trace): Middleware[Any] =
    toMiddleware(path, StaticServe.fromDirectory(docRoot))

  /**
   * Creates a middleware for serving static files at URL path `path` from
   * resources with the given `resourcePrefix`.
   *
   * Example: `Middleware.serveResources(Path.empty / "assets", "webapp")`
   *
   * With this middleware in place, a request to
   * `https://www.domain.com/assets/folder/file1.jpg` would serve the file
   * `src/main/resources/webapp/folder/file1.jpg`. Note how the URL path is
   * removed and the resourcePrefix prepended.
   *
   * Most build systems support resources in the `src/main/resources` directory.
   * In the above example, the file `src/main/resources/webapp/folder/file1.jpg`
   * would be served.
   *
   * * The `resourcePrefix` defaults to `"public"`. To prevent insecure sharing
   * of * resource files, `resourcePrefix` must start with a `/` followed by at
   * least 1 *
   * [[java.lang.Character.isJavaIdentifierStart(x\$1:Char)* valid java identifier character]].
   * The `/` * will be prepended if it is not present.
   */
  def serveResources(path: Path, resourcePrefix: String = "public")(implicit trace: Trace): Middleware[Any] =
    toMiddleware(path, StaticServe.fromResource(resourcePrefix))

  /**
   * Creates a middleware for managing the flash scope.
   */
  def flashScopeHandling: HandlerAspect[Any, Unit] = Middleware.intercept { (req, resp) =>
    req.cookie(Flash.COOKIE_NAME).fold(resp)(flash => resp.addCookie(Cookie.clear(flash.name)))
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy