All Downloads are FREE. Search and download functionalities are using the official Maven repository.

xin.alum.aim.websocks.WebSocketServerHandler Maven / Gradle / Ivy

There is a newer version: 1.9.6
Show newest version
/*
 * Copyright 2014 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package xin.alum.aim.websocks;

import xin.alum.aim.AIM;
import xin.alum.aim.config.DataAgreement;
import xin.alum.aim.constant.AIMConstant;
import xin.alum.aim.constant.ChannelAttr;
import xin.alum.aim.constant.ChannelClose;
import xin.alum.aim.handler.BaseServerHandler;
import xin.alum.aim.model.Sent;
import xin.alum.aim.model.proto.ReplyProto;
import xin.alum.aim.model.proto.SentProto;
import com.google.protobuf.MessageLite;
import io.netty.buffer.*;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.AttributeKey;
import org.springframework.scheduling.annotation.Async;
import xin.alum.aim.server.ServerInitializer;

import java.util.List;
import java.util.Map;

import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.*;

/**
 * Handles handshakes and messages
 */
@ChannelHandler.Sharable
public class WebSocketServerHandler extends BaseServerHandler {

    private WebSocketServerHandshaker handshake;

    @Async
    @Override
    public ChannelFuture send(Channel ch, Object msg) {
        //JSON-String
        if (msg == AIMConstant.KEY_PONG) {
            msg = new PongWebSocketFrame();
        } else if (msg == AIMConstant.KEY_PING) {
            msg = new PingWebSocketFrame();
        } else if (msg instanceof String) {
            msg = new TextWebSocketFrame(((String) msg).concat(AIMConstant.TEXT_FRAME_Delimiters));
        } else //Proto
            if (msg instanceof MessageLite) {
                msg = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(((MessageLite) msg).toByteArray()));
            } else if (msg instanceof MessageLite.Builder) {
                msg = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(((MessageLite.Builder) msg).build().toByteArray()));
            }
        return super.send(ch, msg);
    }

    @Override
    public void channelRead0(ChannelHandlerContext ctx, Object msg) {
        if (msg instanceof FullHttpRequest) {
            handleHttpRequest(ctx, (FullHttpRequest) msg);
        } else if (msg instanceof WebSocketFrame) {
            handleWebSocketFrame(ctx, (WebSocketFrame) msg);
        } else {
            logger.error("未知消息类型");
        }
    }

    private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
        // Handle a bad request.
        if (!req.decoderResult().isSuccess()) {
            sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), BAD_REQUEST, ctx.alloc().buffer(0)));
            return;
        }

        // Allow only GET methods.
        if (!GET.equals(req.method())) {
            sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), FORBIDDEN, ctx.alloc().buffer(0)));
            return;
        }
        FullHttpResponse res;
        // Send the demo page and favicon.ico
        if ("/".equals(req.uri())) {
            ByteBuf content = WebSocketServerBenchmarkPage.getContent(getWebSocketLocation(req));
            res = new DefaultFullHttpResponse(req.protocolVersion(), OK, content);

            res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8");
            HttpUtil.setContentLength(res, content.readableBytes());

            sendHttpResponse(ctx, req, res);
            return;
        }

        // Handshake
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, true, 5 * 1024 * 1024);
        handshake = wsFactory.newHandshaker(req);
        if (handshake == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
        } else {
            //读取URL参数进行初始化
            QueryStringDecoder queryStringDecoder = new QueryStringDecoder(req.uri());
            Map> params = queryStringDecoder.parameters();
            if (!params.isEmpty()) {
                //URL参数传递编码
                List agreements = params.get("agreement");
                DataAgreement agreement = agreements == null ? AIM.properties.getAgreement() : DataAgreement.valueOf(agreements.get(0));
                ServerInitializer.InitAgreement(ctx.pipeline(), agreement);

                params.forEach((key, val) -> {
                    ctx.channel().attr(AttributeKey.valueOf(key)).set(val.size() == 1 ? val.get(0) : val);
                });
            }
            res = new DefaultFullHttpResponse(req.protocolVersion(), OK, ctx.alloc().buffer(0));
            if (AIM.request.onHandShake(ctx.channel(), req, res)) {
                ChannelFuture f = handshake.handshake(ctx.channel(), req);
                f.addListener(t -> {
                    if (t.isSuccess()) {
                        AIM.request.onHandShaked(ctx.channel());
                    } else {
                        logger.error("{}握手失败", f.channel());
                    }
                });
            } else {
                logger.error("{}握手被拒绝", ctx.channel());
                //拒绝接入
//                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
                reject(ctx.channel(), res);
            }
        }
    }

    /**
     * 拒绝接入 自定义返回码
     *
     * @param ch
     * @param res
     * @return
     */
    private ChannelFuture reject(Channel ch, HttpResponse res) {
        res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue());
//        HttpUtil.setContentLength(res, 0L);
        return ch.writeAndFlush(res, ch.newPromise());
    }

    private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
        try {
            // Check for closing frame
            if (frame instanceof CloseWebSocketFrame) {
                //响应关闭
                super.send(ctx.channel(), frame.retain());
                handshake.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
                super.onClose(ctx.channel(), ChannelClose.CLOSE);
            } else if (frame instanceof PingWebSocketFrame) {
                super.send(ctx.channel(), new PongWebSocketFrame());
                AIM.request.onPing(ctx.channel());
            } else if (frame instanceof TextWebSocketFrame) {
                ctx.channel().attr(ChannelAttr.AGREEMENT).set(DataAgreement.Json);
                AIM.request.onText(ctx.channel(), ((TextWebSocketFrame) frame).text());
            } else if (frame instanceof BinaryWebSocketFrame) {
                ctx.channel().attr(ChannelAttr.AGREEMENT).set(DataAgreement.ProtoBuf);
                AIM.request.onByte(ctx.channel(), frame.content());
            }
        } catch (Exception e) {
            logger.error("WebSocket消息处理异常:{},{}", e.getMessage(), e);
        }
    }

    private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
        // Generate an error page if response getStatus code is not OK (200).
        HttpResponseStatus responseStatus = res.status();
        if (responseStatus.code() != 200) {
            ByteBufUtil.writeUtf8(res.content(), responseStatus.toString());
            HttpUtil.setContentLength(res, res.content().readableBytes());
        }
        // Send the response and close the connection if necessary.
        boolean keepAlive = HttpUtil.isKeepAlive(req) && responseStatus.code() == 200;
        HttpUtil.setKeepAlive(res, keepAlive);
        ChannelFuture future = ctx.write(res); // Flushed in channelReadComplete()
        if (!keepAlive) {
            future.addListener(ChannelFutureListener.CLOSE);
        }
    }

    private static String getWebSocketLocation(FullHttpRequest req) {
        String location = req.headers().get(HttpHeaderNames.HOST) + AIM.properties.getPath();
        if (AIM.properties.isSsl()) {
            return "wss://" + location;
        } else {
            return "ws://" + location;
        }
    }
}