All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.scalatra.ScalatraKernel.scala Maven / Gradle / Ivy

package org.scalatra

import javax.servlet._
import javax.servlet.http._
import scala.util.DynamicVariable
import scala.util.matching.Regex
import scala.collection.JavaConversions._
import scala.collection.mutable.{ConcurrentMap, HashMap, ListBuffer}
import scala.xml.NodeSeq
import util.{MapWithIndifferentAccess, MultiMapHeadView, using}
import util.io.copy
import java.io.{File, FileInputStream}
import java.util.concurrent.ConcurrentHashMap
import scala.annotation.tailrec

object ScalatraKernel
{
  type MultiParams = Map[String, Seq[String]]

  type Action = () => Any

  val httpMethods = List("GET", "POST", "PUT", "DELETE", "OPTIONS")
  val writeMethods = "POST" :: "PUT" :: "DELETE" :: Nil
  val csrfKey = "csrfToken"

  val EnvironmentKey = "org.scalatra.environment"
}
import ScalatraKernel._

/**
 * ScalatraKernel provides the DSL for building Scalatra applications.
 *
 * At it's core a type mixing in ScalatraKernel is a registry of possible actions,
 * every request is dispatched to the first route matching.
 *
 * The [[org.scalatra.ScalatraKernel#get]], [[org.scalatra.ScalatraKernel#post]],
 * [[org.scalatra.ScalatraKernel#put]] and [[org.scalatra.ScalatraKernel#delete]]
 * methods register a new action to a route for a given HTTP method, possibly
 * overwriting a previous one. This trait is thread safe.
 */
trait ScalatraKernel extends Handler with Initializable
{
  protected val Routes: ConcurrentMap[String, List[Route]] = {
    val map = new ConcurrentHashMap[String, List[Route]]
    httpMethods foreach { x: String => map += ((x, List[Route]())) }
    map
  }

  def contentType = response.getContentType
  def contentType_=(value: String): Unit = response.setContentType(value)

  protected val defaultCharacterEncoding = "UTF-8"
  protected val _response   = new DynamicVariable[HttpServletResponse](null)
  protected val _request    = new DynamicVariable[HttpServletRequest](null)

  protected implicit def requestWrapper(r: HttpServletRequest) = RichRequest(r)
  protected implicit def sessionWrapper(s: HttpSession) = new RichSession(s)
  protected implicit def servletContextWrapper(sc: ServletContext) = new RichServletContext(sc)

  protected[scalatra] class Route(val routeMatchers: Iterable[RouteMatcher], val action: Action) {
    def apply(realPath: String): Option[Any] = RouteMatcher.matchRoute(routeMatchers) flatMap { invokeAction(_) }

    private def invokeAction(routeParams: MultiParams) =
      _multiParams.withValue(multiParams ++ routeParams) {
        try {
          Some(action.apply())
        }
        catch {
          case e: PassException => None
        }
      }

    override def toString() = routeMatchers.toString
  }

  /**
   * Pluggable way to convert Strings into RouteMatchers.  By default, we
   * interpret them the same way Sinatra does.
   */
  protected implicit def string2RouteMatcher(path: String): RouteMatcher =
    SinatraPathPatternParser(path)

  /**
   * Path pattern is decoupled from requests.  This adapts the PathPattern to
   * a RouteMatcher by supplying the request path.
   */
  protected implicit def pathPatternParser2RouteMatcher(pattern: PathPattern): RouteMatcher =
    new RouteMatcher {
      def apply() = pattern(requestPath)

      // By overriding toString, we can list the available routes in the
      // default notFound handler.
      override def toString = pattern.regex.toString
    }

  protected implicit def regex2RouteMatcher(regex: Regex): RouteMatcher = new RouteMatcher {
    def apply() = regex.findFirstMatchIn(requestPath) map { _.subgroups match {
      case Nil => Map.empty
      case xs => Map("captures" -> xs)
    }}

    override def toString = regex.toString
  }

  protected implicit def booleanBlock2RouteMatcher(matcher: => Boolean): RouteMatcher =
    () => { if (matcher) Some(Map[String, Seq[String]]()) else None }

  def handle(request: HttpServletRequest, response: HttpServletResponse) {
    // As default, the servlet tries to decode params with ISO_8859-1.
    // It causes an EOFException if params are actually encoded with the other code (such as UTF-8)
    if (request.getCharacterEncoding == null)
      request.setCharacterEncoding(defaultCharacterEncoding)

    val realMultiParams = request.getParameterMap.asInstanceOf[java.util.Map[String,Array[String]]].toMap
      .transform { (k, v) => v: Seq[String] }

    response.setCharacterEncoding(defaultCharacterEncoding)

    _request.withValue(request) {
      _response.withValue(response) {
        _multiParams.withValue(Map() ++ realMultiParams) {
          val result = try {
            beforeFilters foreach { _() }
            Routes(effectiveMethod).toStream.flatMap { _(requestPath) }.headOption.getOrElse(doNotFound())
          }
          catch {
            case HaltException(Some(code), Some(msg)) => response.sendError(code, msg)
            case HaltException(Some(code), None) => response.sendError(code)
            case HaltException(None, _) =>
            case e => handleError(e)
          }
          finally {
            afterFilters foreach { _() }
          }
          renderResponse(result)
        }
      }
    }
  }

  protected def effectiveMethod = request.getMethod.toUpperCase match {
    case "HEAD" => "GET"
    case x => x
  }

  def requestPath: String

  protected val beforeFilters = new ListBuffer[() => Any]
  def before(fun: => Any) = beforeFilters += { () => fun }

  protected val afterFilters = new ListBuffer[() => Any]
  def after(fun: => Any) = afterFilters += { () => fun }

  protected var doNotFound: Action
  def notFound(fun: => Any) = doNotFound = { () => fun }

  protected def handleError(e: Throwable): Any = {
    status(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
    _caughtThrowable.withValue(e) { errorHandler() }
  }

  protected var errorHandler: Action = { () => throw caughtThrowable }
  def error(fun: => Any) = errorHandler = { () => fun }

  private val _caughtThrowable = new DynamicVariable[Throwable](null)
  protected def caughtThrowable = _caughtThrowable.value

  protected def renderResponse(actionResult: Any) {
    if (contentType == null)
      contentType = inferContentType(actionResult)
    renderResponseBody(actionResult)
  }

  protected def inferContentType(actionResult: Any): String = actionResult match {
    case _: NodeSeq => "text/html"
    case _: Array[Byte] => "application/octet-stream"
    case _ => "text/plain"
  }

  protected def renderResponseBody(actionResult: Any) {
    actionResult match {
      case bytes: Array[Byte] =>
        response.getOutputStream.write(bytes)
      case file: File =>
        using(new FileInputStream(file)) { in => copy(in, response.getOutputStream) }
      case _: Unit =>
      // If an action returns Unit, it assumes responsibility for the response
      case x: Any  =>
        response.getWriter.print(x.toString)
    }
  }

  protected[scalatra] val _multiParams = new DynamicVariable[Map[String, Seq[String]]](Map())
  protected def multiParams: MultiParams = (_multiParams.value).withDefaultValue(Seq.empty)
  /*
   * Assumes that there is never a null or empty value in multiParams.  The servlet container won't put them
   * in request.getParameters, and we shouldn't either.
   */
  protected val _params = new MultiMapHeadView[String, String] with MapWithIndifferentAccess[String] {
    protected def multiMap = multiParams
  }
  def params = _params

  def redirect(uri: String) = (_response value) sendRedirect uri
  implicit def request = _request value
  implicit def response = _response value
  def session = request.getSession
  def sessionOption = request.getSession(false) match {
    case s: HttpSession => Some(s)
    case null => None
  }
  def status(code: Int) = (_response value) setStatus code

  def halt(code: Int, msg: String) = throw new HaltException(Some(code), Some(msg))
  def halt(code: Int) = throw new HaltException(Some(code), None)
  def halt() = throw new HaltException(None, None)
  private case class HaltException(val code: Option[Int], val msg: Option[String]) extends RuntimeException

  def pass() = throw new PassException
  protected[scalatra] class PassException extends RuntimeException

  /**
   * The Scalatra DSL core methods take a list of [[org.scalatra.RouteMatcher]] and a block as
   * the action body.
   * The return value of the block is converted to a string and sent to the client as the response body.
   *
   * See [[org.scalatra.ScalatraKernel.renderResponseBody]] for the detailed behaviour and how to handle your
   * response body more explicitly, and see how different return types are handled.
   *
   * The block is executed in the context of the ScalatraKernel instance, so all the methods defined in
   * this trait are also available inside the block.
   *
   * {{{
   *   get("/") {
   *     
* * *
* } * * post("/echo") { * "hello {params('name)}!" * } * }}} * * ScalatraKernel provides implicit transformation from boolean blocks, strings and regular expressions * to [[org.scalatra.RouteMatcher]], so you can write code naturally * {{{ * get("/", request.getRemoteHost == "127.0.0.1") { "Hello localhost!" } * }}} * */ def get(routeMatchers: RouteMatcher*)(action: => Any) = addRoute("GET", routeMatchers, action) /** * @see [[org.scalatra.ScalatraKernel.get]] */ def post(routeMatchers: RouteMatcher*)(action: => Any) = addRoute("POST", routeMatchers, action) /** * @see [[org.scalatra.ScalatraKernel.get]] */ def put(routeMatchers: RouteMatcher*)(action: => Any) = addRoute("PUT", routeMatchers, action) /** * @see [[org.scalatra.ScalatraKernel.get]] */ def delete(routeMatchers: RouteMatcher*)(action: => Any) = addRoute("DELETE", routeMatchers, action) /** * @see [[org.scalatra.ScalatraKernel.get]] */ def options(routeMatchers: RouteMatcher*)(action: => Any) = addRoute("OPTIONS", routeMatchers, action) /** * registers a new route for the given HTTP method, can be overriden so that subtraits can use their own logic * for example, restricting protocol usage, namespace routes based on class name, raise errors on overlapping entries * etc. * * This is the method invoked by get(), post() etc. * * @see removeRoute */ protected[scalatra] def addRoute(verb: String, routeMatchers: Iterable[RouteMatcher], action: => Any): Route = { val route = new Route(routeMatchers, () => action) modifyRoutes(verb, route :: _ ) route } /** * removes _all_ the actions of a given route for a given HTTP method. * If [[addRoute]] is overriden this should probably be overriden too. * * @see addRoute */ protected def removeRoute(verb: String, route: Route): Unit = { modifyRoutes(verb, _ filterNot (_ == route) ) route } /** * since routes is a ConcurrentMap and we avoid locking, we need to retry if there are * concurrent modifications, this is abstracted here for removeRoute and addRoute */ @tailrec private def modifyRoutes(protocol: String, f: (List[Route] => List[Route])): Unit = { val oldRoutes = Routes(protocol) if (!Routes.replace(protocol, oldRoutes, f(oldRoutes))) { modifyRoutes(protocol,f) } } private var config: Config = _ def initialize(config: Config) = this.config = config def initParameter(name: String): Option[String] = config match { case config: ServletConfig => Option(config.getInitParameter(name)) case config: FilterConfig => Option(config.getInitParameter(name)) case _ => None } def environment: String = System.getProperty(EnvironmentKey, initParameter(EnvironmentKey).getOrElse("development")) def isDevelopmentMode = environment.toLowerCase.startsWith("dev") }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy