All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.firefly.kotlin.ext.http.HttpServerExtension.kt Maven / Gradle / Ivy

The newest version!
package com.firefly.kotlin.ext.http

import com.firefly.`$`
import com.firefly.codec.http2.model.*
import com.firefly.codec.websocket.frame.Frame
import com.firefly.codec.websocket.stream.AbstractWebSocketBuilder
import com.firefly.codec.websocket.stream.WebSocketConnection
import com.firefly.kotlin.ext.annotation.NoArg
import com.firefly.kotlin.ext.common.CoroutineLocalContext
import com.firefly.kotlin.ext.common.Json
import com.firefly.kotlin.ext.common.launchTraceable
import com.firefly.server.http2.SimpleHTTPServer
import com.firefly.server.http2.SimpleHTTPServerConfiguration
import com.firefly.server.http2.SimpleRequest
import com.firefly.server.http2.WebSocketHandler
import com.firefly.server.http2.router.*
import com.firefly.server.http2.router.handler.body.HTTPBodyConfiguration
import com.firefly.server.http2.router.handler.error.DefaultErrorResponseHandlerLoader
import com.firefly.server.http2.router.impl.RoutingContextImpl
import kotlinx.coroutines.*
import kotlinx.coroutines.future.await
import org.slf4j.LoggerFactory
import java.io.Closeable
import java.net.InetAddress
import java.util.*
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionException
import java.util.concurrent.ConcurrentLinkedDeque
import java.util.concurrent.TimeUnit
import java.util.function.Supplier
import kotlin.coroutines.CoroutineContext

/**
 * Firefly HTTP server extensions.
 *
 * @author Pengtao Qiu
 */
val sysLogger = LoggerFactory.getLogger("firefly-system")

// HTTP server API extensions
inline fun  RoutingContext.getJsonBody(charset: String): T = Json.parse(getStringBody(charset))

inline fun  RoutingContext.getJsonBody(): T = Json.parse(stringBody)

inline fun  RoutingContext.getAttr(name: String): T? {
    val data = getAttribute(name) ?: return null
    if (data is T) {
        return data
    } else {
        throw ClassCastException("The attribute $name type is ${data::class.java}. It can't cast to ${T::class.java}")
    }
}

inline fun  SimpleRequest.getJsonBody(charset: String): T = Json.parse(getStringBody(charset))

inline fun  SimpleRequest.getJsonBody(): T = Json.parse(stringBody)

data class AsyncPromise(val succeeded: suspend (C) -> Unit, val failed: suspend (Throwable?) -> Unit)

const val promiseQueueKey = "_promiseQueue"

fun  RoutingContext.getPromiseQueue(): Deque>? = getAttr(promiseQueueKey)

@Suppress("UNCHECKED_CAST")
fun  RoutingContext.createPromiseQueueIfAbsent(): Deque> =
    attributes.computeIfAbsent(promiseQueueKey) { ConcurrentLinkedDeque>() } as Deque>

/**
 * Set the callback that is called when the asynchronous handler finishes.
 */
fun  RoutingContext.asyncComplete(
    succeeded: suspend (C) -> Unit,
    failed: suspend (Throwable?) -> Unit
                                    ): RoutingContext {
    val queue = createPromiseQueueIfAbsent()
    queue.push(AsyncPromise(succeeded, failed))
    return this
}

fun  RoutingContext.asyncComplete(succeeded: suspend (C) -> Unit): RoutingContext {
    asyncComplete(succeeded, { this.asyncFail(it) })
    return this
}

/**
 * Execute the next asynchronous handler and set the callback is called when the asynchronous handler finishes.
 */
fun  RoutingContext.asyncNext(succeeded: suspend (C) -> Unit, failed: suspend (Throwable?) -> Unit): Boolean {
    asyncComplete(succeeded, failed)
    return next()
}

fun  RoutingContext.asyncNext(succeeded: suspend (C) -> Unit): Boolean {
    asyncComplete(succeeded, { this.asyncFail(it) })
    return next()
}

suspend fun  RoutingContext.asyncNext(): Pair {
    val future = CompletableFuture()
    val hasNext = asyncNext({ future.complete(it) }, { future.completeExceptionally(it) })
    return hasNext to future.await()
}

suspend fun  RoutingContext.asyncNext(time: Long, unit: TimeUnit): Pair {
    val future = CompletableFuture()
    val hasNext = asyncNext({ future.complete(it) }, { future.completeExceptionally(it) })
    return withTimeout(unit.toMillis(time)) { hasNext to future.await() }
}

/**
 * Execute asynchronous succeeded callback.
 */
suspend fun  RoutingContext.asyncSucceed(result: C) {
    getPromiseQueue()?.pop()?.succeeded?.invoke(result)
}

/**
 * Execute asynchronous failed callback.
 */
suspend fun  RoutingContext.asyncFail(x: Throwable? = null) {
    getPromiseQueue()?.pop()?.failed?.invoke(x)
}

/**
 * Get the real client ip.
 */
fun RoutingContext.getRealClientIp(): String = Optional
    .ofNullable(fields["X-Forwarded-For"])
    .map { `$`.string.split(it, ",") }
    .filter { it.isNotEmpty() }
    .map { it[0].trim() }
    .orElseGet { request.connection.remoteAddress.toString() }

/**
 * Get current HTTPSession. This function does not throw any exception.
 */
suspend fun RoutingContext.getCurrentSessionQuietly(sessionKey: String = "_sessionKey"): HTTPSession? = try {
    val ret = attributes[sessionKey]
    if (ret != null) {
        ret as HTTPSession
    } else {
        val session = getSession(false).await()
        attributes[sessionKey] = session
        session
    }
} catch (e: SessionNotFound) {
    null
} catch (e: CompletionException) {
    sysLogger.info("get session failure. ${e.cause}")
    null
} catch (e: Exception) {
    sysLogger.error("get session exception", e)
    null
}

// HTTP server DSL

/**
 * Response status line block
 *
 * @param block Response status line statement
 */
inline fun RoutingContext.statusLine(block: StatusLineBlock.() -> Unit) = block.invoke(StatusLineBlock(this))

class StatusLineBlock(private val ctx: RoutingContext) {
    var status: Int = HttpStatus.OK_200
        set(value) {
            ctx.setStatus(value)
            field = value
        }

    var reason: String = HttpStatus.Code.OK.message
        set(value) {
            ctx.setReason(value)
            field = value
        }

    var httpVersion: HttpVersion = HttpVersion.HTTP_1_1
        set(value) {
            ctx.httpVersion = value
            field = value
        }

    override fun toString(): String = "StatusLineBlock(status=$status, reason='$reason', httpVersion=$httpVersion)"

}

interface HttpFieldOperator {
    infix fun String.to(value: String)

    infix fun HttpHeader.to(value: String)

    operator fun HttpField.unaryPlus()
}

/**
 * Response HTTP header block
 *
 * @param block HTTP header statement
 */
inline fun RoutingContext.header(block: HeaderBlock.() -> Unit) = block.invoke(HeaderBlock(this))

class HeaderBlock(ctx: RoutingContext) : HttpFieldOperator {

    val httpFields: HttpFields = ctx.response.fields

    override infix fun String.to(value: String) {
        httpFields.put(this, value)
    }

    override infix fun HttpHeader.to(value: String) {
        httpFields.put(this, value)
    }

    override operator fun HttpField.unaryPlus() {
        httpFields.add(this)
    }

    override fun toString(): String = "HeaderBlock(httpFields=$httpFields)"
}

/**
 * Response HTTP trailer block
 *
 * @param block HTTP trailer statement
 */
inline fun RoutingContext.trailer(block: TrailerBlock.() -> Unit) = block.invoke(TrailerBlock(this))

class TrailerBlock(ctx: RoutingContext) : Supplier, HttpFieldOperator {

    val httpFields: HttpFields = HttpFields()

    init {
        ctx.response.trailerSupplier = this
    }

    override fun get(): HttpFields = httpFields

    override infix fun String.to(value: String) {
        httpFields.put(this, value)
    }

    override infix fun HttpHeader.to(value: String) {
        httpFields.put(this, value)
    }

    override operator fun HttpField.unaryPlus() {
        httpFields.add(this)
    }

    override fun toString(): String = "TrailerBlock(httpFields=$httpFields)"

}

interface AsyncHandler {
    suspend fun handle(ctx: RoutingContext)
}

const val httpCtxKey = "_httpCtxKey"

/**
 * Get the routing context in the current coroutine scope.
 *
 * @return The routing context in the current coroutine scope.
 */
fun getRequestContext(): RoutingContext? = CoroutineLocalContext.getAttr(httpCtxKey)

@HttpServerMarker
class RouterBlock(
    private val router: Router,
    private val coroutineDispatcher: CoroutineDispatcher
                 ) {

    var method: String = HttpMethod.GET.asString()
        set(value) {
            router.method(value)
            field = value
        }

    var methods: List = listOf(HttpMethod.GET.asString(), HttpMethod.POST.asString())
        set(value) {
            value.forEach { router.method(it) }
            field = value
        }

    var httpMethod: HttpMethod = HttpMethod.GET
        set(value) {
            router.method(value)
            field = value
        }

    var httpMethods: List = listOf(HttpMethod.GET, HttpMethod.POST)
        set(value) {
            value.forEach { router.method(it) }
            field = value
        }

    var path: String = ""
        set(value) {
            router.path(value)
            field = value
        }

    var paths: List = listOf()
        set(value) {
            router.paths(value)
            field = value
        }

    var regexPath: String = ""
        set(value) {
            router.pathRegex(value)
            field = value
        }

    var consumes: String = ""
        set(value) {
            router.consumes(value)
            field = value
        }

    var produces: String = ""
        set(value) {
            router.produces(value)
            field = value
        }

    fun getId() = router.id
    /**
     * Register a handler that is executed in the coroutine asynchronously.
     *
     * @param handler The handler that processes the business logic.
     */
    fun asyncHandler(handler: suspend RoutingContext.(context: CoroutineContext) -> Unit) {
        router.handler {
            it.response.isAsynchronous = true
            launchTraceable(coroutineDispatcher, mutableMapOf(httpCtxKey to it)) {
                handler.invoke(it, coroutineContext)
            }
        }
    }

    /**
     * Register a handler that is executed in the coroutine asynchronously.
     *
     * @param handler The handler that processes the business logic.
     */
    fun asyncHandler(handler: AsyncHandler) = asyncHandler {
        handler.handle(this)
    }

    /**
     * Automatically call the succeeded callback when the asynchronous handler has completed.
     */
    fun asyncCompleteHandler(handler: suspend RoutingContext.(context: CoroutineContext) -> Unit) = asyncHandler {
        try {
            handler.invoke(this, it)
            asyncSucceed(Unit)
        } catch (x: Throwable) {
            asyncFail(x)
        }
    }

    /**
     * Register a handler that is executed in the network thread synchronously.
     *
     * @param handler The handler that processes the business logic.
     */
    fun handler(handler: RoutingContext.() -> Unit) {
        router.handler(handler)
    }

    /**
     * Register a handler that is executed in the network thread synchronously.
     *
     * @param handler The handler that processes the business logic.
     */
    fun handler(handler: Handler) {
        router.handler(handler)
    }

    /**
     * Automatically close the resource when the block has completed.
     */
    suspend fun  T.safeUse(block: suspend (T) -> R): R {
        var closed = false
        try {
            return block(this)
        } catch (e: Exception) {
            try {
                withContext(NonCancellable) {
                    closed = true
                    this@safeUse?.close()
                }
            } catch (closeException: Exception) {
            }
            throw e
        } finally {
            if (!closed) {
                withContext(NonCancellable) {
                    this@safeUse?.close()
                }
            }
        }
    }

    override fun toString(): String = router.toString()

}

@HttpServerMarker
class WebSocketBlock(
    server: SimpleHTTPServer,
    router: Router,
    private val path: String
                    ) : AbstractWebSocketBuilder() {

    var onConnect: ((WebSocketConnection) -> Unit)? = null

    init {
        server.registerWebSocket(path, object : WebSocketHandler {
            override fun onConnect(connection: WebSocketConnection) {
                [email protected]?.invoke(connection)
            }

            override fun onFrame(frame: Frame, connection: WebSocketConnection) {
                [email protected](frame, connection)
            }

            override fun onError(t: Throwable, connection: WebSocketConnection) {
                [email protected](t, connection)
            }
        })
        router.path(path).handler { }
    }

    fun onConnect(onConnect: (WebSocketConnection) -> Unit) {
        this.onConnect = onConnect
    }

    override fun toString(): String = "WebSocket(path='$path')"
}

interface HttpServerLifecycle {

    /**
     * Stop the HTTP server.
     */
    fun stop()

    /**
     * Start the HTTP server.
     *
     * @param host The server hostname.
     * @param port The server port.
     */
    fun listen(host: String, port: Int)

    /**
     * Start the HTTP server and set the address of the local host.
     *
     * @param port The server port.
     */
    fun listen(port: Int)

    /**
     * Start the HTTP server. You must set host and port in the SimpleHTTPServerConfiguration.
     */
    fun listen()
}

@DslMarker
annotation class HttpServerMarker

/**
 * HTTP server DSL. It helps you write HTTP server elegantly.
 *
 * @param serverConfiguration HTTP server configuration.
 * @param httpBodyConfiguration HTTP body configuration.
 * @param block The HTTP server DSL block. You can register routers in this block.
 */
@HttpServerMarker
class HttpServer(
    serverConfiguration: SimpleHTTPServerConfiguration = SimpleHTTPServerConfiguration(),
    httpBodyConfiguration: HTTPBodyConfiguration = HTTPBodyConfiguration(),
    block: HttpServer.() -> Unit
                ) : HttpServerLifecycle {

    val server = SimpleHTTPServer(serverConfiguration)
    val routerManager = RouterManager.create(httpBodyConfiguration)!!
    val coroutineDispatcher = server.handlerExecutorService.asCoroutineDispatcher()
    val defaultErrorHandler = DefaultErrorResponseHandlerLoader.getInstance().handler!!

    init {
        server.badMessage { status, reason, request ->
            val ctx = RoutingContextImpl(request, Collections.emptyNavigableSet())
            defaultErrorHandler.render(ctx, status, BadMessageException(reason))
        }
        block.invoke(this)
    }

    constructor() : this(SimpleHTTPServerConfiguration(), HTTPBodyConfiguration(), {})

    constructor(block: HttpServer.() -> Unit)
            : this(SimpleHTTPServerConfiguration(), HTTPBodyConfiguration(), block)

    constructor(
        serverConfiguration: SimpleHTTPServerConfiguration,
        httpBodyConfiguration: HTTPBodyConfiguration
               )
            : this(serverConfiguration, httpBodyConfiguration, {})

    override fun stop() = server.stop()

    override fun listen(host: String, port: Int) = server.headerComplete(routerManager::accept).listen(host, port)

    override fun listen(port: Int) = listen(InetAddress.getLocalHost().hostAddress, port)

    override fun listen() = server.headerComplete(routerManager::accept).listen()

    fun enableSecureConnection(): HttpServer {
        this.server.configuration.isSecureConnectionEnabled = true
        return this
    }

    /**
     * Register a router using the DSL with a autoincrement ID.
     *
     * @param block The router builder.
     */
    inline fun router(block: RouterBlock.() -> Unit) {
        val r = RouterBlock(routerManager.register(), coroutineDispatcher)
        block.invoke(r)
        sysLogger.info("register $r")
    }

    /**
     * Register a router using the DSL with a specified ID.
     *
     * @param id The router ID. The router is sorted by ID ascending order.
     * @param block The router builder.
     */
    inline fun router(id: Int, block: RouterBlock.() -> Unit) {
        val r = RouterBlock(routerManager.register(id), coroutineDispatcher)
        block.invoke(r)
        sysLogger.info("register $r")
    }

    inline fun webSocket(path: String, block: WebSocketBlock.() -> Unit) {
        val r = WebSocketBlock(server, routerManager.register(), path)
        block.invoke(r)
        sysLogger.info("register $r")
    }

    /**
     * Register routers using the DSL.
     *
     * @param block Register routers in this block.
     */
    inline fun addRouters(block: HttpServer.() -> Unit) = block.invoke(this)
}

class AccessLogService(
    logName: String = "firefly-access",
    val userTracingId: String = "_const_firefly_user_id_"
                      ) {

    private val log = LoggerFactory.getLogger(logName)

    fun recordAccessLog(ctx: RoutingContext, startTime: Long, endTime: Long) {
        val accessLog = toAccessLog(ctx, endTime - startTime)
        log.info(`$`.json.toJson(accessLog))
    }

    fun toAccessLog(ctx: RoutingContext, time: Long): AccessLog {
        val requestBody = if (ctx.method == HttpMethod.POST.asString() || ctx.method == HttpMethod.PUT.asString()) {
            val contentType = ctx.fields[HttpHeader.CONTENT_TYPE.asString()]
            if (`$`.string.hasText(contentType)) {
                if (contentType.startsWith("application/json") || contentType.startsWith("application/x-www-form-urlencoded")) {
                    ctx.stringBody
                } else ""
            } else ""
        } else ""
        val tid = ctx.getAttr(userTracingId)
        return AccessLog(
            ctx.getRealClientIp(),
            ctx.method,
            ctx.uri.toString(),
            requestBody,
            time,
            ctx.response.status,
            tid
                        )
    }

}

@NoArg
data class AccessLog(
    var ip: String,
    var method: String,
    var uri: String,
    var requestBody: String?,
    var time: Long,
    var responseStatus: Int,
    var tid: String?
                    )




© 2015 - 2024 Weber Informatics LLC | Privacy Policy