gateway.websockets.scala Maven / Gradle / Ivy
package otoroshi.gateway
import akka.actor.{Actor, ActorRef, PoisonPill, Props}
import akka.http.scaladsl.ClientTransport
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.model.ws.{InvalidUpgradeResponse, Message, ValidUpgrade, WebSocketRequest}
import akka.http.scaladsl.settings.ClientConnectionSettings
import akka.http.scaladsl.util.FastFuture
import akka.stream.scaladsl.{Flow, Keep, Sink, Source, SourceQueueWithComplete, Tcp}
import akka.stream.{FlowShape, Materializer, OverflowStrategy}
import akka.util.ByteString
import akka.{Done, NotUsed}
import org.joda.time.DateTime
import otoroshi.el.TargetExpressionLanguage
import otoroshi.env.Env
import otoroshi.events._
import otoroshi.models._
import otoroshi.next.models.{NgContextualPlugins, NgPluginInstance, NgRoute}
import otoroshi.next.plugins.RejectStrategy
import otoroshi.next.plugins.api.{
NgAccess,
NgPluginWrapper,
NgWebsocketError,
NgWebsocketPlugin,
NgWebsocketPluginContext,
NgWebsocketResponse,
NgWebsocketValidatorPlugin,
WebsocketMessage
}
import otoroshi.next.proxy.NgProxyEngineError
import otoroshi.next.proxy.NgProxyEngineError.NgResultProxyEngineError
import otoroshi.next.utils.FEither
import otoroshi.script.Implicits._
import otoroshi.script.TransformerRequestContext
import otoroshi.security.{IdGenerator, OtoroshiClaim}
import otoroshi.utils.future.Implicits._
import otoroshi.utils.http.RequestImplicits._
import otoroshi.utils.http.{HeadersHelper, ManualResolveTransport, WSCookieWithSameSite, WSProxyServerUtils}
import otoroshi.utils.syntax.implicits.BetterSyntax
import otoroshi.utils.udp._
import otoroshi.utils.{TypedMap, UrlSanitizer}
import play.api.Logger
import play.api.http.websocket.{
CloseMessage,
PingMessage,
PongMessage,
BinaryMessage => PlayWSBinaryMessage,
Message => PlayWSMessage,
TextMessage => PlayWSTextMessage
}
import play.api.libs.json.{JsValue, Json}
import play.api.libs.streams.ActorFlow
import play.api.mvc.Results.NotFound
import play.api.mvc._
import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.util.{Failure, Success, Try}
class WebSocketHandler()(implicit env: Env) {
type WSFlow = Flow[PlayWSMessage, PlayWSMessage, _]
implicit lazy val currentEc = env.otoroshiExecutionContext
implicit lazy val currentScheduler = env.otoroshiScheduler
implicit lazy val currentSystem = env.otoroshiActorSystem
implicit lazy val currentMaterializer = env.otoroshiMaterializer
lazy val logger = Logger("otoroshi-websocket-handler")
def forwardCall(
reverseProxyAction: ReverseProxyAction,
snowMonkey: SnowMonkey,
headersInFiltered: Seq[String],
headersOutFiltered: Seq[String]
) = {
WebSocket.acceptOrResult[PlayWSMessage, PlayWSMessage] { req =>
reverseProxyAction.async[WSFlow](
ReverseProxyActionContext(req, Source.empty, snowMonkey, logger),
true,
c => actuallyCallDownstream(c, headersInFiltered, headersOutFiltered)
)
}
}
def forwardCallRaw(
req: RequestHeader,
reverseProxyAction: ReverseProxyAction,
snowMonkey: SnowMonkey,
headersInFiltered: Seq[String],
headersOutFiltered: Seq[String]
) = {
reverseProxyAction.async[WSFlow](
ReverseProxyActionContext(req, Source.empty, snowMonkey, logger),
true,
c => actuallyCallDownstream(c, headersInFiltered, headersOutFiltered)
)
}
def actuallyCallDownstream(
ctx: ActualCallContext,
headersInFiltered: Seq[String],
headersOutFiltered: Seq[String]
): Future[Either[Result, Flow[PlayWSMessage, PlayWSMessage, _]]] = {
val ActualCallContext(
req,
descriptor,
_target,
apiKey,
paUsr,
jwtInjection,
snowMonkeyContext,
snowflake,
attrs,
elCtx,
globalConfig,
withTrackingCookies,
bodyAlreadyConsumed,
requestBody,
secondStart,
firstOverhead,
cbDuration,
callAttempts,
attempts,
alreadyFailed
) = ctx
val counterIn = attrs.get(otoroshi.plugins.Keys.RequestCounterInKey).get
val counterOut = attrs.get(otoroshi.plugins.Keys.RequestCounterOutKey).get
val canaryId = attrs.get(otoroshi.plugins.Keys.RequestCanaryIdKey).get
val callDate = attrs.get(otoroshi.plugins.Keys.RequestTimestampKey).get
val start = attrs.get(otoroshi.plugins.Keys.RequestStartKey).get
val requestTimestamp = callDate.toString("yyyy-MM-dd'T'HH:mm:ss.SSSZZ")
if (logger.isTraceEnabled) logger.trace("[WEBSOCKET] Call backend !!!")
val stateValue = IdGenerator.extendedToken(128)
val stateToken: String = descriptor.secComVersion match {
case SecComVersion.V1 => stateValue
case SecComVersion.V2 =>
OtoroshiClaim(
iss = env.Headers.OtoroshiIssuer,
sub = env.Headers.OtoroshiIssuer,
aud = descriptor.name,
exp = DateTime
.now()
.plus(descriptor.secComTtl.toMillis)
.toDate
.getTime,
iat = DateTime.now().toDate.getTime,
jti = IdGenerator.uuid
).withClaim("state", stateValue)
.serialize(descriptor.algoChallengeFromOtoToBack)
}
val rawUri = req.relativeUri.substring(1)
val uriParts = rawUri.split("/").toSeq
val uri: String = descriptor.maybeStrippedUri(req, rawUri)
// val index = reqCounter.incrementAndGet() % (if (descriptor.targets.nonEmpty) descriptor.targets.size else 1)
// // Round robin loadbalancing is happening here !!!!!
// val target = descriptor.targets.apply(index.toInt)
val scheme =
if (descriptor.redirectToLocal) descriptor.localScheme else _target.scheme
val host = if (descriptor.redirectToLocal) descriptor.localHost else _target.host
val root = descriptor.root
val url = TargetExpressionLanguage(
s"${if (_target.scheme == "https") "wss" else "ws"}://$host$root$uri",
Some(req),
Some(descriptor),
None,
apiKey,
paUsr,
elCtx,
attrs,
env
)
// val queryString = req.queryString.toSeq.flatMap { case (key, values) => values.map(v => (key, v)) }
val fromOtoroshi = req.headers
.get(env.Headers.OtoroshiRequestId)
.orElse(req.headers.get(env.Headers.OtoroshiGatewayParentRequest))
val promise = Promise[ProxyDone]
val claim = descriptor.generateInfoToken(apiKey, paUsr, Some(req))
if (logger.isTraceEnabled) logger.trace(s"Claim is : $claim")
attrs.put(otoroshi.plugins.Keys.OtoTokenKey -> claim.payload)
//val stateRequestHeaderName =
// descriptor.secComHeaders.stateRequestName.getOrElse(env.Headers.OtoroshiState)
val stateResponseHeaderName =
descriptor.secComHeaders.stateResponseName
.getOrElse(env.Headers.OtoroshiStateResp)
val headersIn: Seq[(String, String)] = HeadersHelper.composeHeadersIn(
descriptor = descriptor,
req = req,
apiKey = apiKey,
paUsr = paUsr,
elCtx = elCtx,
currentReqHasBody = false,
headersInFiltered = headersInFiltered,
snowflake = snowflake,
requestTimestamp = requestTimestamp,
host = host,
claim = claim,
stateToken = stateToken,
fromOtoroshi = fromOtoroshi,
snowMonkeyContext = SnowMonkeyContext(
Source.empty[ByteString],
Source.empty[ByteString]
),
jwtInjection = jwtInjection,
attrs = attrs
)
if (logger.isTraceEnabled)
logger.trace(
s"[WEBSOCKET] calling '$url' with headers \n ${headersIn.map(_.toString()) mkString "\n"}"
)
val overhead = System.currentTimeMillis() - start
val quotas: Future[RemainingQuotas] =
apiKey.map(_.updateQuotas()).getOrElse(FastFuture.successful(RemainingQuotas()))
promise.future.andThen {
case Success(resp) => {
val duration = System.currentTimeMillis() - start
// logger.trace(s"[$snowflake] Call forwarded in $duration ms. with $overhead ms overhead for (${req.version}, ${req.theProtocol}://${req.host}${req.relativeUri} => $url, $from)")
descriptor
.updateMetrics(duration, overhead, counterIn.get(), counterOut.get(), 0, globalConfig)
.andThen { case Failure(e) =>
logger.error("Error while updating call metrics reporting", e)
}
env.datastores.globalConfigDataStore.updateQuotas(globalConfig)
quotas.andThen {
case Success(q) => {
val fromLbl =
req.headers.get(env.Headers.OtoroshiVizFromLabel).getOrElse("internet")
val viz: OtoroshiViz = OtoroshiViz(
to = descriptor.id,
toLbl = descriptor.name,
from = req.headers.get(env.Headers.OtoroshiVizFrom).getOrElse("internet"),
fromLbl = fromLbl,
fromTo = s"$fromLbl###${descriptor.name}"
)
val evt = GatewayEvent(
`@id` = env.snowflakeGenerator.nextIdStr(),
reqId = snowflake,
parentReqId = fromOtoroshi,
`@timestamp` = DateTime.now(),
`@calledAt` = callDate,
protocol = req.version,
to = Location(
scheme = req.theWsProtocol,
host = req.theHost,
uri = req.relativeUri
),
target = Location(
scheme = scheme,
host = host,
uri = req.relativeUri
),
backendDuration = attrs.get(otoroshi.plugins.Keys.BackendDurationKey).getOrElse(-1L),
duration = duration,
overhead = overhead,
cbDuration = cbDuration,
overheadWoCb = Math.abs(overhead - cbDuration),
callAttempts = callAttempts,
url = url,
method = req.method,
from = req.theIpAddress,
env = descriptor.env,
data = DataInOut(
dataIn = counterIn.get(),
dataOut = counterOut.get()
),
status = resp.status,
headers = req.headers.toSimpleMap.toSeq.map(Header.apply),
headersOut = resp.headersOut,
otoroshiHeadersIn = resp.otoroshiHeadersIn,
otoroshiHeadersOut = resp.otoroshiHeadersOut,
extraInfos = attrs.get(otoroshi.plugins.Keys.GatewayEventExtraInfosKey),
identity = apiKey
.map(k =>
Identity(
identityType = "APIKEY",
identity = k.clientId,
label = k.clientName,
tags = k.tags,
metadata = k.metadata
)
)
.orElse(
paUsr.map(k =>
Identity(
identityType = "PRIVATEAPP",
identity = k.email,
label = k.name,
tags = k.tags,
metadata = k.metadata
)
)
),
responseChunked = false,
`@serviceId` = descriptor.id,
`@service` = descriptor.name,
descriptor = Some(descriptor),
`@product` = descriptor.metadata.getOrElse("product", "--"),
remainingQuotas = q,
viz = Some(viz),
clientCertChain = req.clientCertChainPem,
err = attrs.get(otoroshi.plugins.Keys.GwErrorKey).isDefined,
gwError = attrs.get(otoroshi.plugins.Keys.GwErrorKey).map(_.message),
userAgentInfo = attrs.get[JsValue](otoroshi.plugins.Keys.UserAgentInfoKey),
geolocationInfo = attrs.get[JsValue](otoroshi.plugins.Keys.GeolocationInfoKey),
extraAnalyticsData = attrs.get[JsValue](otoroshi.plugins.Keys.ExtraAnalyticsDataKey)
)
evt.toAnalytics()
if (descriptor.logAnalyticsOnServer) {
evt.log()(env, env.analyticsExecutionContext) // pressure EC
}
}
}(env.analyticsExecutionContext) // pressure EC
}
}(env.analyticsExecutionContext) // pressure EC
val wsCookiesIn = req.cookies.toSeq.map(c =>
WSCookieWithSameSite(
name = c.name,
value = c.value,
domain = c.domain,
path = Option(c.path),
maxAge = c.maxAge.map(_.toLong),
secure = c.secure,
httpOnly = c.httpOnly,
sameSite = c.sameSite
)
)
val rawRequest = otoroshi.script.HttpRequest(
url = s"${req.theProtocol}://${req.theHost}${req.relativeUri}",
method = req.method,
headers = req.headers.toSimpleMap,
cookies = wsCookiesIn,
version = req.version,
clientCertificateChain = req.clientCertificateChain,
target = None,
claims = claim,
body = () => requestBody
)
val otoroshiRequest = otoroshi.script.HttpRequest(
url = url,
method = req.method,
headers = headersIn.toMap,
cookies = wsCookiesIn,
version = req.version,
clientCertificateChain = req.clientCertificateChain,
target = Some(_target),
claims = claim,
body = () => requestBody
)
val upstreamStart = System.currentTimeMillis()
descriptor
.transformRequest(
TransformerRequestContext(
index = -1,
snowflake = snowflake,
rawRequest = rawRequest,
otoroshiRequest = otoroshiRequest,
descriptor = descriptor,
apikey = apiKey,
user = paUsr,
request = req,
config = descriptor.transformerConfig,
attrs = attrs
)
)
.flatMap {
case Left(badResult) => {
quotas
.map { remainingQuotas =>
val _headersOut: Seq[(String, String)] =
HeadersHelper.composeHeadersOutBadResult(
descriptor = descriptor,
req = req,
badResult = badResult,
apiKey = apiKey,
paUsr = paUsr,
elCtx = elCtx,
snowflake = snowflake,
requestTimestamp = requestTimestamp,
headersOutFiltered = headersOutFiltered,
overhead = overhead,
upstreamLatency = 0L,
canaryId = canaryId,
remainingQuotas = remainingQuotas,
attrs = attrs
)
promise.trySuccess(
ProxyDone(
badResult.header.status,
false,
0,
headersOut = badResult.header.headers.toSeq.map(Header.apply),
otoroshiHeadersOut = _headersOut.map(Header.apply),
otoroshiHeadersIn = headersIn.map(Header.apply)
)
)
badResult.withHeaders(_headersOut: _*)
}
.asLeft[WSFlow]
}
case Right(_)
if descriptor.tcpUdpTunneling && !req.relativeUri
.startsWith("/.well-known/otoroshi/tunnel") => {
Errors
.craftResponseResult(
s"Resource not found",
NotFound,
req,
None,
Some("errors.resource.not.found"),
attrs = attrs
)
.asLeft[WSFlow]
}
case Right(_httpReq)
if descriptor.tcpUdpTunneling && req.relativeUri
.startsWith("/.well-known/otoroshi/tunnel") => {
val target = _httpReq.target.getOrElse(_target)
val (theHost: String, thePort: Int) =
(
target.scheme,
TargetExpressionLanguage(target.host, Some(req), Some(descriptor), None, apiKey, paUsr, elCtx, attrs, env)
) match {
case (_, host) if host.contains(":") =>
(host.split(":").apply(0), host.split(":").apply(1).toInt)
case (scheme, host) if scheme.contains("https") => (host, 443)
case (_, host) => (host, 80)
}
val remoteAddress = target.ipAddress match {
case Some(ip) =>
new InetSocketAddress(
InetAddress.getByAddress(theHost, InetAddress.getByName(ip).getAddress),
thePort
)
case None => new InetSocketAddress(theHost, thePort)
}
req
.getQueryString("transport")
.map(_.toLowerCase())
.getOrElse("tcp") match {
case "tcp" => {
val flow: Flow[PlayWSMessage, PlayWSMessage, _] =
Flow[PlayWSMessage]
.collect {
case PlayWSBinaryMessage(data) =>
data
case _ =>
ByteString.empty
}
.via(
Tcp()
.outgoingConnection(
remoteAddress = remoteAddress,
connectTimeout = descriptor.clientConfig.connectionTimeout.millis,
idleTimeout = descriptor.clientConfig.idleTimeout.millis
)
.map(bs => PlayWSBinaryMessage(bs))
)
.alsoTo(Sink.onComplete { case _ =>
promise.trySuccess(
ProxyDone(
200,
false,
0,
Seq.empty[Header],
Seq.empty[Header],
Seq.empty[Header]
)
)
})
FastFuture.successful(Right(flow))
}
case "udp-old" => {
val flow: Flow[PlayWSMessage, PlayWSMessage, _] =
Flow[PlayWSMessage]
.collect {
case PlayWSBinaryMessage(data) =>
Datagram(data, remoteAddress)
case _ =>
Datagram(ByteString.empty, remoteAddress)
}
.via(
UdpClient
.flow(new InetSocketAddress("0.0.0.0", 0))
.map(dg => PlayWSBinaryMessage(dg.data))
)
.alsoTo(Sink.onComplete { case _ =>
promise.trySuccess(
ProxyDone(
200,
false,
0,
Seq.empty[Header],
Seq.empty[Header],
Seq.empty[Header]
)
)
})
FastFuture.successful(Right(flow))
}
case "udp" => {
import akka.stream.scaladsl.{Flow, GraphDSL, UnzipWith, ZipWith}
import GraphDSL.Implicits._
val base64decoder = java.util.Base64.getDecoder
val base64encoder = java.util.Base64.getEncoder
val fromJson: Flow[PlayWSMessage, (Int, String, Datagram), NotUsed] =
Flow[PlayWSMessage].collect {
case PlayWSBinaryMessage(data) =>
val json = Json.parse(data.utf8String)
val port: Int = (json \ "port").as[Int]
val address: String = (json \ "address").as[String]
val _data: ByteString = (json \ "data")
.asOpt[String]
.map(str => ByteString(base64decoder.decode(str)))
.getOrElse(ByteString.empty)
(port, address, Datagram(_data, remoteAddress))
case _ =>
(0, "localhost", Datagram(ByteString.empty, remoteAddress))
}
val updFlow: Flow[Datagram, Datagram, Future[InetSocketAddress]] =
UdpClient
.flow(new InetSocketAddress("0.0.0.0", 0))
def nothing[T]: Flow[T, T, NotUsed] = Flow[T].map(identity)
val flow: Flow[PlayWSMessage, PlayWSBinaryMessage, NotUsed] = fromJson via Flow
.fromGraph(GraphDSL.create() { implicit builder =>
val dispatch = builder.add(
UnzipWith[(Int, String, Datagram), Int, String, Datagram](a => a)
)
val merge = builder.add(
ZipWith[Int, String, Datagram, (Int, String, Datagram)]((a, b, c) => (a, b, c))
)
dispatch.out2 ~> updFlow.async ~> merge.in2
dispatch.out1 ~> nothing[String].async ~> merge.in1
dispatch.out0 ~> nothing[Int].async ~> merge.in0
FlowShape(dispatch.in, merge.out)
})
.map { case (port, address, dg) =>
PlayWSBinaryMessage(
ByteString(
Json.stringify(
Json.obj(
"port" -> port,
"address" -> address,
"data" -> base64encoder.encodeToString(dg.data.toArray)
)
)
)
)
}
.alsoTo(Sink.onComplete { case _ =>
promise.trySuccess(
ProxyDone(
200,
false,
0,
Seq.empty[Header],
Seq.empty[Header],
Seq.empty[Header]
)
)
})
FastFuture.successful(Right(flow))
}
}
}
case Right(httpRequest) if !descriptor.tcpUdpTunneling => {
if (descriptor.useNewWSClient) {
FastFuture.successful(
Right(
WebSocketProxyActor.wsCall(
UrlSanitizer.sanitize(httpRequest.url),
// httpRequest.headers.toSeq, //.filterNot(_._1 == "Cookie"),
HeadersHelper
.addClaims(httpRequest.headers, httpRequest.claims, descriptor),
descriptor,
target = httpRequest.target.getOrElse(_target),
rawRequest = req
)
)
)
} else {
attrs.put(
otoroshi.plugins.Keys.RequestTargetKey -> httpRequest.target
.getOrElse(_target)
)
FastFuture.successful(
Right(
ActorFlow
.actorRef(out =>
WebSocketProxyActor.props(
UrlSanitizer.sanitize(httpRequest.url),
out,
httpRequest.headers.toSeq, //.filterNot(_._1 == "Cookie"),
req,
descriptor,
None, // TODO - check if we can pass the current route
None,
httpRequest.target.getOrElse(_target),
attrs,
env
)
)
.alsoTo(Sink.onComplete { case _ =>
promise.trySuccess(
ProxyDone(
200,
false,
0,
headersOut = Seq.empty[Header],
otoroshiHeadersOut = Seq.empty[Header],
otoroshiHeadersIn = req.headers.toSimpleMap.map(Header.apply).toSeq
)
)
})
)
)
}
}
}
}
}
object WebSocketProxyActor {
lazy val logger = Logger("otoroshi-websocket")
def props(
url: String,
out: ActorRef,
headers: Seq[(String, String)],
rawRequest: RequestHeader,
descriptor: ServiceDescriptor,
route: Option[NgRoute],
ctxPlugins: Option[NgContextualPlugins],
target: Target,
attrs: TypedMap,
env: Env
) =
Props(new WebSocketProxyActor(url, out, headers, rawRequest, descriptor, route, ctxPlugins, target, attrs, env))
def wsCall(
url: String,
headers: Seq[(String, String)],
descriptor: ServiceDescriptor,
target: Target,
rawRequest: RequestHeader,
route: Option[NgRoute] = None
)(implicit
env: Env,
ec: ExecutionContext,
mat: Materializer
): Flow[PlayWSMessage, PlayWSMessage, _] = {
val avoid = Seq("Upgrade", "Connection", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key")
val _headers = headers.toList.filterNot(t => avoid.contains(t._1)).flatMap {
// case (key, value) if key.toLowerCase == "cookie" =>
// Try(value.split(";").toSeq.map(_.trim).filterNot(_.isEmpty).map { cookie =>
// val parts = cookie.split("=")
// val name = parts(0)
// val cookieValue = parts.tail.mkString("=")
// akka.http.scaladsl.model.headers.Cookie(name, cookieValue)
// }) match {
// case Success(seq) => seq
// case Failure(e) => List.empty
// }
case (key, value) if key.toLowerCase == "host" =>
val part = value.split(":")
Seq(akka.http.scaladsl.model.headers.Host(part.head))
case (key, value) if key.toLowerCase == "user-agent" =>
Seq(akka.http.scaladsl.model.headers.`User-Agent`(value))
case (key, value) =>
Seq(RawHeader(key, value))
}
val request = _headers.foldLeft[WebSocketRequest](WebSocketRequest(url))((r, header) =>
r.copy(extraHeaders = r.extraHeaders :+ header)
)
// WARN: DOES NOT MAKE USE OF WS PLUGINS BECAUSE OF THE LIMITS OF THE AKKA STREAM SINK API
val flow = Flow.fromSinkAndSourceMat(
Sink.asPublisher[akka.http.scaladsl.model.ws.Message](fanout = false),
Source.asSubscriber[akka.http.scaladsl.model.ws.Message]
)(Keep.both)
val (connected, (publisher, subscriber)) = env.gatewayClient.ws(
request = request,
targetOpt = Some(target),
mtlsConfigOpt = Some(target.mtlsConfig).filter(_.mtls),
clientFlow = flow,
customizer = {
descriptor.clientConfig.proxy
.orElse(env.datastores.globalConfigDataStore.latestSafe.flatMap(_.proxies.services))
.filter(p =>
WSProxyServerUtils
.isIgnoredForHost(Uri(url).authority.host.toString(), p.nonProxyHosts.getOrElse(Seq.empty))
)
.map { proxySettings =>
val proxyAddress = InetSocketAddress.createUnresolved(proxySettings.host, proxySettings.port)
val httpsProxyTransport = (proxySettings.principal, proxySettings.password) match {
case (Some(principal), Some(password)) => {
val auth = akka.http.scaladsl.model.headers.BasicHttpCredentials(principal, password)
ClientTransport.httpsProxy(proxyAddress, auth)
}
case _ => ClientTransport.httpsProxy(proxyAddress)
}
a: ClientConnectionSettings =>
a.withTransport(httpsProxyTransport)
.withIdleTimeout(descriptor.clientConfig.idleTimeout.millis)
.withConnectingTimeout(descriptor.clientConfig.connectionTimeout.millis)
} getOrElse { a: ClientConnectionSettings =>
val maybeIpAddress = target.ipAddress.map(addr => InetSocketAddress.createUnresolved(addr, target.thePort))
if (env.manualDnsResolve && maybeIpAddress.isDefined) {
a.withTransport(ManualResolveTransport.resolveTo(maybeIpAddress.get))
.withIdleTimeout(descriptor.clientConfig.idleTimeout.millis)
.withConnectingTimeout(descriptor.clientConfig.connectionTimeout.millis)
} else {
a.withIdleTimeout(descriptor.clientConfig.idleTimeout.millis)
.withConnectingTimeout(descriptor.clientConfig.connectionTimeout.millis)
}
}
}
)
Flow.lazyFutureFlow[PlayWSMessage, PlayWSMessage, Any] { () =>
connected.flatMap { r =>
if (logger.isTraceEnabled)
logger.trace(
s"[WEBSOCKET] connected to target ${r.response.status} :: ${r.response.headers.map(h => h.toString()).mkString(", ")}"
)
r match {
case ValidUpgrade(response, chosenSubprotocol) =>
val f: Flow[PlayWSMessage, PlayWSMessage, NotUsed] = Flow.fromSinkAndSource(
Sink.fromSubscriber(subscriber).contramap {
case PlayWSTextMessage(text) => akka.http.scaladsl.model.ws.TextMessage(text)
case PlayWSBinaryMessage(data) => akka.http.scaladsl.model.ws.BinaryMessage(data)
case PingMessage(data) => akka.http.scaladsl.model.ws.BinaryMessage(data)
case PongMessage(data) => akka.http.scaladsl.model.ws.BinaryMessage(data)
case CloseMessage(status, reason) =>
logger.error(s"close message $status: $reason")
akka.http.scaladsl.model.ws.BinaryMessage(ByteString.empty)
// throw new RuntimeException(reason)
case m =>
logger.error(s"Unknown message $m")
throw new RuntimeException(s"Unknown message $m")
},
Source.fromPublisher(publisher).mapAsync(1) {
case akka.http.scaladsl.model.ws.TextMessage.Strict(text) =>
FastFuture.successful(PlayWSTextMessage(text))
case akka.http.scaladsl.model.ws.TextMessage.Streamed(source) =>
source.runFold("")((concat, str) => concat + str).map(str => PlayWSTextMessage(str))
case akka.http.scaladsl.model.ws.BinaryMessage.Strict(data) =>
FastFuture.successful(PlayWSBinaryMessage(data))
case akka.http.scaladsl.model.ws.BinaryMessage.Streamed(source) =>
source
.runFold(ByteString.empty)((concat, str) => concat ++ str)
.map(data => PlayWSBinaryMessage(data))
case other => FastFuture.failed(new RuntimeException(s"Unkown message type ${other}"))
}
)
FastFuture.successful(f)
case InvalidUpgradeResponse(response, cause) =>
FastFuture.failed(new RuntimeException(cause))
}
}
}
}
}
class WebSocketProxyActor(
url: String,
out: ActorRef,
headers: Seq[(String, String)],
rawRequest: RequestHeader,
descriptor: ServiceDescriptor,
route: Option[NgRoute],
ctxPlugins: Option[NgContextualPlugins],
target: Target,
attrs: TypedMap,
env: Env
) extends Actor {
import scala.concurrent.duration._
implicit val ec = env.otoroshiExecutionContext
implicit val mat = env.otoroshiMaterializer
implicit val e = env
lazy val source = Source.queue[akka.http.scaladsl.model.ws.Message](50000, OverflowStrategy.dropTail)
lazy val logger = Logger("otoroshi-websocket-handler-actor")
val queueRef = new AtomicReference[SourceQueueWithComplete[akka.http.scaladsl.model.ws.Message]]
val avoid = Seq("Upgrade", "Connection", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key")
// Seq("Upgrade", "Connection", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions", "Host")
val wsEngine = if (route.isDefined && ctxPlugins.isDefined && ctxPlugins.get.hasWebsocketPlugins) {
new WebsocketEngine(route.get, ctxPlugins.get, rawRequest, target, attrs)
} else {
new WebsocketEngine(NgRoute.empty, NgContextualPlugins.empty(rawRequest), rawRequest, target, attrs)
}
override def preStart() =
try {
if (logger.isTraceEnabled) logger.trace("[WEBSOCKET] initializing client call ...")
val _headers = headers.toList.filterNot(t => avoid.contains(t._1)).flatMap {
//case (key, value) if key.toLowerCase == "cookie" =>
// Try(value.split(";").toSeq.map(_.trim).filterNot(_.isEmpty).map { cookie =>
// val parts = cookie.split("=")
// val name = parts(0)
// val cookieValue = parts.tail.mkString("=")
// akka.http.scaladsl.model.headers.Cookie(name, cookieValue)
// }) match {
// case Success(seq) => seq
// case Failure(e) => List.empty
// }
case (key, value) if key.toLowerCase == "host" =>
Seq(akka.http.scaladsl.model.headers.Host(value.split(":").head))
case (key, value) if key.toLowerCase == "user-agent" =>
Seq(akka.http.scaladsl.model.headers.`User-Agent`(value))
case (key, value) =>
Seq(RawHeader(key, value))
}
val request = _headers.foldLeft[WebSocketRequest](WebSocketRequest(url))((r, header) =>
r.copy(extraHeaders = r.extraHeaders :+ header)
)
val (connected, materialized) = env.gatewayClient.ws(
request = request,
targetOpt = Some(target),
mtlsConfigOpt = Some(target.mtlsConfig).filter(_.mtls),
clientFlow = Flow
.fromSinkAndSourceMat(
Sink.foreach[akka.http.scaladsl.model.ws.Message] { data =>
{
wsEngine
.handleResponse(data)(closeFunction)
.map {
case Left(error) =>
Option(queueRef.get()).foreach(_.complete())
case Right(msg) => {
msg.asPlay.map { msg =>
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] message from target: ${msg}")
out ! msg
}
}
}
}
},
source
)(Keep.both)
.alsoTo(Sink.onComplete { _ =>
if (logger.isTraceEnabled) logger.trace(s"[WEBSOCKET] target stopped")
Option(queueRef.get()).foreach(_.complete())
// out ! PoisonPill
self ! PoisonPill
}),
customizer = descriptor.clientConfig.proxy
.orElse(env.datastores.globalConfigDataStore.latestSafe.flatMap(_.proxies.services))
.filter(p =>
WSProxyServerUtils
.isIgnoredForHost(Uri(url).authority.host.toString(), p.nonProxyHosts.getOrElse(Seq.empty))
)
.map { proxySettings =>
val proxyAddress = InetSocketAddress.createUnresolved(proxySettings.host, proxySettings.port)
val httpsProxyTransport = (proxySettings.principal, proxySettings.password) match {
case (Some(principal), Some(password)) => {
val auth = akka.http.scaladsl.model.headers.BasicHttpCredentials(principal, password)
ClientTransport.httpsProxy(proxyAddress, auth)
}
case _ => ClientTransport.httpsProxy(proxyAddress)
}
// TODO: use proxy transport when akka http will be updated
a: ClientConnectionSettings =>
//a //.withTransport(httpsProxyTransport)
a.withIdleTimeout(descriptor.clientConfig.idleTimeout.millis)
.withConnectingTimeout(descriptor.clientConfig.connectionTimeout.millis)
} getOrElse { a: ClientConnectionSettings =>
a.withIdleTimeout(descriptor.clientConfig.idleTimeout.millis)
.withConnectingTimeout(descriptor.clientConfig.connectionTimeout.millis)
}
)
queueRef.set(materialized._2)
connected.andThen {
case Success(r) => {
implicit val ec = env.otoroshiExecutionContext
implicit val mat = env.otoroshiMaterializer
if (logger.isTraceEnabled)
logger.trace(
s"[WEBSOCKET] connected to target ${r.response.status} :: ${r.response.headers.map(h => h.toString()).mkString(", ")}"
)
r.response.entity.dataBytes.runFold(ByteString.empty)(_ ++ _).map { bs =>
if (logger.isTraceEnabled) logger.trace(s"[WEBSOCKET] connected to target with response '${bs.utf8String}'")
}
}
case Failure(e) => logger.error(s"[WEBSOCKET] error", e)
}(context.dispatcher)
} catch {
case e: Exception => logger.error("[WEBSOCKET] error during call", e)
}
override def postStop() = {
if (logger.isTraceEnabled) logger.trace(s"[WEBSOCKET] client stopped")
Option(queueRef.get()).foreach(_.complete())
// out ! PoisonPill
}
def closeFunction(message: NgWebsocketResponse): Unit = {
Option(queueRef.get()).foreach(_.complete())
message match {
case NgWebsocketResponse(_, Some(status), Some(reason)) => out ! CloseMessage(status, reason)
case _ => // do nothing
}
}
def receive: Receive = {
case data: play.api.http.websocket.Message => {
wsEngine
.handleRequest(data)(closeFunction)
.map {
case Left(error) => {
error.rejectStrategy.foreach {
case RejectStrategy.Close =>
if (logger.isDebugEnabled)
logger.debug(s"[WEBSOCKET] close message from client: ${error.statusCode} : ${error.reason}")
Option(queueRef.get()).foreach(_.complete())
case _ => // TODO - logging ??
}
}
case Right(msg) => {
msg.asAkka.map { msg =>
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] message from client: ${msg}")
Option(queueRef.get()).foreach { q =>
q.offer(msg)
}
}
}
}
}
}
}
class WebsocketEngine(
route: NgRoute,
ctxPlugins: NgContextualPlugins,
rawRequest: RequestHeader,
target: Target,
attrs: TypedMap
) {
private def getPlugins()(
f: NgPluginWrapper.NgSimplePluginWrapper[NgWebsocketPlugin] => Boolean
): Seq[NgPluginWrapper.NgSimplePluginWrapper[NgWebsocketPlugin]] = {
ctxPlugins.websocketPlugins
.filter(f)
}
private def handle[A](
validators: Seq[NgPluginWrapper.NgSimplePluginWrapper[NgWebsocketPlugin]],
data: WebsocketMessage,
applyResponseFilter: Boolean = false
)(
closeConnection: NgWebsocketResponse => Unit
)(implicit env: Env, ec: ExecutionContext): Future[Either[NgWebsocketError, WebsocketMessage]] = {
val promise = Promise[Either[NgWebsocketError, WebsocketMessage]]()
def next(current: WebsocketMessage, plugins: Seq[NgPluginWrapper[NgWebsocketPlugin]]): Unit = {
plugins.headOption match {
case None => promise.trySuccess(Right(current))
case Some(wrapper) =>
val ctx = NgWebsocketPluginContext(
snowflake = attrs.get(otoroshi.plugins.Keys.SnowFlakeKey).get,
route = route,
request = rawRequest,
attrs = attrs,
config = wrapper.plugin.defaultConfig
.map(dc => dc ++ wrapper.instance.config.raw)
.getOrElse(wrapper.instance.config.raw),
target = target
)
(if (applyResponseFilter) {
wrapper.plugin.onResponseMessage(ctx, current)
} else {
wrapper.plugin.onRequestMessage(ctx, current)
}).andThen {
case Failure(_) =>
promise.trySuccess(
Left(NgWebsocketError(500.some, "internal_server_error".some, wrapper.plugin.rejectStrategy(ctx).some))
)
case Success(Left(error)) => {
//println("DENIED", wrapper.plugin.rejectStrategy(ctx), wrapper.plugin.name, error.statusCode, error.reason)
wrapper.plugin.rejectStrategy(ctx) match {
case RejectStrategy.Close =>
closeConnection(NgWebsocketResponse(NgAccess.NgAllowed, error.statusCode, error.reason))
case _ => // TODO - logging ??
}
promise.trySuccess(Left(error.copy(rejectStrategy = wrapper.plugin.rejectStrategy(ctx).some)))
}
case Success(Right(nextMessage)) if plugins.size == 1 => {
promise.trySuccess(Right(nextMessage))
}
case Success(Right(nextMessage)) => {
next(nextMessage, plugins.tail)
}
}
}
}
next(data, validators)
promise.future
}
def handleRequest(data: play.api.http.websocket.Message)(
closeConnection: NgWebsocketResponse => Unit
)(implicit env: Env, ec: ExecutionContext): Future[Either[NgWebsocketError, WebsocketMessage]] = {
if (ctxPlugins.hasNoWebsocketPlugins) {
val r: Either[NgWebsocketError, WebsocketMessage] =
Right[NgWebsocketError, WebsocketMessage](WebsocketMessage.PlayMessage(data))
r.vfuture
} else {
val requestValidators: Seq[NgPluginWrapper.NgSimplePluginWrapper[NgWebsocketPlugin]] =
getPlugins()(_.plugin.onRequestFlow)
handle(requestValidators, WebsocketMessage.PlayMessage(data))(closeConnection)
}
}
def handleResponse(data: akka.http.scaladsl.model.ws.Message)(
closeConnection: NgWebsocketResponse => Unit
)(implicit env: Env, ec: ExecutionContext): Future[Either[NgWebsocketError, WebsocketMessage]] = {
if (ctxPlugins.hasNoWebsocketPlugins) {
WebsocketMessage.AkkaMessage(data).rightf[NgWebsocketError]
} else {
val responseValidators = getPlugins()(_.plugin.onResponseFlow)
handle(responseValidators, WebsocketMessage.AkkaMessage(data), applyResponseFilter = true)(closeConnection)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy