All Downloads are FREE. Search and download functionalities are using the official Maven repository.

next.plugins.websocket.scala Maven / Gradle / Ivy

The newest version!

import akka.Done
import akka.util.ByteString
import com.arakelian.jq.{ImmutableJqLibrary, ImmutableJqRequest}
import com.networknt.schema.SpecVersion.VersionFlag
import com.networknt.schema.{InputFormat, JsonSchemaFactory, PathType, SchemaValidatorsConfig}
import io.otoroshi.wasm4s.scaladsl.WasmFunctionParameters
import otoroshi.env.Env
import otoroshi.gateway.Errors
import otoroshi.utils.JsonPathValidator
import otoroshi.utils.syntax.implicits._
import otoroshi.wasm.WasmConfig
import play.api.Logger
import play.api.http.websocket.CloseCodes
import play.api.libs.json._
import play.api.mvc.Results

import java.nio.charset.StandardCharsets
import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters.asScalaBufferConverter
import scala.util._

sealed trait RejectStrategy {
  def json: JsValue
object RejectStrategy       {
  case object Drop  extends RejectStrategy { def json: JsValue = JsString("drop")  }
  case object Close extends RejectStrategy { def json: JsValue = JsString("close") }
  def parse(value: String): RejectStrategy = value.toLowerCase() match {
    case "drop"  => Drop
    case "close" => Close
    case _       => Drop

  def read(json: JsValue): RejectStrategy = json

sealed trait FrameFormat {
  def json: JsValue
object FrameFormat       {
  case object All    extends FrameFormat { def json: JsValue = JsString("all")    }
  case object Binary extends FrameFormat { def json: JsValue = JsString("binary") }
  case object Text   extends FrameFormat { def json: JsValue = JsString("text")   }
  case object Json   extends FrameFormat { def json: JsValue = JsString("json")   }
  def parse(value: String): FrameFormat = value.toLowerCase() match {
    case "all"    => All
    case "binary" => Binary
    case "text"   => Text
    case "json"   => Json
    case _        => All

case class WebsocketTypeValidatorConfig(
    allowedFormat: FrameFormat = FrameFormat.All,
    rejectStrategy: RejectStrategy = RejectStrategy.Drop
) extends NgPluginConfig {
  override def json: JsValue = WebsocketTypeValidatorConfig.format.writes(this)

object WebsocketTypeValidatorConfig {
  val default = WebsocketTypeValidatorConfig()
  val format  = new Format[WebsocketTypeValidatorConfig] {
    override def writes(o: WebsocketTypeValidatorConfig): JsValue = Json.obj(
      "allowed_format"  -> o.allowedFormat.json,
      "reject_strategy" -> o.rejectStrategy.json
    override def reads(json: JsValue): JsResult[WebsocketTypeValidatorConfig] = {
      Try {
          allowedFormat = json
          rejectStrategy =
      } match {
        case Failure(e) => JsError(e.getMessage)
        case Success(s) => JsSuccess(s)

case class FrameFormatValidatorConfig(
    validator: Option[JsonPathValidator] = None,
    rejectStrategy: RejectStrategy = RejectStrategy.Drop
) extends NgPluginConfig {
  override def json: JsValue = FrameFormatValidatorConfig.format.writes(this)

object FrameFormatValidatorConfig {
  val default = FrameFormatValidatorConfig(validator = Some(JsonPathValidator("$.message", JsString("foo"), None)))
  val format  = new Format[FrameFormatValidatorConfig] {
    override def writes(o: FrameFormatValidatorConfig): JsValue = Json.obj(
      "validator"       ->,
      "reject_strategy" -> o.rejectStrategy.json
    override def reads(json: JsValue): JsResult[FrameFormatValidatorConfig] = {
      Try {
          validator = (json \ "validator")
            .flatMap(v => JsonPathValidator.format.reads(v).asOpt),
          rejectStrategy =
      } match {
        case Failure(e) => JsError(e.getMessage)
        case Success(s) => JsSuccess(s)

class WebsocketContentValidatorIn extends NgWebsocketValidatorPlugin {

  override def multiInstance: Boolean                      = true
  override def defaultConfigObject: Option[NgPluginConfig] = Some(FrameFormatValidatorConfig.default)
  override def core: Boolean                               = false
  override def name: String                                = "Websocket content validator in"
  override def description: Option[String]                 = "Validate the content of each frame".some
  override def visibility: NgPluginVisibility              = NgPluginVisibility.NgUserLand
  override def categories: Seq[NgPluginCategory]           = Seq(NgPluginCategory.Websocket)
  override def steps: Seq[NgStep]                          = Seq(NgStep.ValidateAccess)

  override def onResponseFlow: Boolean = false
  override def onRequestFlow: Boolean  = true

  private def validate(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Boolean] = {
    implicit val m: Materializer = env.otoroshiMaterializer
    val config                   =

      .map(message => {
        val json = ctx.json.asObject ++ Json.obj(
          "route"   -> ctx.route.json,
          "message" -> message
        config.validator.forall(validator => validator.validate(json))

  override def onRequestMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    validate(ctx, message)
      .flatMap {
        case true  => Right(message).vfuture
        case false => Left(NgWebsocketError(CloseCodes.PolicyViolated, "failed to validate message")).vfuture

  override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = {
    val config =

class WebsocketTypeValidator extends NgWebsocketValidatorPlugin {

  override def multiInstance: Boolean                      = true
  override def defaultConfigObject: Option[NgPluginConfig] = Some(WebsocketTypeValidatorConfig.default)
  override def core: Boolean                               = false
  override def name: String                                = "Websocket type validator"
  override def description: Option[String]                 = "Validate the type of each frame".some
  override def visibility: NgPluginVisibility              = NgPluginVisibility.NgUserLand
  override def categories: Seq[NgPluginCategory]           = Seq(NgPluginCategory.Websocket)
  override def steps: Seq[NgStep]                          = Seq(NgStep.ValidateAccess)

  override def onResponseFlow: Boolean = false
  override def onRequestFlow: Boolean  = true

  override def onRequestMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    implicit val m: Materializer = env.otoroshiMaterializer

    val config =

    (config.allowedFormat match {
      case FrameFormat.Binary if !message.isBinary =>
        Left(NgWebsocketError(CloseCodes.Unacceptable, "expected binary content")).vfuture
      case FrameFormat.Text if !message.isText     =>
        Left(NgWebsocketError(CloseCodes.Unacceptable, "expected text content")).vfuture
      case FrameFormat.Text if message.isText      =>
          .flatMap(str => {
            if (!StandardCharsets.UTF_8.newEncoder().canEncode(str)) {
              Left(NgWebsocketError(CloseCodes.InconsistentData, "non-UTF-8 data within content")).vfuture
            } else {
      case FrameFormat.Json if message.isText      =>
          .map(bs => (Try(Json.parse(bs)), bs))
          .flatMap(res => {
            res._1 match {
              case Success(_) if !StandardCharsets.UTF_8.newEncoder().canEncode(res._2) =>
                Left(NgWebsocketError(CloseCodes.InconsistentData, "non-UTF-8 data within content")).vfuture
              case Failure(_)                                                           => Left(NgWebsocketError(CloseCodes.Unacceptable, "expected json content")).vfuture
              case _                                                                    => Right(message).vfuture
      case _                                       => Right(message).vfuture

  override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = {
    val config =

case class WebsocketJsonFormatValidatorConfig(
    schema: Option[String] = None,
    specification: String = VersionFlag.V202012.getId,
    rejectStrategy: RejectStrategy = RejectStrategy.Drop
) extends NgPluginConfig {
  override def json: JsValue = WebsocketJsonFormatValidatorConfig.format.writes(this)

object WebsocketJsonFormatValidatorConfig {
  val default = WebsocketJsonFormatValidatorConfig(schema = "{ \"type\": \"object\", \"required\": [\"name\"] }".some)
  val format  = new Format[WebsocketJsonFormatValidatorConfig] {
    override def writes(o: WebsocketJsonFormatValidatorConfig): JsValue = Json.obj(
      "schema"          -> o.schema,
      "specification"   -> o.specification,
      "reject_strategy" -> o.rejectStrategy.json
    override def reads(json: JsValue): JsResult[WebsocketJsonFormatValidatorConfig] = {
      Try {
          schema ="schema").asOpt[String],
          specification ="specification").asOpt[String].getOrElse(VersionFlag.V202012.getId),
          rejectStrategy =
      } match {
        case Failure(e) => JsError(e.getMessage)
        case Success(s) => JsSuccess(s)

class WebsocketJsonFormatValidator extends NgWebsocketValidatorPlugin {

  override def multiInstance: Boolean                      = true
  override def defaultConfigObject: Option[NgPluginConfig] = Some(WebsocketJsonFormatValidatorConfig.default)
  override def core: Boolean                               = false
  override def name: String                                = "Websocket json format validator"
  override def description: Option[String]                 = "Validate the json".some
  override def visibility: NgPluginVisibility              = NgPluginVisibility.NgUserLand
  override def categories: Seq[NgPluginCategory]           = Seq(NgPluginCategory.Websocket)
  override def steps: Seq[NgStep]                          = Seq(NgStep.ValidateAccess)

  override def onResponseFlow: Boolean = false
  override def onRequestFlow: Boolean  = true

  override def onRequestMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    implicit val m: Materializer = env.otoroshiMaterializer

    val config = ctx

      .map(data => {
        val userSchema = config.schema.getOrElse("")

        val jsonSchemaFactory = JsonSchemaFactory.getInstance(VersionFlag.fromId(config.specification).get())

        val schemaConfig = new SchemaValidatorsConfig()

        val schema = jsonSchemaFactory.getSchema(userSchema, schemaConfig)

        Try {
          schema.validate(data, InputFormat.JSON).isEmpty
        } recover { case _: Throwable =>
        } get
      .flatMap {
        case true  => Right(message).vfuture
        case false => Left(NgWebsocketError(CloseCodes.PolicyViolated, "failed to validate message")).vfuture

  override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = {
    val config = ctx

case class WebsocketSizeValidatorConfig(
    clientMaxPayload: Int = 4096,
    upstreamMaxPayload: Int = 4096,
    rejectStrategy: RejectStrategy = RejectStrategy.Drop
) extends NgPluginConfig {
  override def json: JsValue = WebsocketSizeValidatorConfig.format.writes(this)

object WebsocketSizeValidatorConfig {
  val default = WebsocketSizeValidatorConfig()
  val format  = new Format[WebsocketSizeValidatorConfig] {
    override def writes(o: WebsocketSizeValidatorConfig): JsValue = Json.obj(
      "client_max_payload"   -> o.clientMaxPayload,
      "upstream_max_payload" -> o.upstreamMaxPayload,
      "reject_strategy"      -> o.rejectStrategy.json
    override def reads(json: JsValue): JsResult[WebsocketSizeValidatorConfig] = {
      Try {
          clientMaxPayload ="client_max_payload").asOpt[Int].getOrElse(4096),
          upstreamMaxPayload ="upstream_max_payload").asOpt[Int].getOrElse(4096),
          rejectStrategy =
      } match {
        case Failure(e) => JsError(e.getMessage)
        case Success(s) => JsSuccess(s)

class WebsocketSizeValidator extends NgWebsocketValidatorPlugin {

  override def multiInstance: Boolean                      = true
  override def defaultConfigObject: Option[NgPluginConfig] = Some(WebsocketSizeValidatorConfig.default)
  override def core: Boolean                               = false
  override def name: String                                = "Websocket size validator"
  override def description: Option[String]                 = "Make sure the frame does not exceed the maximum size set.".some
  override def visibility: NgPluginVisibility              = NgPluginVisibility.NgUserLand
  override def categories: Seq[NgPluginCategory]           = Seq(NgPluginCategory.Websocket)
  override def steps: Seq[NgStep]                          = Seq(NgStep.ValidateAccess, NgStep.TransformResponse)

  override def onResponseFlow: Boolean = true
  override def onRequestFlow: Boolean  = true

  private def internalCanAccess(ctx: NgWebsocketPluginContext, message: WebsocketMessage, maxSize: Int, reason: String)(
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    implicit val m: Materializer = env.otoroshiMaterializer

      .map(_ <= maxSize)
      .flatMap {
        case true  => Right(message).vfuture
        case false => Left(NgWebsocketError(CloseCodes.TooBig, reason)).vfuture

  override def onRequestMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    val config =

    internalCanAccess(ctx, message, config.clientMaxPayload, "limit exceeded")

  override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = {
    val config =

  override def onResponseMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    val config =

    internalCanAccess(ctx, message, config.upstreamMaxPayload, reason = "upstream payload limit exceeded")

case class JqWebsocketMessageTransformerConfig(requestFilter: String = ".", responseFilter: String = ".")
    extends NgPluginConfig                 {
  override def json: JsValue = JqWebsocketMessageTransformerConfig.format.writes(this)
object JqWebsocketMessageTransformerConfig {
  val format = new Format[JqWebsocketMessageTransformerConfig] {
    override def reads(json: JsValue): JsResult[JqWebsocketMessageTransformerConfig] = Try {
        requestFilter ="request_filter").asOpt[String].getOrElse("."),
        responseFilter ="response_filter").asOpt[String].getOrElse(".")
    } match {
      case Failure(e) => JsError(e.getMessage)
      case Success(s) => JsSuccess(s)
    override def writes(o: JqWebsocketMessageTransformerConfig): JsValue             = Json.obj(
      "request_filter"  -> o.requestFilter,
      "response_filter" -> o.responseFilter

class JqWebsocketMessageTransformer extends NgWebsocketPlugin {

  private val library = ImmutableJqLibrary.of()
  private val logger  = Logger("otoroshi-plugins-jq-websocket")

  override def multiInstance: Boolean                      = true
  override def defaultConfigObject: Option[NgPluginConfig] = Some(JqWebsocketMessageTransformerConfig())
  override def core: Boolean                               = false
  override def name: String                                = "Websocket JQ transformer"
  override def description: Option[String]                 = "Transform messages JSON content using JQ filters".some
  override def visibility: NgPluginVisibility              = NgPluginVisibility.NgUserLand
  override def categories: Seq[NgPluginCategory]           = Seq(NgPluginCategory.Websocket)
  override def steps: Seq[NgStep]                          = Seq(NgStep.TransformResponse)

  override def onRequestFlow: Boolean                                        = true
  override def onResponseFlow: Boolean                                       = true
  override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = RejectStrategy.Drop

  override def onRequestMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    val config = ctx
    onMessage(ctx, message, config.requestFilter)

  override def onResponseMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    val config = ctx
    onMessage(ctx, message, config.responseFilter)

  def onMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage, filter: String)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    implicit val mat = env.otoroshiMaterializer
    if (message.isText) {
      message.str().flatMap { bodyStr =>
        Try(Json.parse(bodyStr)) match {
          case Failure(e) => Left(NgWebsocketError(CloseCodes.PolicyViolated, "message payload is not json")).vfuture
          case Success(_) => {
            val request  = ImmutableJqRequest
              .putArgJson("context", ctx.json.stringify)
            val response = request.execute()
            if (response.hasErrors) {
                s"error while transforming response body:\n${response.getErrors.asScala.mkString("\n")}"
              val errors = JsArray( => JsString(err)))
            } else {
              val rawBody = response.getOutput
    } else {
      Left(NgWebsocketError(CloseCodes.PolicyViolated, "message payload is not text")).vfuture

class WasmWebsocketTransformer extends NgWebsocketPlugin {

  override def multiInstance: Boolean                      = true
  override def defaultConfigObject: Option[NgPluginConfig] = Some(WasmConfig())
  override def core: Boolean                               = false
  override def name: String                                = "Wasm Websocket transformer"
  override def description: Option[String]                 = "Transform messages and filter websocket messages".some
  override def visibility: NgPluginVisibility              = NgPluginVisibility.NgUserLand
  override def categories: Seq[NgPluginCategory]           = Seq(NgPluginCategory.Websocket, NgPluginCategory.Wasm)
  override def steps: Seq[NgStep]                          = Seq(NgStep.TransformResponse)

  override def onRequestFlow: Boolean                                        = true
  override def onResponseFlow: Boolean                                       = true
  override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = RejectStrategy.Drop

  private val logger = Logger("otoroshi-plugins-wasm-websocket-transformer")

  def onMessage(
      ctx: NgWebsocketPluginContext,
      message: WebsocketMessage,
      functionName: Option[String]
  )(implicit env: Env, ec: ExecutionContext): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    implicit val mat = env.otoroshiMaterializer
    val config       = ctx
    (if (message.isText) {
       message.str().map { str =>[JsObject] ++ Json.obj(
           "message" -> Json.obj(
             "kind"    -> "text",
             "payload" -> str
     } else {
       message.bytes().map { bytes =>[JsObject] ++ Json.obj(
           "message" -> Json.obj(
             "kind"    -> "text",
             "payload" -> bytes
     }).flatMap { input =>
      env.wasmIntegration.wasmVmFor(config).flatMap {
        case None                    => Left(NgWebsocketError(500, "plugin not found !")).vfuture
        case Some((vm, localConfig)) =>

          ).flatMap {
            case Left(err)     =>
              Left(NgWebsocketError(500, err.stringify)).vfuture
            case Right(resStr) => {
              Try(Json.parse(resStr._1)) match {
                case Failure(e)        =>
                  Left(NgWebsocketError(500, Json.obj("error" -> e.getMessage).stringify)).vfuture
                case Success(response) => {
                  AttrsHelper.updateAttrs(ctx.attrs, response)
                  val error ="error").asOpt[Boolean].getOrElse(false)
                  if (error) {
                    val reason: String  ="reason").asOpt[String].getOrElse("error")
                    val statusCode: Int ="statusCode").asOpt[Int].getOrElse(500)
                    Left(NgWebsocketError(statusCode, reason)).vfuture
                  } else {
                    val msg                       ="message").asOpt[JsObject].getOrElse(Json.obj())
                    val kind                      ="kind").asOpt[String].getOrElse("text")
                    val message: WebsocketMessage = if (kind == "text") {
                      val payload ="payload").asOpt[String].getOrElse("")
                    } else {
                      val payload = msg
                        .map(bytes => ByteString(bytes))
          }.andThen { case e =>

            e match {
              case Failure(exception) => logger.error(exception.getMessage)
              case Success(_)         =>
          }.recover { case e: Throwable =>
            Left(NgWebsocketError(500, Json.obj("error" -> e.getMessage).stringify))

  override def onRequestMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    onMessage(ctx, message, "on_request_message".some)

  override def onResponseMessage(ctx: NgWebsocketPluginContext, message: WebsocketMessage)(implicit
      env: Env,
      ec: ExecutionContext
  ): Future[Either[NgWebsocketError, WebsocketMessage]] = {
    onMessage(ctx, message, "on_response_message".some)

© 2015 - 2025 Weber Informatics LLC | Privacy Policy