All Downloads are FREE. Search and download functionalities are using the official Maven repository.

pl.iterators.stir.server.RejectionHandler.scala Maven / Gradle / Ivy

The newest version!
package pl.iterators.stir.server

import cats.effect.IO
import org.http4s.{ Headers, Response }
import org.http4s.Status._
import org.http4s.headers.{ `WWW-Authenticate`, Allow }
import pl.iterators.stir.impl.util._
import pl.iterators.stir.marshalling.ToResponseMarshallable
import pl.iterators.stir.server.directives.BasicDirectives

import scala.annotation.tailrec
import scala.collection.immutable.VectorBuilder
import scala.reflect.ClassTag

trait RejectionHandler extends (Seq[Rejection] => Option[Route]) { self =>
  import RejectionHandler._

  /** Map any HTTP response which was returned by this RejectionHandler to a different one before rendering it. */
  def mapRejectionResponse(map: Response[IO] => Response[IO]): RejectionHandler = {
    this match {
      case a: BuiltRejectionHandler =>
        new BuiltRejectionHandler(
          a.cases.map { handler => handler.mapResponse(map) },
          a.notFound.map { route => BasicDirectives.mapResponse(map)(route) },
          isDefault = false)

      case other =>
        throw new IllegalArgumentException("Can only mapRejectionResponse on BuiltRejectionHandler " +
          s"(e.g. obtained by calling `.result()` on the `RejectionHandler` builder). Type was: ${other.getClass}")
    }
  }

  /**
   * Creates a new [[RejectionHandler]] which uses the given one as fallback for this one.
   */
  def withFallback(that: RejectionHandler): RejectionHandler =
    (this, that) match {
      case (a: BuiltRejectionHandler, _) if a.isDefault => this // the default handler already handles everything
      case (a: BuiltRejectionHandler, b: BuiltRejectionHandler) =>
        new BuiltRejectionHandler(a.cases ++ b.cases, a.notFound.orElse(b.notFound), b.isDefault)
      case _ => new RejectionHandler {
          def apply(rejections: Seq[Rejection]): Option[Route] =
            self(rejections).orElse(that(rejections))
        }
    }

  /**
   * "Seals" this handler by attaching a default handler as fallback if necessary.
   */
  def seal: RejectionHandler =
    this match {
      case x: BuiltRejectionHandler if x.isDefault => x
      case _                                       => withFallback(default)
    }
}

object RejectionHandler {

  /**
   * Creates a new [[RejectionHandler]] builder.
   */
  def newBuilder(): Builder = new Builder(isDefault = false)

  final class Builder private[RejectionHandler] (isDefault: Boolean) {
    private[this] val cases = new VectorBuilder[Handler]
    private[this] var notFound: Option[Route] = None

    /**
     * Handles a single [[Rejection]] with the given partial function.
     */
    def handle(pf: PartialFunction[Rejection, Route]): this.type = {
      cases += CaseHandler(pf)
      this
    }

    /**
     * Handles several Rejections of the same type at the same time.
     * The seq passed to the given function is guaranteed to be non-empty.
     */
    def handleAll[T <: Rejection: ClassTag](f: Seq[T] => Route): this.type = {
      val runtimeClass = implicitly[ClassTag[T]].runtimeClass
      cases += TypeHandler[T](runtimeClass, f)
      this
    }

    /**
     * Handles the special "not found" case using the given [[Route]].
     */
    def handleNotFound(route: Route): this.type = {
      notFound = Some(route)
      this
    }

    def result(): RejectionHandler =
      new BuiltRejectionHandler(cases.result(), notFound, isDefault)
  }

  private sealed abstract class Handler {
    def mapResponse(map: Response[IO] => Response[IO]): Handler
  }
  private final case class CaseHandler(pf: PartialFunction[Rejection, Route]) extends Handler {
    override def mapResponse(map: Response[IO] => Response[IO]): CaseHandler = {
      copy(pf.andThen(route => BasicDirectives.mapResponse(map)(route)))
    }
  }
  private final case class TypeHandler[T <: Rejection](
      runtimeClass: Class[_], f: Seq[T] => Route) extends Handler with PartialFunction[Rejection, T] {
    def isDefinedAt(rejection: Rejection): Boolean = runtimeClass.isInstance(rejection)
    def apply(rejection: Rejection): T = rejection.asInstanceOf[T]

    override def mapResponse(map: Response[IO] => Response[IO]): TypeHandler[T] = {
      copy(f = f.andThen(route => BasicDirectives.mapResponse(map)(route)))
    }
  }

  private class BuiltRejectionHandler(
      val cases: Vector[Handler],
      val notFound: Option[Route],
      val isDefault: Boolean) extends RejectionHandler {
    def apply(rejections: Seq[Rejection]): Option[Route] =
      if (rejections.nonEmpty) {
        @tailrec def rec(ix: Int): Option[Route] =
          if (ix < cases.length) {
            cases(ix) match {
              case CaseHandler(pf) =>
                val route = rejections.collectFirst(pf)
                if (route.isEmpty) rec(ix + 1) else route
              case x @ TypeHandler(_, f) =>
                val rejs = rejections.collect(x)
                if (rejs.isEmpty) rec(ix + 1) else Some(f(rejs))
            }
          } else None
        rec(0)
      } else notFound
  }

  import Directives._

  private def rejectRequestEntityAndComplete(m: => ToResponseMarshallable) = complete(m)

  /**
   * Default [[RejectionHandler]] instance.
   */
  final val default =
    new Builder(isDefault = true)
      .handleAll[SchemeRejection] { rejections =>
        val schemes = rejections.map(_.supported).mkString(", ")
        rejectRequestEntityAndComplete((BadRequest, "Uri scheme not allowed, supported schemes: " + schemes))
      }
      .handleAll[MethodRejection] { rejections =>
        val (methods, names) = rejections.map(r => r.supported -> r.supported.name).unzip
        rejectRequestEntityAndComplete((MethodNotAllowed, Headers(Allow(methods.toSet)),
          "HTTP method not allowed, supported methods: " + names.mkString(", ")))
      }
      .handle {
        case AuthorizationFailedRejection =>
          rejectRequestEntityAndComplete((Forbidden,
            "The supplied authentication is not authorized to access this resource"))
      }
      .handle {
        case MalformedFormFieldRejection(name, msg, _) =>
          rejectRequestEntityAndComplete((BadRequest, "The form field '" + name + "' was malformed:\n" + msg))
      }
      .handle {
        case MalformedHeaderRejection(headerName, msg, _) =>
          rejectRequestEntityAndComplete((BadRequest, s"The value of HTTP header '$headerName' was malformed:\n" + msg))
      }
      .handle {
        case MalformedQueryParamRejection(name, msg, _) =>
          rejectRequestEntityAndComplete((BadRequest, "The query parameter '" + name + "' was malformed:\n" + msg))
      }
      .handle {
        case MalformedRequestContentRejection(msg, throwable) => {
          val rejectionMessage = "The request content was malformed:\n" + msg
          throwable match {
//            case _: EntityStreamSizeException => rejectRequestEntityAndComplete((PayloadTooLarge, rejectionMessage))
            case _ => rejectRequestEntityAndComplete((BadRequest, rejectionMessage))
          }
        }
      }
      .handle {
        case MissingCookieRejection(cookieName) =>
          rejectRequestEntityAndComplete((BadRequest, "Request is missing required cookie '" + cookieName + '\''))
      }
      .handle {
        case MissingFormFieldRejection(fieldName) =>
          rejectRequestEntityAndComplete((BadRequest, "Request is missing required form field '" + fieldName + '\''))
      }
      .handle {
        case MissingHeaderRejection(headerName) =>
          rejectRequestEntityAndComplete((BadRequest, "Request is missing required HTTP header '" + headerName + '\''))
      }
//      .handle {
//        case MissingAttributeRejection(_) =>
//          rejectRequestEntityAndComplete((InternalServerError, InternalServerError.defaultMessage))
//      }
//      .handle {
//        case InvalidOriginRejection(allowedOrigins) =>
//          rejectRequestEntityAndComplete((Forbidden, s"Allowed `Origin` header values: ${allowedOrigins.mkString(", ")}"))
//      }
      .handle {
        case MissingQueryParamRejection(paramName) =>
          rejectRequestEntityAndComplete((NotFound, "Request is missing required query parameter '" + paramName + '\''))
      }
      .handle {
        case InvalidRequiredValueForQueryParamRejection(paramName, requiredValue, _) =>
          rejectRequestEntityAndComplete((NotFound,
            s"Request is missing required value '$requiredValue' for query parameter '$paramName'"))
      }
      .handle {
        case EntityRejection(e) =>
          rejectRequestEntityAndComplete((BadRequest, "Request entity malformed: " + e.getMessage.nullAsEmpty))
      }
//      .handle {
//        case TooManyRangesRejection(_) =>
//          rejectRequestEntityAndComplete((RangeNotSatisfiable, "Request contains too many ranges"))
//      }
//      .handle {
//        case CircuitBreakerOpenRejection(_) =>
//          rejectRequestEntityAndComplete(ServiceUnavailable)
//      }
//      .handle {
//        case UnsatisfiableRangeRejection(unsatisfiableRanges, actualEntityLength) =>
//          rejectRequestEntityAndComplete((RangeNotSatisfiable, List(`Content-Range`(ContentRange.Unsatisfiable(actualEntityLength))),
//            unsatisfiableRanges.mkString("None of the following requested Ranges were satisfiable:\n", "\n", "")))
//      }
      .handleAll[AuthenticationFailedRejection] { rejections =>
        val rejectionMessage = rejections.head.cause match {
          case AuthenticationFailedRejection.CredentialsMissing =>
            "The resource requires authentication, which was not supplied with the request"
          case AuthenticationFailedRejection.CredentialsRejected => "The supplied authentication is invalid"
        }
        // Multiple challenges per WWW-Authenticate header are allowed per spec,
        // however, it seems many browsers will ignore all challenges but the first.
        // Therefore, multiple WWW-Authenticate headers are rendered, instead.
        //
        // See https://code.google.com/p/chromium/issues/detail?id=103220
        // and https://bugzilla.mozilla.org/show_bug.cgi?id=669675
        val authenticateHeaders = Headers(rejections.map(r => `WWW-Authenticate`(r.challenge)))
        rejectRequestEntityAndComplete((Unauthorized, authenticateHeaders, rejectionMessage))
      }
//      .handleAll[UnacceptedResponseContentTypeRejection] { rejections =>
//        val supported = rejections.flatMap(_.supported)
//        val msg = supported.map(_.format).mkString("Resource representation is only available with these types:\n", "\n", "")
//        rejectRequestEntityAndComplete((NotAcceptable, msg))
//      }
//      .handleAll[UnacceptedResponseEncodingRejection] { rejections =>
//        val supported = rejections.flatMap(_.supported)
//        rejectRequestEntityAndComplete((NotAcceptable, "Resource representation is only available with these Content-Encodings:\n" +
//          supported.map(_.value).mkString("\n")))
//      }
//      .handleAll[UnsupportedRequestContentTypeRejection] { rejections =>
//        val unsupported = rejections.find(_.contentType.isDefined).flatMap(_.contentType).fold("")(" [" + _ + "]")
//        val supported = rejections.flatMap(_.supported).mkString(" or ")
//        val expected =
//          if (supported.isEmpty) ""
//          else " Expected:\n" + supported
//
//        rejectRequestEntityAndComplete((UnsupportedMediaType, s"The request's Content-Type$unsupported is not supported.$expected"))
//      }
//      .handleAll[UnsupportedRequestEncodingRejection] { rejections =>
//        val supported = rejections.map(_.supported.value).mkString(" or ")
//        rejectRequestEntityAndComplete((BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported))
//      }
//      .handle { case ExpectedWebSocketRequestRejection => rejectRequestEntityAndComplete((BadRequest, "Expected WebSocket Upgrade request")) }
//      .handleAll[UnsupportedWebSocketSubprotocolRejection] { rejections =>
//        val supported = rejections.map(_.supportedProtocol)
//        rejectRequestEntityAndComplete(HttpResponse(
//          BadRequest,
//          entity = s"None of the websocket subprotocols offered in the request are supported. Supported are ${supported.map("'" + _ + "'").mkString(",")}.",
//          headers = `Sec-WebSocket-Protocol`(supported) :: Nil))
//      }
      .handle { case ValidationRejection(msg, _) => rejectRequestEntityAndComplete((BadRequest, msg)) }
      .handle { case x => sys.error("Unhandled rejection: " + x) }
      .handleNotFound { rejectRequestEntityAndComplete((NotFound, "The requested resource could not be found.")) }
      .result()

  /**
   * Filters out all TransformationRejections from the given sequence and applies them (in order) to the
   * remaining rejections.
   */
  def applyTransformations(rejections: Seq[Rejection]): Seq[Rejection] = {
    val (transformations, rest) = rejections.partition(_.isInstanceOf[TransformationRejection])
    transformations.asInstanceOf[Seq[TransformationRejection]].foldLeft(rest.distinct) {
      case (remaining, transformation) => transformation.transform(remaining)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy