ws.osiris.core.Api.kt Maven / Gradle / Ivy
The newest version!
package ws.osiris.core
import org.slf4j.LoggerFactory
import java.util.regex.Pattern
import kotlin.reflect.KClass
private val log = LoggerFactory.getLogger("ws.osiris.core")
/**
* The MIME types that are treated as binary by default.
*
* The user can specify additional types that should be treated as binary using `binaryMimeTypes` in
* the API definition.
*/
val STANDARD_BINARY_MIME_TYPES = setOf(
"application/octet-steam",
"image/png",
"image/apng",
"image/webp",
"image/jpeg",
"image/gif",
"audio/mpeg",
"video/mpeg",
"application/pdf",
"multipart/form-data"
)
/**
* A model describing an API; it contains the routes to the API endpoints and the code executed
* when the API receives requests.
*/
data class Api(
/**
* The routes defined by the API.
*
* Each route consists of:
* * An HTTP method
* * A path
* * The code that is executed when a request is received for the endpoint
* * The authorisation required to invoke the endpoint.
*/
val routes: List>,
/** Filters applied to requests before they are passed to a handler. */
val filters: List>,
/**
* The type of the object available to the code in the API definition that handles the HTTP requests.
*
* The code in the API definition runs with the components provider class as the implicit receiver
* and can directly access its properties and methods.
*
* For example, if a data store is needed to handle a request, it would be provided by
* the `ComponentsProvider` implementation:
*
* class Components(val dataStore: DataStore) : ComponentsProvider
*
* ...
*
* get("/orders/{orderId}") { req ->
* val orderId = req.pathParams["orderId"]
* dataStore.loadOrderDetails(orderId)
* }
*/
val componentsClass: KClass,
/** True if this API serves static files. */
val staticFiles: Boolean,
/** The MIME types that are treated by API Gateway as binary; these are encoded in the JSON using Base64. */
val binaryMimeTypes: Set
) {
companion object {
/**
* Merges multiple APIs into a single API.
*
* The APIs must not defines any endpoints with the same path and method.
*
* This is intended to allow large APIs to be defined across multiple files:
*
* val api1 = api {
* get("/bar") {
* ...
* }
* }
*
* val api2 = api {
* get("/baz") {
* ...
* }
* }
*
* // An API containing all the endpoints and filters from `api1` and `api2`
* val api = Api.merge(api1, api2)
*/
inline fun merge(api1: Api, api2: Api, vararg rest: Api): Api {
val apis = listOf(api1, api2) + rest
val routes = apis.map { it.routes }.reduce { allRoutes, apiRoutes -> allRoutes + apiRoutes }
val filters = apis.map { it.filters }.reduce { allFilters, apiFilters -> allFilters + apiFilters }
val staticFiles = apis.map { it.staticFiles }.reduce { sf1, sf2 -> sf1 || sf2}
val binaryMimeTypes = apis.flatMap { it.binaryMimeTypes }.toSet()
return Api(routes, filters, T::class, staticFiles, binaryMimeTypes)
}
}
}
/**
* This function is the entry point to the DSL used for defining an API.
*
* It is used to populate a top-level property named `api`. For example
*
* val api = api {
* get("/foo") { req ->
* ...
* }
* }
*
* The type parameter is the type of the implicit receiver of the handler code. This means the properties and
* functions of that type are available to be used by the handler code. See [ComponentsProvider] for details.
*/
inline fun api(cors: Boolean = false, body: RootApiBuilder.() -> Unit): Api {
// This needs to be local because this function is inline and can only access public members of this package
val log = LoggerFactory.getLogger("ws.osiris.core")
log.debug("Creating the Api")
val builder = RootApiBuilder(T::class, cors)
log.debug("Running the RootApiBuilder")
builder.body()
log.debug("Building the Api from the builder")
val api = buildApi(builder)
log.debug("Created the Api")
return api
}
/**
* The type of the lambdas in the DSL containing the code that runs when a request is received.
*/
internal typealias Handler = T.(Request) -> Any
// This causes the compiler to crash
//typealias FilterHandler = T.(Request, Handler) -> Any
// This is equivalent to the line above but doesn't make the compiler crash
internal typealias FilterHandler = T.(Request, T.(Request) -> Response) -> Any
/**
* The type of lambda in the DSL passed to the `cors` function.
*
* This lambda receives a request (for any endpoint where `cors = true`) and populates the fields
* used to build the CORS headers.
*/
internal typealias CorsHandler = CorsHeadersBuilder.(Request) -> Unit
/**
* The type of the handler passed to a [Filter].
*
* Handlers and filters can return objects of any type (see [Handler]). If the returned value is
* not a [Response] the library wraps it in a `Response` before returning it to the caller. This
* ensures that the objects returned to a `Filter` implementation is guaranteed to be a `Response`.
*/
typealias RequestHandler = T.(Request) -> Response
/**
* Pattern matching resource paths; matches regular segments like `/foo` and variable segments like `/{foo}` and
* any combination of the two.
*/
internal val pathPattern = Pattern.compile("/|(?:(?:/[a-zA-Z0-9_\\-~.()]+)|(?:/\\{[a-zA-Z0-9_\\-~.()]+}))+")
/**
* Marks the DSL implicit receivers to avoid scoping problems.
*/
@DslMarker
@Target(AnnotationTarget.CLASS)
internal annotation class OsirisDsl
/**
* This is an internal class that is part of the DSL implementation and should not be used by user code.
*/
@OsirisDsl
open class ApiBuilder internal constructor(
internal val componentsClass: KClass,
private val prefix: String,
private val auth: Auth?,
private val cors: Boolean
) {
internal var staticFilesBuilder: StaticFilesBuilder? = null
internal var corsHandler: CorsHandler? = null
set(handler) {
if (field == null) {
field = handler
} else {
throw IllegalStateException("There must be only one cors block")
}
}
internal val routes: MutableList> = arrayListOf()
internal val filters: MutableList> = arrayListOf()
private val children: MutableList> = arrayListOf()
/** Defines an endpoint that handles GET requests to the path. */
fun get(path: String, cors: Boolean? = null, handler: Handler): Unit =
addRoute(HttpMethod.GET, path, handler, cors)
/** Defines an endpoint that handles POST requests to the path. */
fun post(path: String, cors: Boolean? = null, handler: Handler): Unit =
addRoute(HttpMethod.POST, path, handler, cors)
/** Defines an endpoint that handles PUT requests to the path. */
fun put(path: String, cors: Boolean? = null, handler: Handler): Unit =
addRoute(HttpMethod.PUT, path, handler, cors)
/** Defines an endpoint that handles UPDATE requests to the path. */
fun update(path: String, cors: Boolean? = null, handler: Handler): Unit =
addRoute(HttpMethod.UPDATE, path, handler, cors)
/** Defines an endpoint that handles OPTIONS requests to the path. */
fun options(path: String, handler: Handler): Unit =
addRoute(HttpMethod.OPTIONS, path, handler, null)
/** Defines an endpoint that handles PATCH requests to the path. */
fun patch(path: String, cors: Boolean? = null, handler: Handler): Unit =
addRoute(HttpMethod.PATCH, path, handler, cors)
/** Defines an endpoint that handles DELETE requests to the path. */
fun delete(path: String, cors: Boolean? = null, handler: Handler): Unit =
addRoute(HttpMethod.DELETE, path, handler, cors)
fun filter(path: String, handler: FilterHandler): Unit {
filters.add(Filter(prefix, path, handler))
}
fun filter(handler: FilterHandler): Unit = filter("/*", handler)
fun path(path: String, cors: Boolean? = null, body: ApiBuilder.() -> Unit) {
val child = ApiBuilder(componentsClass, prefix + path, auth, cors ?: this.cors)
children.add(child)
child.body()
}
fun auth(auth: Auth, body: ApiBuilder.() -> Unit) {
// not sure about this. the alternative is to allow nesting and the inner block applies.
// this seems misleading to me. common sense says that nesting an endpoint inside multiple
// auth blocks means it would be subject to multiple auth strategies. which isn't true
// and wouldn't make sense.
// if I did that then the auth fields could be non-nullable and default to None
if (this.auth != null) throw IllegalStateException("auth blocks cannot be nested")
val child = ApiBuilder(componentsClass, prefix, auth, cors)
children.add(child)
child.body()
}
fun staticFiles(body: StaticFilesBuilder.() -> Unit) {
val staticFilesBuilder = StaticFilesBuilder(prefix, auth)
staticFilesBuilder.body()
this.staticFilesBuilder = staticFilesBuilder
}
fun cors(corsHandler: CorsHandler) {
this.corsHandler = corsHandler
}
//--------------------------------------------------------------------------------------------------
private fun addRoute(method: HttpMethod, path: String, handler: Handler, routeCors: Boolean?) {
val cors = routeCors ?: cors
routes.add(LambdaRoute(method, prefix + path, requestHandler(handler), auth ?: NoAuth, cors))
}
internal fun descendants(): List> = children + children.flatMap { it.descendants() }
companion object {
private fun requestHandler(handler: Handler): RequestHandler = { req ->
val returnVal = handler(this, req)
returnVal as? Response ?: req.responseBuilder().build(returnVal)
}
}
}
@OsirisDsl
class RootApiBuilder internal constructor(
componentsClass: KClass,
prefix: String,
auth: Auth?,
cors: Boolean
) : ApiBuilder(componentsClass, prefix, auth, cors) {
constructor(componentsType: KClass, cors: Boolean) : this(componentsType, "", null, cors)
var globalFilters: List> = StandardFilters.create()
var binaryMimeTypes: Set? = null
set(value) {
if (field != null) {
throw IllegalStateException("Binary MIME types must only be set once. Current values: $binaryMimeTypes")
}
field = value
}
get() = field?.let { it + STANDARD_BINARY_MIME_TYPES } ?: STANDARD_BINARY_MIME_TYPES
/**
* Returns the static files configuration.
*
* This can be specified in any `ApiBuilder` in the API definition, but it must only be specified once.
*/
internal fun effectiveStaticFiles(): StaticFiles? {
val allStaticFiles = descendants().map { it.staticFilesBuilder } + staticFilesBuilder
val nonNullStaticFiles = allStaticFiles.filterNotNull()
if (nonNullStaticFiles.size > 1) {
throw IllegalArgumentException("staticFiles must only be specified once")
}
return nonNullStaticFiles.firstOrNull()?.build()
}
}
/**
* Builds the API defined by the builder.
*
* This function is an implementation detail and not intended to be called by users.
*/
fun buildApi(builder: RootApiBuilder): Api {
val filters = builder.globalFilters + builder.filters + builder.descendants().flatMap { it.filters }
// TODO validate that there's a CORS handler if anything has cors = true?
val corsHandler = builder.corsHandler
val corsFilters = if (corsHandler != null) listOf(corsFilter(corsHandler)) + filters else filters
val lambdaRoutes = builder.routes + builder.descendants().flatMap { it.routes }
val allLambdaRoutes = addOptionsMethods(lambdaRoutes)
// TODO the explicit type is needed to make type inference work. remove in future
val wrappedRoutes: List> = allLambdaRoutes.map { if (it.cors) it.wrap(corsFilters) else it.wrap(filters) }
val effectiveStaticFiles = builder.effectiveStaticFiles()
val allRoutes = when (effectiveStaticFiles) {
null -> wrappedRoutes
else -> wrappedRoutes + StaticRoute(
effectiveStaticFiles.path,
effectiveStaticFiles.indexFile,
effectiveStaticFiles.auth
)
}
if (effectiveStaticFiles != null && !staticFilesPattern.matcher(effectiveStaticFiles.path).matches()) {
throw IllegalArgumentException("Static files path is illegal: $effectiveStaticFiles")
}
val authTypes = allRoutes.map { it.auth }.filter { it != NoAuth }.toSet()
if (authTypes.size > 1) throw IllegalArgumentException("Only one auth type is supported but found $authTypes")
val binaryMimeTypes = builder.binaryMimeTypes ?: setOf()
return Api(allRoutes, filters, builder.componentsClass, effectiveStaticFiles != null, binaryMimeTypes)
}
private fun addOptionsMethods(routes: List>): List> {
// group the routes by path, ignoring any paths with no CORS routes and any that already have an OPTIONS method
val routesByPath = routes.groupBy { it.path }
.filterValues { pathRoutes -> pathRoutes.any { it.cors } }
.filterValues { pathRoutes -> pathRoutes.none { it.method == HttpMethod.OPTIONS } }
// the options method provides a CORS response for all other methods for the same path.
// if they all have the same auth then the OPTIONS method should probably have it too.
// if they don't all have the same auth then it's impossible to say what the OPTIONS auth
// should be. NoAuth seems reasonable. an OPTIONS request should be harmless.
val authByPath = routesByPath
.mapValues { (_, pathRoutes) -> pathRoutes.map { it.auth }.filter { it != NoAuth }.toSet() }
.mapValues { (_, authSet) -> if (authSet.size == 1) authSet.first() else NoAuth }
// the default handler added for OPTIONS methods doesn't do anything exception build the response.
// the response builder will have been initialised with the CORS headers so this will build a
// CORS-compliant response
val optionsHandler: RequestHandler = { req -> req.responseBuilder().build() }
val optionsRoutes = authByPath.map { (path, auth) -> LambdaRoute(HttpMethod.OPTIONS, path, optionsHandler, auth, true) }
log.debug("Adding routes for OPTIONS methods: {}", optionsRoutes)
return routes + optionsRoutes
}
/**
* Returns a filter that passes the request to the [corsHandler] and adds the returned headers to the default
* response headers.
*
* This filter is used as the first filter for any endpoint where `cors=true`.
*/
private fun corsFilter(corsHandler: CorsHandler): Filter =
defineFilter { req, handler ->
val corsHeadersBuilder = CorsHeadersBuilder()
corsHandler(corsHeadersBuilder, req)
val corHeaders = corsHeadersBuilder.build()
val defaultResponseHeaders = req.defaultResponseHeaders + corHeaders.headerMap
val newReq = req.copy(defaultResponseHeaders = defaultResponseHeaders)
handler(newReq)
}
private val staticFilesPattern = Pattern.compile("/|(?:/[a-zA-Z0-9_\\-~.()]+)+")
class CorsHeadersBuilder {
var allowMethods: Set? = null
var allowHeaders: Set? = null
var allowOrigin: Set? = null
internal fun build(): Headers {
val headerMap = mapOf(
"Access-Control-Allow-Methods" to allowMethods?.joinToString(","),
"Access-Control-Allow-Headers" to allowHeaders?.joinToString(","),
"Access-Control-Allow-Origin" to allowOrigin?.joinToString(",")
)
@Suppress("UNCHECKED_CAST")
return Headers(headerMap.filterValues { it != null } as Map)
}
}
class StaticFilesBuilder(
private val prefix: String,
private val auth: Auth?
) {
var path: String? = null
var indexFile: String? = null
internal fun build(): StaticFiles {
val localPath = path ?: throw IllegalArgumentException("Must specify the static files path")
return StaticFiles(prefix + localPath, indexFile, auth ?: NoAuth)
}
}
data class StaticFiles internal constructor(val path: String, val indexFile: String?, val auth: Auth)
/**
* Provides all the components used by the implementation of the API.
*
* The code in the API runs with the components provider class as the implicit receiver
* and can directly access its properties and methods.
*
* For example, if a data store is needed to handle a request, it would be provided by
* the `ComponentsProvider` implementation:
*
* class Components(val dataStore: DataStore) : ComponentsProvider
*
* ...
*
* get("/orders/{orderId}") { req ->
* val orderId = req.pathParams["orderId"]
* dataStore.loadOrderDetails(orderId)
* }
*/
@OsirisDsl
interface ComponentsProvider
/**
* The authorisation strategy that should be applied to an endpoint in the API.
*/
interface Auth {
/** The name of the authorisation strategy. */
val name: String
}
/**
* Authorisation strategy that allows anyone to call an endpoint in the API without authenticating.
*/
object NoAuth : Auth {
override val name: String = "NONE"
}
/**
* Base class for exceptions that should be automatically mapped to HTTP status codes.
*/
abstract class HttpException(val httpStatus: Int, message: String) : RuntimeException(message)
/**
* Exception indicating that data could not be found at the specified location.
*
* This is thrown when a path doesn't match a resource and is translated to a
* status of 404 (not found) by default.
*/
class DataNotFoundException(message: String = "Not Found") : HttpException(404, message)
/**
* Exception indicating the caller is not authorised to access the resource.
*
* This is translated to a status of 403 (forbidden) by default.
*/
class ForbiddenException(message: String = "Forbidden") : HttpException(403, message)