org.apache.shenyu.plugin.websocket.WebSocketPlugin Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shenyu.plugin.websocket;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.shenyu.common.constant.Constants;
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.impl.WebSocketRuleHandle;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.enums.RpcTypeEnum;
import org.apache.shenyu.loadbalancer.cache.UpstreamCacheManager;
import org.apache.shenyu.loadbalancer.entity.Upstream;
import org.apache.shenyu.loadbalancer.factory.LoadBalancerFactory;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.context.ShenyuContext;
import org.apache.shenyu.plugin.api.result.ShenyuResultEnum;
import org.apache.shenyu.plugin.api.result.ShenyuResultWrap;
import org.apache.shenyu.plugin.api.utils.RequestUrlUtils;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
import org.apache.shenyu.plugin.websocket.handler.WebSocketPluginDataHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.NonNull;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* The type Web socket plugin.
*/
public class WebSocketPlugin extends AbstractShenyuPlugin {
private static final Logger LOG = LoggerFactory.getLogger(WebSocketPlugin.class);
private static final String SEC_WEB_SOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
private final WebSocketClient webSocketClient;
private final WebSocketService webSocketService;
/**
* Instantiates a new Web socket plugin.
*
* @param webSocketClient the web socket client
* @param webSocketService the web socket service
*/
public WebSocketPlugin(final WebSocketClient webSocketClient, final WebSocketService webSocketService) {
this.webSocketClient = webSocketClient;
this.webSocketService = webSocketService;
}
@Override
protected Mono doExecute(final ServerWebExchange exchange, final ShenyuPluginChain chain, final SelectorData selector, final RuleData rule) {
final List upstreamList = UpstreamCacheManager.getInstance().findUpstreamListBySelectorId(selector.getId());
final ShenyuContext shenyuContext = exchange.getAttribute(Constants.CONTEXT);
if (CollectionUtils.isEmpty(upstreamList) || Objects.isNull(shenyuContext)) {
LOG.error("websocket upstream configuration error:{}", rule);
return chain.execute(exchange);
}
final WebSocketRuleHandle ruleHandle = buildRuleHandle(rule);
final String ip = Objects.requireNonNull(exchange.getRequest().getRemoteAddress()).getAddress().getHostAddress();
Upstream upstream = LoadBalancerFactory.selector(upstreamList, ruleHandle.getLoadBalance(), ip);
if (Objects.isNull(upstream)) {
LOG.error("websocket has no upstream, error:{}", rule);
Object error = ShenyuResultWrap.error(exchange, ShenyuResultEnum.CANNOT_FIND_HEALTHY_UPSTREAM_URL);
return WebFluxResultUtils.result(exchange, error);
}
URI wsRequestUrl = buildWsRealPath(exchange, upstream, shenyuContext);
LOG.info("you websocket urlPath is :{}", wsRequestUrl.toASCIIString());
HttpHeaders headers = exchange.getRequest().getHeaders();
return this.webSocketService.handleRequest(exchange, new ShenyuWebSocketHandler(
wsRequestUrl, this.webSocketClient, filterHeaders(headers), buildWsProtocols(headers)));
}
private WebSocketRuleHandle buildRuleHandle(final RuleData rule) {
return WebSocketPluginDataHandler.CACHED_HANDLE.get().obtainHandle(CacheKeyUtils.INST.getKey(rule));
}
private URI buildWsRealPath(final ServerWebExchange exchange, final Upstream upstream, final ShenyuContext shenyuContext) {
String protocol = upstream.getProtocol();
if (!StringUtils.hasLength(protocol)) {
protocol = "ws://";
}
return RequestUrlUtils.buildRequestUri(exchange, upstream.buildDomain(protocol));
}
private List buildWsProtocols(final HttpHeaders headers) {
List protocols = headers.get(SEC_WEB_SOCKET_PROTOCOL);
if (CollectionUtils.isEmpty(protocols)) {
return protocols;
}
return protocols.stream()
.flatMap(header -> Arrays.stream(StringUtils.commaDelimitedListToStringArray(header)))
.map(String::trim)
.collect(Collectors.toList());
}
private HttpHeaders filterHeaders(final HttpHeaders headers) {
HttpHeaders filtered = new HttpHeaders();
headers.entrySet().stream()
.filter(entry -> !entry.getKey().toLowerCase()
.startsWith("sec-websocket"))
.forEach(header -> filtered.addAll(header.getKey(),
header.getValue()));
filtered.remove(HttpHeaders.HOST);
return filtered;
}
// see https://github.com/spring-cloud/spring-cloud-gateway/pull/2254
private static CloseStatus adaptCloseStatus(final CloseStatus closeStatus) {
int code = closeStatus.getCode();
if (code > 2999 && code < 5000) {
return closeStatus;
}
switch (code) {
case 1000:
case 1001:
case 1002:
case 1003:
case 1007:
case 1008:
case 1009:
case 1010:
case 1011:
return closeStatus;
case 1004:
// Should not be used in a close frame
// RESERVED;
case 1005:
// Should not be used in a close frame
// return CloseStatus.NO_STATUS_CODE;
case 1006:
// Should not be used in a close frame
// return CloseStatus.NO_CLOSE_FRAME;
case 1012:
// Not in RFC6455
// return CloseStatus.SERVICE_RESTARTED;
case 1013:
// Not in RFC6455
// return CloseStatus.SERVICE_OVERLOAD;
case 1015:
// Should not be used in a close frame
// return CloseStatus.TLS_HANDSHAKE_FAILURE;
default:
return CloseStatus.PROTOCOL_ERROR;
}
}
@Override
public String named() {
return PluginEnum.WEB_SOCKET.getName();
}
/**
* plugin is execute.
*
* @return default false.
*/
@Override
public boolean skip(final ServerWebExchange exchange) {
return skipExcept(exchange, RpcTypeEnum.WEB_SOCKET);
}
@Override
protected Mono handleSelectorIfNull(final String pluginName, final ServerWebExchange exchange, final ShenyuPluginChain chain) {
return WebFluxResultUtils.noSelectorResult(pluginName, exchange);
}
@Override
protected Mono handleRuleIfNull(final String pluginName, final ServerWebExchange exchange, final ShenyuPluginChain chain) {
return WebFluxResultUtils.noRuleResult(pluginName, exchange);
}
@Override
public int getOrder() {
return PluginEnum.WEB_SOCKET.getCode();
}
private static class ShenyuWebSocketHandler implements WebSocketHandler {
private final WebSocketClient client;
private final URI url;
private final HttpHeaders headers;
private final List subProtocols;
/**
* Instantiates a new shenyu web socket handler.
*
* @param url the url
* @param client the client
* @param headers the headers
* @param protocols the protocols
*/
ShenyuWebSocketHandler(final URI url, final WebSocketClient client,
final HttpHeaders headers,
final List protocols) {
this.client = client;
this.url = url;
this.headers = headers;
this.subProtocols = ObjectUtils.defaultIfNull(protocols, Collections.emptyList());
}
@NonNull
@Override
public List getSubProtocols() {
return this.subProtocols;
}
@NonNull
@Override
public Mono handle(@NonNull final WebSocketSession session) {
// pass headers along so custom headers can be sent through
return client.execute(url, this.headers, new WebSocketHandler() {
@NonNull
@Override
public Mono handle(@NonNull final WebSocketSession proxySocketSession) {
Mono serverClose = proxySocketSession.closeStatus().filter(it -> session.isOpen())
.map(WebSocketPlugin::adaptCloseStatus).flatMap(session::close);
Mono proxyClose = session.closeStatus().filter(it -> proxySocketSession.isOpen())
.map(WebSocketPlugin::adaptCloseStatus).flatMap(proxySocketSession::close);
// Use retain() for Reactor Netty
Mono proxySessionSend = proxySocketSession
.send(session.receive().doOnNext(WebSocketMessage::retain));
Mono serverSessionSend = session.send(
proxySocketSession.receive().doOnNext(WebSocketMessage::retain));
// Ensure closeStatus from one propagates to the other
Mono.when(serverClose, proxyClose).subscribe();
return Mono.zip(proxySessionSend, serverSessionSend).then();
}
@NonNull
@Override
public List getSubProtocols() {
return ShenyuWebSocketHandler.this.subProtocols;
}
});
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy