com.firefly.kotlin.ext.http.HttpServerExtension.kt Maven / Gradle / Ivy
The newest version!
package com.firefly.kotlin.ext.http
import com.firefly.`$`
import com.firefly.codec.http2.model.*
import com.firefly.codec.websocket.frame.Frame
import com.firefly.codec.websocket.stream.AbstractWebSocketBuilder
import com.firefly.codec.websocket.stream.WebSocketConnection
import com.firefly.kotlin.ext.annotation.NoArg
import com.firefly.kotlin.ext.common.CoroutineLocalContext
import com.firefly.kotlin.ext.common.Json
import com.firefly.kotlin.ext.common.launchTraceable
import com.firefly.server.http2.SimpleHTTPServer
import com.firefly.server.http2.SimpleHTTPServerConfiguration
import com.firefly.server.http2.SimpleRequest
import com.firefly.server.http2.WebSocketHandler
import com.firefly.server.http2.router.*
import com.firefly.server.http2.router.handler.body.HTTPBodyConfiguration
import com.firefly.server.http2.router.handler.error.DefaultErrorResponseHandlerLoader
import com.firefly.server.http2.router.impl.RoutingContextImpl
import kotlinx.coroutines.*
import kotlinx.coroutines.future.await
import org.slf4j.LoggerFactory
import java.io.Closeable
import java.net.InetAddress
import java.util.*
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionException
import java.util.concurrent.ConcurrentLinkedDeque
import java.util.concurrent.TimeUnit
import java.util.function.Supplier
import kotlin.coroutines.CoroutineContext
/**
* Firefly HTTP server extensions.
*
* @author Pengtao Qiu
*/
val sysLogger = LoggerFactory.getLogger("firefly-system")
// HTTP server API extensions
inline fun RoutingContext.getJsonBody(charset: String): T = Json.parse(getStringBody(charset))
inline fun RoutingContext.getJsonBody(): T = Json.parse(stringBody)
inline fun RoutingContext.getAttr(name: String): T? {
val data = getAttribute(name) ?: return null
if (data is T) {
return data
} else {
throw ClassCastException("The attribute $name type is ${data::class.java}. It can't cast to ${T::class.java}")
}
}
inline fun SimpleRequest.getJsonBody(charset: String): T = Json.parse(getStringBody(charset))
inline fun SimpleRequest.getJsonBody(): T = Json.parse(stringBody)
data class AsyncPromise(val succeeded: suspend (C) -> Unit, val failed: suspend (Throwable?) -> Unit)
const val promiseQueueKey = "_promiseQueue"
fun RoutingContext.getPromiseQueue(): Deque>? = getAttr(promiseQueueKey)
@Suppress("UNCHECKED_CAST")
fun RoutingContext.createPromiseQueueIfAbsent(): Deque> =
attributes.computeIfAbsent(promiseQueueKey) { ConcurrentLinkedDeque>() } as Deque>
/**
* Set the callback that is called when the asynchronous handler finishes.
*/
fun RoutingContext.asyncComplete(
succeeded: suspend (C) -> Unit,
failed: suspend (Throwable?) -> Unit
): RoutingContext {
val queue = createPromiseQueueIfAbsent()
queue.push(AsyncPromise(succeeded, failed))
return this
}
fun RoutingContext.asyncComplete(succeeded: suspend (C) -> Unit): RoutingContext {
asyncComplete(succeeded, { this.asyncFail(it) })
return this
}
/**
* Execute the next asynchronous handler and set the callback is called when the asynchronous handler finishes.
*/
fun RoutingContext.asyncNext(succeeded: suspend (C) -> Unit, failed: suspend (Throwable?) -> Unit): Boolean {
asyncComplete(succeeded, failed)
return next()
}
fun RoutingContext.asyncNext(succeeded: suspend (C) -> Unit): Boolean {
asyncComplete(succeeded, { this.asyncFail(it) })
return next()
}
suspend fun RoutingContext.asyncNext(): Pair {
val future = CompletableFuture()
val hasNext = asyncNext({ future.complete(it) }, { future.completeExceptionally(it) })
return hasNext to future.await()
}
suspend fun RoutingContext.asyncNext(time: Long, unit: TimeUnit): Pair {
val future = CompletableFuture()
val hasNext = asyncNext({ future.complete(it) }, { future.completeExceptionally(it) })
return withTimeout(unit.toMillis(time)) { hasNext to future.await() }
}
/**
* Execute asynchronous succeeded callback.
*/
suspend fun RoutingContext.asyncSucceed(result: C) {
getPromiseQueue()?.pop()?.succeeded?.invoke(result)
}
/**
* Execute asynchronous failed callback.
*/
suspend fun RoutingContext.asyncFail(x: Throwable? = null) {
getPromiseQueue()?.pop()?.failed?.invoke(x)
}
/**
* Get the real client ip.
*/
fun RoutingContext.getRealClientIp(): String = Optional
.ofNullable(fields["X-Forwarded-For"])
.map { `$`.string.split(it, ",") }
.filter { it.isNotEmpty() }
.map { it[0].trim() }
.orElseGet { request.connection.remoteAddress.toString() }
/**
* Get current HTTPSession. This function does not throw any exception.
*/
suspend fun RoutingContext.getCurrentSessionQuietly(sessionKey: String = "_sessionKey"): HTTPSession? = try {
val ret = attributes[sessionKey]
if (ret != null) {
ret as HTTPSession
} else {
val session = getSession(false).await()
attributes[sessionKey] = session
session
}
} catch (e: SessionNotFound) {
null
} catch (e: CompletionException) {
sysLogger.info("get session failure. ${e.cause}")
null
} catch (e: Exception) {
sysLogger.error("get session exception", e)
null
}
// HTTP server DSL
/**
* Response status line block
*
* @param block Response status line statement
*/
inline fun RoutingContext.statusLine(block: StatusLineBlock.() -> Unit) = block.invoke(StatusLineBlock(this))
class StatusLineBlock(private val ctx: RoutingContext) {
var status: Int = HttpStatus.OK_200
set(value) {
ctx.setStatus(value)
field = value
}
var reason: String = HttpStatus.Code.OK.message
set(value) {
ctx.setReason(value)
field = value
}
var httpVersion: HttpVersion = HttpVersion.HTTP_1_1
set(value) {
ctx.httpVersion = value
field = value
}
override fun toString(): String = "StatusLineBlock(status=$status, reason='$reason', httpVersion=$httpVersion)"
}
interface HttpFieldOperator {
infix fun String.to(value: String)
infix fun HttpHeader.to(value: String)
operator fun HttpField.unaryPlus()
}
/**
* Response HTTP header block
*
* @param block HTTP header statement
*/
inline fun RoutingContext.header(block: HeaderBlock.() -> Unit) = block.invoke(HeaderBlock(this))
class HeaderBlock(ctx: RoutingContext) : HttpFieldOperator {
val httpFields: HttpFields = ctx.response.fields
override infix fun String.to(value: String) {
httpFields.put(this, value)
}
override infix fun HttpHeader.to(value: String) {
httpFields.put(this, value)
}
override operator fun HttpField.unaryPlus() {
httpFields.add(this)
}
override fun toString(): String = "HeaderBlock(httpFields=$httpFields)"
}
/**
* Response HTTP trailer block
*
* @param block HTTP trailer statement
*/
inline fun RoutingContext.trailer(block: TrailerBlock.() -> Unit) = block.invoke(TrailerBlock(this))
class TrailerBlock(ctx: RoutingContext) : Supplier, HttpFieldOperator {
val httpFields: HttpFields = HttpFields()
init {
ctx.response.trailerSupplier = this
}
override fun get(): HttpFields = httpFields
override infix fun String.to(value: String) {
httpFields.put(this, value)
}
override infix fun HttpHeader.to(value: String) {
httpFields.put(this, value)
}
override operator fun HttpField.unaryPlus() {
httpFields.add(this)
}
override fun toString(): String = "TrailerBlock(httpFields=$httpFields)"
}
interface AsyncHandler {
suspend fun handle(ctx: RoutingContext)
}
const val httpCtxKey = "_httpCtxKey"
/**
* Get the routing context in the current coroutine scope.
*
* @return The routing context in the current coroutine scope.
*/
fun getRequestContext(): RoutingContext? = CoroutineLocalContext.getAttr(httpCtxKey)
@HttpServerMarker
class RouterBlock(
private val router: Router,
private val coroutineDispatcher: CoroutineDispatcher
) {
var method: String = HttpMethod.GET.asString()
set(value) {
router.method(value)
field = value
}
var methods: List = listOf(HttpMethod.GET.asString(), HttpMethod.POST.asString())
set(value) {
value.forEach { router.method(it) }
field = value
}
var httpMethod: HttpMethod = HttpMethod.GET
set(value) {
router.method(value)
field = value
}
var httpMethods: List = listOf(HttpMethod.GET, HttpMethod.POST)
set(value) {
value.forEach { router.method(it) }
field = value
}
var path: String = ""
set(value) {
router.path(value)
field = value
}
var paths: List = listOf()
set(value) {
router.paths(value)
field = value
}
var regexPath: String = ""
set(value) {
router.pathRegex(value)
field = value
}
var consumes: String = ""
set(value) {
router.consumes(value)
field = value
}
var produces: String = ""
set(value) {
router.produces(value)
field = value
}
fun getId() = router.id
/**
* Register a handler that is executed in the coroutine asynchronously.
*
* @param handler The handler that processes the business logic.
*/
fun asyncHandler(handler: suspend RoutingContext.(context: CoroutineContext) -> Unit) {
router.handler {
it.response.isAsynchronous = true
launchTraceable(coroutineDispatcher, mutableMapOf(httpCtxKey to it)) {
handler.invoke(it, coroutineContext)
}
}
}
/**
* Register a handler that is executed in the coroutine asynchronously.
*
* @param handler The handler that processes the business logic.
*/
fun asyncHandler(handler: AsyncHandler) = asyncHandler {
handler.handle(this)
}
/**
* Automatically call the succeeded callback when the asynchronous handler has completed.
*/
fun asyncCompleteHandler(handler: suspend RoutingContext.(context: CoroutineContext) -> Unit) = asyncHandler {
try {
handler.invoke(this, it)
asyncSucceed(Unit)
} catch (x: Throwable) {
asyncFail(x)
}
}
/**
* Register a handler that is executed in the network thread synchronously.
*
* @param handler The handler that processes the business logic.
*/
fun handler(handler: RoutingContext.() -> Unit) {
router.handler(handler)
}
/**
* Register a handler that is executed in the network thread synchronously.
*
* @param handler The handler that processes the business logic.
*/
fun handler(handler: Handler) {
router.handler(handler)
}
/**
* Automatically close the resource when the block has completed.
*/
suspend fun T.safeUse(block: suspend (T) -> R): R {
var closed = false
try {
return block(this)
} catch (e: Exception) {
try {
withContext(NonCancellable) {
closed = true
this@safeUse?.close()
}
} catch (closeException: Exception) {
}
throw e
} finally {
if (!closed) {
withContext(NonCancellable) {
this@safeUse?.close()
}
}
}
}
override fun toString(): String = router.toString()
}
@HttpServerMarker
class WebSocketBlock(
server: SimpleHTTPServer,
router: Router,
private val path: String
) : AbstractWebSocketBuilder() {
var onConnect: ((WebSocketConnection) -> Unit)? = null
init {
server.registerWebSocket(path, object : WebSocketHandler {
override fun onConnect(connection: WebSocketConnection) {
[email protected]?.invoke(connection)
}
override fun onFrame(frame: Frame, connection: WebSocketConnection) {
[email protected](frame, connection)
}
override fun onError(t: Throwable, connection: WebSocketConnection) {
[email protected](t, connection)
}
})
router.path(path).handler { }
}
fun onConnect(onConnect: (WebSocketConnection) -> Unit) {
this.onConnect = onConnect
}
override fun toString(): String = "WebSocket(path='$path')"
}
interface HttpServerLifecycle {
/**
* Stop the HTTP server.
*/
fun stop()
/**
* Start the HTTP server.
*
* @param host The server hostname.
* @param port The server port.
*/
fun listen(host: String, port: Int)
/**
* Start the HTTP server and set the address of the local host.
*
* @param port The server port.
*/
fun listen(port: Int)
/**
* Start the HTTP server. You must set host and port in the SimpleHTTPServerConfiguration.
*/
fun listen()
}
@DslMarker
annotation class HttpServerMarker
/**
* HTTP server DSL. It helps you write HTTP server elegantly.
*
* @param serverConfiguration HTTP server configuration.
* @param httpBodyConfiguration HTTP body configuration.
* @param block The HTTP server DSL block. You can register routers in this block.
*/
@HttpServerMarker
class HttpServer(
serverConfiguration: SimpleHTTPServerConfiguration = SimpleHTTPServerConfiguration(),
httpBodyConfiguration: HTTPBodyConfiguration = HTTPBodyConfiguration(),
block: HttpServer.() -> Unit
) : HttpServerLifecycle {
val server = SimpleHTTPServer(serverConfiguration)
val routerManager = RouterManager.create(httpBodyConfiguration)!!
val coroutineDispatcher = server.handlerExecutorService.asCoroutineDispatcher()
val defaultErrorHandler = DefaultErrorResponseHandlerLoader.getInstance().handler!!
init {
server.badMessage { status, reason, request ->
val ctx = RoutingContextImpl(request, Collections.emptyNavigableSet())
defaultErrorHandler.render(ctx, status, BadMessageException(reason))
}
block.invoke(this)
}
constructor() : this(SimpleHTTPServerConfiguration(), HTTPBodyConfiguration(), {})
constructor(block: HttpServer.() -> Unit)
: this(SimpleHTTPServerConfiguration(), HTTPBodyConfiguration(), block)
constructor(
serverConfiguration: SimpleHTTPServerConfiguration,
httpBodyConfiguration: HTTPBodyConfiguration
)
: this(serverConfiguration, httpBodyConfiguration, {})
override fun stop() = server.stop()
override fun listen(host: String, port: Int) = server.headerComplete(routerManager::accept).listen(host, port)
override fun listen(port: Int) = listen(InetAddress.getLocalHost().hostAddress, port)
override fun listen() = server.headerComplete(routerManager::accept).listen()
fun enableSecureConnection(): HttpServer {
this.server.configuration.isSecureConnectionEnabled = true
return this
}
/**
* Register a router using the DSL with a autoincrement ID.
*
* @param block The router builder.
*/
inline fun router(block: RouterBlock.() -> Unit) {
val r = RouterBlock(routerManager.register(), coroutineDispatcher)
block.invoke(r)
sysLogger.info("register $r")
}
/**
* Register a router using the DSL with a specified ID.
*
* @param id The router ID. The router is sorted by ID ascending order.
* @param block The router builder.
*/
inline fun router(id: Int, block: RouterBlock.() -> Unit) {
val r = RouterBlock(routerManager.register(id), coroutineDispatcher)
block.invoke(r)
sysLogger.info("register $r")
}
inline fun webSocket(path: String, block: WebSocketBlock.() -> Unit) {
val r = WebSocketBlock(server, routerManager.register(), path)
block.invoke(r)
sysLogger.info("register $r")
}
/**
* Register routers using the DSL.
*
* @param block Register routers in this block.
*/
inline fun addRouters(block: HttpServer.() -> Unit) = block.invoke(this)
}
class AccessLogService(
logName: String = "firefly-access",
val userTracingId: String = "_const_firefly_user_id_"
) {
private val log = LoggerFactory.getLogger(logName)
fun recordAccessLog(ctx: RoutingContext, startTime: Long, endTime: Long) {
val accessLog = toAccessLog(ctx, endTime - startTime)
log.info(`$`.json.toJson(accessLog))
}
fun toAccessLog(ctx: RoutingContext, time: Long): AccessLog {
val requestBody = if (ctx.method == HttpMethod.POST.asString() || ctx.method == HttpMethod.PUT.asString()) {
val contentType = ctx.fields[HttpHeader.CONTENT_TYPE.asString()]
if (`$`.string.hasText(contentType)) {
if (contentType.startsWith("application/json") || contentType.startsWith("application/x-www-form-urlencoded")) {
ctx.stringBody
} else ""
} else ""
} else ""
val tid = ctx.getAttr(userTracingId)
return AccessLog(
ctx.getRealClientIp(),
ctx.method,
ctx.uri.toString(),
requestBody,
time,
ctx.response.status,
tid
)
}
}
@NoArg
data class AccessLog(
var ip: String,
var method: String,
var uri: String,
var requestBody: String?,
var time: Long,
var responseStatus: Int,
var tid: String?
)