All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.usergrid.websocket.WebSocketChannelHandler Maven / Gradle / Ivy

There is a newer version: 0.0.27.1
Show newest version
/*******************************************************************************
 * Copyright 2012 Apigee Corporation
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package org.usergrid.websocket;

import static org.apache.commons.lang.StringUtils.isEmpty;
import static org.apache.commons.lang.StringUtils.removeEnd;
import static org.apache.commons.lang.StringUtils.split;
import static org.jboss.netty.handler.codec.http.HttpHeaders.isKeepAlive;
import static org.jboss.netty.handler.codec.http.HttpHeaders.setContentLength;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONNECTION;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ORIGIN;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_KEY1;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_KEY2;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_LOCATION;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_ORIGIN;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.WEBSOCKET_LOCATION;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.WEBSOCKET_ORIGIN;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.WEBSOCKET_PROTOCOL;
import static org.jboss.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET;
import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;

import java.security.MessageDigest;
import java.util.List;
import java.util.concurrent.ConcurrentMap;

import org.apache.shiro.mgt.SessionsSecurityManager;
import org.apache.shiro.subject.Subject;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpHeaders.Names;
import org.jboss.netty.handler.codec.http.HttpHeaders.Values;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.QueryStringDecoder;
import org.jboss.netty.handler.codec.http.websocket.DefaultWebSocketFrame;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrame;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameDecoder;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameEncoder;
import org.jboss.netty.util.CharsetUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.usergrid.management.ManagementService;
import org.usergrid.persistence.EntityManagerFactory;
import org.usergrid.services.ServiceManagerFactory;

import com.google.common.base.Function;
import com.google.common.collect.MapMaker;

public class WebSocketChannelHandler extends SimpleChannelUpstreamHandler {

	private static final Logger logger = LoggerFactory
			.getLogger(WebSocketChannelHandler.class);

	private final EntityManagerFactory emf;
	private final ServiceManagerFactory smf;
	private final ManagementService management;
	private final SessionsSecurityManager securityManager;
	private final boolean ssl;

	boolean websocket = false;

	Subject subject = null;

	// static ConcurrentHashMap subscribers = new
	// ConcurrentHashMap();

	static ConcurrentMap subscribers = new MapMaker()
			.makeComputingMap(new Function() {
				@Override
				public ChannelGroup apply(String key) {
					return new DefaultChannelGroup();
				}
			});

	List subscriptions;

	public WebSocketChannelHandler(EntityManagerFactory emf,
			ServiceManagerFactory smf, ManagementService management,
			SessionsSecurityManager securityManager, boolean ssl) {
		super();

		this.emf = emf;
		this.smf = smf;
		this.management = management;
		this.securityManager = securityManager;
		this.ssl = ssl;

		if (securityManager != null) {
			subject = new Subject.Builder(securityManager).buildSubject();
		}
	}

	public EntityManagerFactory getEmf() {
		return emf;
	}

	public ServiceManagerFactory getSmf() {
		return smf;
	}

	public ManagementService getOrganizations() {
		return management;
	}

	public SessionsSecurityManager getSecurityManager() {
		return securityManager;
	}

	private String getWebSocketLocation(HttpRequest req) {
		String path = req.getUri();
		if (path.equals("/")) {
			path = null;
		}
		if (path != null) {
			path = removeEnd(path, "/");
		}
		String location = (ssl ? "wss://" : "ws://")
				+ req.getHeader(HttpHeaders.Names.HOST)
				+ (path != null ? path : "");
		logger.info(location);
		return location;
	}

	private void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req,
			HttpResponse res) {
		// Generate an error page if response status code is not OK (200).
		if (res.getStatus().getCode() != 200) {
			res.setContent(ChannelBuffers.copiedBuffer(res.getStatus()
					.toString(), CharsetUtil.UTF_8));
			setContentLength(res, res.getContent().readableBytes());
		}

		// Send the response and close the connection if necessary.
		ChannelFuture f = ctx.getChannel().write(res);
		if (!isKeepAlive(req) || (res.getStatus().getCode() != 200)) {
			f.addListener(ChannelFutureListener.CLOSE);
		}
	}

	private void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req,
			HttpResponseStatus status) {
		sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, status));
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) {
		logger.warn("Unexpected exception from downstream.", e.getCause());
		e.getChannel().close();
	}

	@Override
	public void channelDisconnected(ChannelHandlerContext ctx,
			ChannelStateEvent e) throws Exception {
		super.channelDisconnected(ctx, e);
		if (websocket) {
			logger.info("Websocket disconnected");
		}
	}

	@Override
	public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
			throws Exception {
		Object msg = e.getMessage();
		if (msg instanceof HttpRequest) {
			handleHttpRequest(ctx, (HttpRequest) msg);
		} else if (msg instanceof WebSocketFrame) {
			handleWebSocketFrame(ctx, (WebSocketFrame) msg);
		}
	}

	private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req)
			throws Exception {
		// Allow only GET methods.
		if (req.getMethod() != GET) {
			sendHttpResponse(ctx, req, FORBIDDEN);
			return;
		}

		boolean is_ws_request = Values.UPGRADE.equalsIgnoreCase(req
				.getHeader(CONNECTION))
				&& WEBSOCKET.equalsIgnoreCase(req.getHeader(Names.UPGRADE));

		// Send the demo page.
		if (!is_ws_request && req.getUri().equals("/")) {
			HttpResponse res = new DefaultHttpResponse(HTTP_1_1, OK);

			ChannelBuffer content = WebSocketServerIndexPage
					.getContent(getWebSocketLocation(req));

			res.setHeader(CONTENT_TYPE, "text/html; charset=UTF-8");
			setContentLength(res, content.readableBytes());

			res.setContent(content);
			sendHttpResponse(ctx, req, res);

			return;

		} else if (is_ws_request) {
			// Serve the WebSocket handshake request.

			logger.info("Starting new websocket connection...");
			websocket = true;

			// Create the WebSocket handshake response.
			HttpResponse res = new DefaultHttpResponse(
					HTTP_1_1,
					new HttpResponseStatus(101, "Web Socket Protocol Handshake"));
			res.addHeader(Names.UPGRADE, WEBSOCKET);
			res.addHeader(CONNECTION, Values.UPGRADE);

			QueryStringDecoder qs = new QueryStringDecoder(req.getUri());
			String path = qs.getPath();
			logger.info(path);

			// Fill in the headers and contents depending on handshake method.
			if (req.containsHeader(SEC_WEBSOCKET_KEY1)
					&& req.containsHeader(SEC_WEBSOCKET_KEY2)) {

				String[] segments = split(path, '/');

				if (segments.length != 3) {
					logger.info("Wrong number of path segments, expected 3, found "
							+ segments.length);
					sendHttpResponse(ctx, req, FORBIDDEN);
					return;
				}

				String nsStr = segments[0];
				String collStr = segments[1];
				String idStr = segments[2];

				logger.info(nsStr + "/" + collStr + "/" + idStr);

				if (isEmpty(nsStr) || isEmpty(collStr) || isEmpty(idStr)) {
					sendHttpResponse(ctx, req, FORBIDDEN);
					return;
				}

				// New handshake method with a challenge:
				res.addHeader(SEC_WEBSOCKET_ORIGIN, req.getHeader(ORIGIN));
				res.addHeader(SEC_WEBSOCKET_LOCATION, getWebSocketLocation(req));
				String protocol = req.getHeader(SEC_WEBSOCKET_PROTOCOL);
				if (protocol != null) {
					res.addHeader(SEC_WEBSOCKET_PROTOCOL, protocol);
				}

				// Calculate the answer of the challenge.
				String key1 = req.getHeader(SEC_WEBSOCKET_KEY1);
				String key2 = req.getHeader(SEC_WEBSOCKET_KEY2);
				int a = (int) (Long.parseLong(key1.replaceAll("[^0-9]", "")) / key1
						.replaceAll("[^ ]", "").length());
				int b = (int) (Long.parseLong(key2.replaceAll("[^0-9]", "")) / key2
						.replaceAll("[^ ]", "").length());
				long c = req.getContent().readLong();
				ChannelBuffer input = ChannelBuffers.buffer(16);
				input.writeInt(a);
				input.writeInt(b);
				input.writeLong(c);
				ChannelBuffer output = ChannelBuffers
						.wrappedBuffer(MessageDigest.getInstance("MD5").digest(
								input.array()));
				res.setContent(output);
			} else {
				// Old handshake method with no challenge:
				res.addHeader(WEBSOCKET_ORIGIN, req.getHeader(ORIGIN));
				res.addHeader(WEBSOCKET_LOCATION, getWebSocketLocation(req));
				String protocol = req.getHeader(WEBSOCKET_PROTOCOL);
				if (protocol != null) {
					res.addHeader(WEBSOCKET_PROTOCOL, protocol);
				}
			}

			// Upgrade the connection and send the handshake response.
			ChannelPipeline p = ctx.getChannel().getPipeline();
			p.remove("aggregator");
			p.replace("decoder", "wsdecoder", new WebSocketFrameDecoder());

			ctx.getChannel().write(res);

			p.replace("encoder", "wsencoder", new WebSocketFrameEncoder());

			return;
		}

		// Send an error page otherwise.
		sendHttpResponse(ctx, req, FORBIDDEN);
	}

	private void handleWebSocketFrame(ChannelHandlerContext ctx,
			WebSocketFrame frame) {
		// Send the uppercased string back.
		ctx.getChannel().write(
				new DefaultWebSocketFrame(frame.getTextData().toUpperCase()));
	}

	// TODO Review this for concurrency safety
	// Note: subscriptions are added and removed relatively infrequently
	// during the lifecycle of a connection i.e. typical minimum lifespan
	// would be 10 seconds when someone opens an app, it connects, then
	// they close the app.

	public void addSubscription(String path, Channel channel) {
		ChannelGroup group = subscribers.get(path);
		synchronized (group) {
			group.add(channel);
		}
	}

	public void removeSubscription(String path, Channel channel) {
		ChannelGroup group = subscribers.get(path);
		synchronized (group) {
			group.remove(channel);
			if (group.isEmpty()) {
				subscribers.remove(path, group);
			}
		}
	}

	public ChannelGroup getSubscriptionGroup(String path) {
		return subscribers.get(path);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy