cask.main.Main.scala Maven / Gradle / Ivy
package cask.main
import cask.endpoints.{WebsocketResult, WsHandler}
import cask.model._
import cask.internal.{DispatchTrie, Util}
import cask.main
import cask.router.{Decorator, EndpointMetadata, EntryPoint, RawDecorator, Result}
import cask.util.Logger
import io.undertow.Undertow
import io.undertow.server.{HttpHandler, HttpServerExchange}
import io.undertow.server.handlers.BlockingHandler
import io.undertow.util.HttpString
import scala.concurrent.ExecutionContext
/**
* A combination of [[cask.Main]] and [[cask.Routes]], ideal for small
* one-file web applications.
*/
class MainRoutes extends Main with Routes{
def allRoutes = Seq(this)
}
/**
* Defines the main entrypoint and configuration of the Cask web application.
*
* You can pass in an arbitrary number of [[cask.Routes]] objects for it to
* serve, and override various properties on [[Main]] in order to configure
* application-wide properties.
*/
abstract class Main{
def mainDecorators: Seq[Decorator[_, _, _, _]] = Nil
def allRoutes: Seq[Routes]
def port: Int = 8080
def host: String = "localhost"
def verbose = false
def debugMode: Boolean = true
def createExecutionContext = castor.Context.Simple.executionContext
def createActorContext = new castor.Context.Simple(executionContext, log.exception)
val executionContext = createExecutionContext
implicit val actorContext: castor.Context = createActorContext
implicit def log: cask.util.Logger = new cask.util.Logger.Console()
def dispatchTrie = Main.prepareDispatchTrie(allRoutes)
def defaultHandler = new BlockingHandler(
new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
)
def handleNotFound(req: Request): Response.Raw = Main.defaultHandleNotFound(req)
def handleMethodNotAllowed(req: Request): Response.Raw = Main.defaultHandleMethodNotAllowed(req)
def handleEndpointError(routes: Routes,
metadata: EndpointMetadata[_],
e: cask.router.Result.Error,
req: Request): Response.Raw = {
Main.defaultHandleError(routes, metadata, e, debugMode, req)
}
def main(args: Array[String]): Unit = {
if (!verbose) Main.silenceJboss()
val server = Undertow.builder
.addHttpListener(port, host)
.setHandler(defaultHandler)
.build
server.start()
}
}
object Main{
class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]],
mainDecorators: Seq[Decorator[_, _, _, _]],
debugMode: Boolean,
handleNotFound: Request => Response.Raw,
handleMethodNotAllowed: Request => Response.Raw,
handleError: (Routes, EndpointMetadata[_], Result.Error, Request) => Response.Raw)
(implicit log: Logger) extends HttpHandler() {
def handleRequest(exchange: HttpServerExchange): Unit = try {
// println("Handling Request: " + exchange.getRequestPath)
val (effectiveMethod, runner) = if ("websocket".equalsIgnoreCase(exchange.getRequestHeaders.getFirst("Upgrade"))) {
Tuple2(
"websocket",
(r: Any) =>
r.asInstanceOf[WebsocketResult] match{
case l: WsHandler =>
io.undertow.Handlers.websocket(l).handleRequest(exchange)
case l: WebsocketResult.Listener =>
io.undertow.Handlers.websocket(l.value).handleRequest(exchange)
case r: WebsocketResult.Response[Response.Data] =>
Main.writeResponse(exchange, r.value)
}
)
} else Tuple2(
exchange.getRequestMethod.toString.toLowerCase(),
(r: Any) => Main.writeResponse(exchange, r.asInstanceOf[Response.Raw])
)
val decodedSegments = Util
.splitPath(exchange.getRequestURI)
.iterator
.map(java.net.URLDecoder.decode(_, "UTF-8"))
.toList
dispatchTrie.lookup(decodedSegments, Vector()) match {
case None => Main.writeResponse(exchange, handleNotFound(Request(exchange, decodedSegments, Map())))
case Some((methodMap, routeBindings, remaining)) =>
methodMap.get(effectiveMethod) match {
case None => Main.writeResponse(exchange, handleMethodNotAllowed(Request(exchange, remaining, routeBindings)))
case Some((routes, metadata)) =>
val req = Request(exchange, remaining, routeBindings)
Decorator.invoke(
req,
metadata.endpoint,
metadata.entryPoint.asInstanceOf[EntryPoint[Routes, _]],
routes,
(mainDecorators ++ routes.decorators ++ metadata.decorators).toList,
Nil,
Nil
) match {
case Result.Success(res) => runner(res)
case e: Result.Error =>
Main.writeResponse(
exchange,
handleError(routes, metadata, e, req)
)
}
}
}
}catch{case e: Throwable =>
e.printStackTrace()
}
}
def defaultHandleNotFound(req: Request): Response.Raw = {
Response(
s"Error 404: ${Status.codesToStatus(404).reason}",
statusCode = 404
)
}
def defaultHandleMethodNotAllowed(req: Request): Response.Raw = {
Response(
s"Error 405: ${Status.codesToStatus(405).reason}",
statusCode = 405
)
}
def prepareDispatchTrie(allRoutes: Seq[Routes]): DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = {
val flattenedRoutes = for {
routes <- allRoutes
metadata <- routes.caskMetadata.value
} yield {
val segments = Util.splitPath(metadata.endpoint.path)
val methods = metadata.endpoint.methods.map(_ -> (routes, metadata: EndpointMetadata[_]))
val methodMap = methods.toMap[String, (Routes, EndpointMetadata[_])]
val subpath =
metadata.endpoint.subpath ||
metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments))
(segments, methodMap, subpath)
}
val dispatchInputs = flattenedRoutes.groupBy(_._1).map { case (segments, values) =>
val methodMap = values.map(_._2).flatten
val hasSubpath = values.map(_._3).contains(true)
(segments, methodMap, hasSubpath)
}.toSeq
DispatchTrie.construct(0, dispatchInputs)(_.map(_._1)).map(_.toMap)
}
def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = {
response.data.headers.foreach{case (k, v) =>
exchange.getResponseHeaders.put(new HttpString(k), v)
}
response.headers.foreach{case (k, v) =>
exchange.getResponseHeaders.put(new HttpString(k), v)
}
response.cookies.foreach(c => exchange.setResponseCookie(Cookie.toUndertow(c)))
exchange.setStatusCode(response.statusCode)
val output = exchange.getOutputStream
response.data.write(new java.io.OutputStream {
def write(b: Int): Unit = output.write(b)
override def write(b: Array[Byte]): Unit = output.write(b)
override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len)
override def close() = {
if (!exchange.isComplete) output.close()
}
override def flush() = {
if (!exchange.isComplete) output.flush()
}
})
}
def defaultHandleError(routes: Routes,
metadata: EndpointMetadata[_],
e: Result.Error,
debugMode: Boolean,
req: Request)
(implicit log: Logger) = {
e match {
case e: Result.Error.Exception => log.exception(e.t)
case _ => // do nothing
}
val statusCode = e match {
case _: Result.Error.Exception => 500
case _: Result.Error.InvalidArguments => 400
case _: Result.Error.MismatchedArguments => 400
}
val str =
if (!debugMode) s"Error $statusCode: ${Status.codesToStatus(statusCode).reason}"
else ErrorMsgs.formatInvokeError(
routes,
metadata.entryPoint.asInstanceOf[EntryPoint[cask.main.Routes, _]],
e
)
Response(str, statusCode = statusCode)
}
def silenceJboss(): Unit = {
// Some jboss classes litter logs from their static initializers. This is a
// workaround to stop this rather annoying behavior.
val tmp = System.out
System.setOut(null)
org.jboss.threads.Version.getVersionString() // this causes the static initializer to be run
System.setOut(tmp)
// Other loggers print way too much information. Set them to only print
// interesting stuff.
val level = java.util.logging.Level.WARNING
java.util.logging.Logger.getLogger("org.jboss").setLevel(level)
java.util.logging.Logger.getLogger("org.xnio").setLevel(level)
java.util.logging.Logger.getLogger("io.undertow").setLevel(level)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy