All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ratpack.websocket.internal.WebSocketConnector Maven / Gradle / Ivy

There is a newer version: 2.0.0-rc-1
Show newest version
/*
 * Copyright 2013 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ratpack.websocket.internal;

import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.websocketx.*;
import ratpack.handling.Context;
import ratpack.handling.direct.DirectChannelAccess;
import ratpack.http.Request;
import ratpack.server.PublicAddress;
import ratpack.func.Action;
import ratpack.websocket.WebSocket;
import ratpack.websocket.WebSocketHandler;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_KEY;
import static io.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_VERSION;
import static io.netty.handler.codec.http.HttpMethod.valueOf;
import static ratpack.util.ExceptionUtils.toException;
import static ratpack.util.ExceptionUtils.uncheck;

public class WebSocketConnector {

  public static  void connect(final Context context, String path, int maxLength, final WebSocketHandler handler) {
    PublicAddress publicAddress = context.get(PublicAddress.class);
    URI address = publicAddress.getAddress(context);
    URI httpPath = address.resolve(path);

    URI wsPath;
    try {
      wsPath = new URI("ws", httpPath.getUserInfo(), httpPath.getHost(), httpPath.getPort(), httpPath.getPath(), httpPath.getQuery(), httpPath.getFragment());
    } catch (URISyntaxException e) {
      throw uncheck(e);
    }

    WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(wsPath.toString(), null, false, maxLength);

    Request request = context.getRequest();
    HttpMethod method = valueOf(request.getMethod().getName());
    FullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, method, request.getUri());
    nettyRequest.headers().add(SEC_WEBSOCKET_VERSION, request.getHeaders().get(SEC_WEBSOCKET_VERSION));
    nettyRequest.headers().add(SEC_WEBSOCKET_KEY, request.getHeaders().get(SEC_WEBSOCKET_KEY));

    final WebSocketServerHandshaker handshaker = factory.newHandshaker(nettyRequest);

    final DirectChannelAccess directChannelAccess = context.getDirectChannelAccess();
    final Channel channel = directChannelAccess.getChannel();

    handshaker.handshake(channel, nettyRequest).addListener(new HandshakeFutureListener<>(context, handshaker, handler));
  }

  private static class HandshakeFutureListener implements ChannelFutureListener {

    private final Context context;
    private final WebSocketServerHandshaker handshaker;
    private final WebSocketHandler handler;

    private volatile T openResult;
    private final CountDownLatch openLatch = new CountDownLatch(1);

    public HandshakeFutureListener(Context context, WebSocketServerHandshaker handshaker, WebSocketHandler handler) {
      this.context = context;
      this.handshaker = handshaker;
      this.handler = handler;
    }

    public void operationComplete(ChannelFuture future) throws Exception {
      if (future.isSuccess()) {
        final AtomicBoolean open = new AtomicBoolean(true);
        final WebSocket webSocket = new DefaultWebSocket(context.getDirectChannelAccess().getChannel(), open, new Runnable() {
          @Override
          public void run() {
            try {
              handler.onClose(new DefaultWebSocketClose<>(false, openResult));
            } catch (Exception e) {
              throw uncheck(e);
            }
          }
        });

        context.getDirectChannelAccess().takeOwnership(new Action() {
          @Override
          public void execute(Object msg) throws Exception {
            openLatch.await();
            Channel channel = context.getDirectChannelAccess().getChannel();
            if (channel.isOpen()) {
              if (msg instanceof WebSocketFrame) {
                WebSocketFrame frame = (WebSocketFrame) msg;
                if (frame instanceof CloseWebSocketFrame) {
                  open.set(false);
                  handshaker.close(channel, (CloseWebSocketFrame) frame.retain()).addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                      handler.onClose(new DefaultWebSocketClose<>(true, openResult));
                    }
                  });
                  return;
                }
                if (frame instanceof PingWebSocketFrame) {
                  channel.write(new PongWebSocketFrame(frame.content().retain()));
                  return;
                }
                if (frame instanceof TextWebSocketFrame) {
                  TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) frame;
                  handler.onMessage(new DefaultWebSocketMessage<>(webSocket, textWebSocketFrame.text(), openResult));
                }
              }
            }
          }
        });

        try {
          openResult = handler.onOpen(webSocket);
        } catch (Exception e) {
          handshaker.close(context.getDirectChannelAccess().getChannel(), new CloseWebSocketFrame(1011, e.getMessage()));
        }
        openLatch.countDown();
      } else {
        context.error(toException(future.cause()));
      }
    }
  }
}