org.eclipse.jetty.websocket.WebSocketClientFactory Maven / Gradle / Ivy
//
// ========================================================================
// Copyright (c) 1995-2016 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.websocket;
import java.io.EOFException;
import java.io.IOException;
import java.net.ProtocolException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.net.ssl.SSLEngine;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpParser;
import org.eclipse.jetty.io.AbstractConnection;
import org.eclipse.jetty.io.AsyncEndPoint;
import org.eclipse.jetty.io.Buffer;
import org.eclipse.jetty.io.Buffers;
import org.eclipse.jetty.io.ByteArrayBuffer;
import org.eclipse.jetty.io.ConnectedEndPoint;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.io.SimpleBuffers;
import org.eclipse.jetty.io.nio.AsyncConnection;
import org.eclipse.jetty.io.nio.SelectChannelEndPoint;
import org.eclipse.jetty.io.nio.SelectorManager;
import org.eclipse.jetty.io.nio.SslConnection;
import org.eclipse.jetty.util.B64Code;
import org.eclipse.jetty.util.QuotedStringTokenizer;
import org.eclipse.jetty.util.component.AggregateLifeCycle;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.ThreadPool;
/* ------------------------------------------------------------ */
/**
* WebSocketClientFactory contains the common components needed by multiple {@link WebSocketClient} instances
* (for example, a {@link ThreadPool}, a {@link SelectorManager NIO selector}, etc).
* WebSocketClients with different configurations should share the same factory to avoid to waste resources.
* If a ThreadPool or MaskGen is passed in the constructor, then it is not added with {@link AggregateLifeCycle#addBean(Object)},
* so it's lifecycle must be controlled externally.
*
* @see WebSocketClient
*/
public class WebSocketClientFactory extends AggregateLifeCycle
{
private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName());
private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept");
private final Queue connections = new ConcurrentLinkedQueue();
private final SslContextFactory _sslContextFactory = new SslContextFactory();
private final ThreadPool _threadPool;
private final WebSocketClientSelector _selector;
private MaskGen _maskGen;
private WebSocketBuffers _buffers;
/* ------------------------------------------------------------ */
/**
* Creates a WebSocketClientFactory with the default configuration.
*/
public WebSocketClientFactory()
{
this(null);
}
/* ------------------------------------------------------------ */
/**
* Creates a WebSocketClientFactory with the given ThreadPool and the default configuration.
*
* @param threadPool the ThreadPool instance to use
*/
public WebSocketClientFactory(ThreadPool threadPool)
{
this(threadPool, new RandomMaskGen());
}
/* ------------------------------------------------------------ */
/**
* Creates a WebSocketClientFactory with the given ThreadPool and the given MaskGen.
*
* @param threadPool the ThreadPool instance to use
* @param maskGen the MaskGen instance to use
*/
public WebSocketClientFactory(ThreadPool threadPool, MaskGen maskGen)
{
this(threadPool, maskGen, 8192);
}
/* ------------------------------------------------------------ */
/**
* Creates a WebSocketClientFactory with the specified configuration.
*
* @param threadPool the ThreadPool instance to use
* @param maskGen the mask generator to use
* @param bufferSize the read buffer size
*/
public WebSocketClientFactory(ThreadPool threadPool, MaskGen maskGen, int bufferSize)
{
if (threadPool == null)
threadPool = new QueuedThreadPool();
_threadPool = threadPool;
addBean(_threadPool);
_buffers = new WebSocketBuffers(bufferSize);
addBean(_buffers);
_maskGen = maskGen;
addBean(_maskGen);
_selector = new WebSocketClientSelector();
addBean(_selector);
addBean(_sslContextFactory);
}
/* ------------------------------------------------------------ */
/**
* @return the SslContextFactory used to configure SSL parameters
*/
public SslContextFactory getSslContextFactory()
{
return _sslContextFactory;
}
/* ------------------------------------------------------------ */
/**
* Get the selectorManager. Used to configure the manager.
*
* @return The {@link SelectorManager} instance.
*/
public SelectorManager getSelectorManager()
{
return _selector;
}
/* ------------------------------------------------------------ */
/**
* Get the ThreadPool.
* Used to set/query the thread pool configuration.
*
* @return The {@link ThreadPool}
*/
public ThreadPool getThreadPool()
{
return _threadPool;
}
/* ------------------------------------------------------------ */
/**
* @return the shared mask generator, or null if no shared mask generator is used
* @see WebSocketClient#getMaskGen()
*/
public MaskGen getMaskGen()
{
return _maskGen;
}
/* ------------------------------------------------------------ */
/**
* @param maskGen the shared mask generator, or null if no shared mask generator is used
* @see WebSocketClient#setMaskGen(MaskGen)
*/
public void setMaskGen(MaskGen maskGen)
{
if (isRunning())
throw new IllegalStateException(getState());
removeBean(_maskGen);
_maskGen = maskGen;
addBean(maskGen);
}
/* ------------------------------------------------------------ */
/**
* @param bufferSize the read buffer size
* @see #getBufferSize()
*/
public void setBufferSize(int bufferSize)
{
if (isRunning())
throw new IllegalStateException(getState());
removeBean(_buffers);
_buffers = new WebSocketBuffers(bufferSize);
addBean(_buffers);
}
/* ------------------------------------------------------------ */
/**
* @return the read buffer size
*/
public int getBufferSize()
{
return _buffers.getBufferSize();
}
@Override
protected void doStop() throws Exception
{
closeConnections();
super.doStop();
}
/* ------------------------------------------------------------ */
/**
* Creates and returns a new instance of a {@link WebSocketClient}, configured with this
* WebSocketClientFactory instance.
*
* @return a new {@link WebSocketClient} instance
*/
public WebSocketClient newWebSocketClient()
{
return new WebSocketClient(this);
}
protected SSLEngine newSslEngine(SocketChannel channel) throws IOException
{
SSLEngine sslEngine;
if (channel != null)
{
String peerHost = channel.socket().getInetAddress().getHostAddress();
int peerPort = channel.socket().getPort();
sslEngine = _sslContextFactory.newSslEngine(peerHost, peerPort);
}
else
{
sslEngine = _sslContextFactory.newSslEngine();
}
sslEngine.setUseClientMode(true);
sslEngine.beginHandshake();
return sslEngine;
}
protected boolean addConnection(WebSocketConnection connection)
{
return isRunning() && connections.add(connection);
}
protected boolean removeConnection(WebSocketConnection connection)
{
return connections.remove(connection);
}
protected void closeConnections()
{
for (WebSocketConnection connection : connections)
connection.shutdown();
}
/* ------------------------------------------------------------ */
/**
* WebSocket Client Selector Manager
*/
class WebSocketClientSelector extends SelectorManager
{
@Override
public boolean dispatch(Runnable task)
{
return _threadPool.dispatch(task);
}
@Override
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, final SelectionKey key) throws IOException
{
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture)key.attachment();
int maxIdleTime = holder.getMaxIdleTime();
if (maxIdleTime < 0)
maxIdleTime = (int)getMaxIdleTime();
SelectChannelEndPoint result = new SelectChannelEndPoint(channel, selectSet, key, maxIdleTime);
AsyncEndPoint endPoint = result;
// Detect if it is SSL, and wrap the connection if so
if ("wss".equals(holder.getURI().getScheme()))
{
SSLEngine sslEngine = newSslEngine(channel);
SslConnection sslConnection = new SslConnection(sslEngine, endPoint);
endPoint.setConnection(sslConnection);
endPoint = sslConnection.getSslEndPoint();
}
AsyncConnection connection = selectSet.getManager().newConnection(channel, endPoint, holder);
endPoint.setConnection(connection);
return result;
}
@Override
public AsyncConnection newConnection(SocketChannel channel, AsyncEndPoint endpoint, Object attachment)
{
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture)attachment;
return new HandshakeConnection(endpoint, holder);
}
@Override
protected void endPointOpened(SelectChannelEndPoint endpoint)
{
// TODO expose on outer class ??
}
@Override
protected void endPointUpgraded(ConnectedEndPoint endpoint, Connection oldConnection)
{
LOG.debug("upgrade {} -> {}", oldConnection, endpoint.getConnection());
}
@Override
protected void endPointClosed(SelectChannelEndPoint endpoint)
{
endpoint.getConnection().onClose();
}
@Override
protected void connectionFailed(SocketChannel channel, Throwable ex, Object attachment)
{
if (!(attachment instanceof WebSocketClient.WebSocketFuture))
super.connectionFailed(channel, ex, attachment);
else
{
__log.debug(ex);
WebSocketClient.WebSocketFuture future = (WebSocketClient.WebSocketFuture)attachment;
future.handshakeFailed(ex);
}
}
}
/* ------------------------------------------------------------ */
/**
* Handshake Connection.
* Handles the connection until the handshake succeeds or fails.
*/
class HandshakeConnection extends AbstractConnection implements AsyncConnection
{
private final AsyncEndPoint _endp;
private final WebSocketClient.WebSocketFuture _future;
private final String _key;
private final HttpParser _parser;
private String _accept;
private String _error;
private ByteArrayBuffer _handshake;
public HandshakeConnection(AsyncEndPoint endpoint, WebSocketClient.WebSocketFuture future)
{
super(endpoint, System.currentTimeMillis());
_endp = endpoint;
_future = future;
byte[] bytes = new byte[16];
new Random().nextBytes(bytes);
_key = new String(B64Code.encode(bytes));
Buffers buffers = new SimpleBuffers(_buffers.getBuffer(), null);
_parser = new HttpParser(buffers, _endp, new HttpParser.EventHandler()
{
@Override
public void startResponse(Buffer version, int status, Buffer reason) throws IOException
{
if (status != 101)
{
_error = "Bad response status " + status + " " + reason;
_endp.close();
}
}
@Override
public void parsedHeader(Buffer name, Buffer value) throws IOException
{
if (__ACCEPT.equals(name))
_accept = value.toString();
}
@Override // TODO simone says shouldn't be needed
public void startRequest(Buffer method, Buffer url, Buffer version) throws IOException
{
if (_error == null)
_error = "Bad response: " + method + " " + url + " " + version;
_endp.close();
}
@Override // TODO simone says shouldn't be needed
public void content(Buffer ref) throws IOException
{
if (_error == null)
_error = "Bad response. " + ref.length() + "B of content?";
_endp.close();
}
});
}
private boolean handshake()
{
if (_handshake==null)
{
String path = _future.getURI().getPath();
if (path == null || path.length() == 0)
path = "/";
if (_future.getURI().getRawQuery() != null)
path += "?" + _future.getURI().getRawQuery();
String origin = _future.getOrigin();
StringBuilder request = new StringBuilder(512);
request.append("GET ").append(path).append(" HTTP/1.1\r\n")
.append("Host: ").append(_future.getURI().getHost()).append(":");
int port = _future.getURI().getPort();
if (port <= 0)
{
// fix it
String scheme = _future.getURI().getScheme();
if ("ws".equalsIgnoreCase(scheme))
{
port = 80;
}
else if ("wss".equalsIgnoreCase(scheme))
{
port = 443;
}
else
{
throw new RuntimeException("No valid port provided for scheme [" + scheme + "]");
}
}
request.append(port).append("\r\n");
request.append("Upgrade: websocket\r\n")
.append("Connection: Upgrade\r\n")
.append("Sec-WebSocket-Key: ")
.append(_key).append("\r\n");
if (origin != null)
request.append("Origin: ").append(origin).append("\r\n");
request.append("Sec-WebSocket-Version: ").append(WebSocketConnectionRFC6455.VERSION).append("\r\n");
if (_future.getProtocol() != null)
request.append("Sec-WebSocket-Protocol: ").append(_future.getProtocol()).append("\r\n");
Map cookies = _future.getCookies();
if (cookies != null && cookies.size() > 0)
{
for (String cookie : cookies.keySet())
request.append("Cookie: ")
.append(QuotedStringTokenizer.quoteIfNeeded(cookie, HttpFields.__COOKIE_DELIM))
.append("=")
.append(QuotedStringTokenizer.quoteIfNeeded(cookies.get(cookie), HttpFields.__COOKIE_DELIM))
.append("\r\n");
}
request.append("\r\n");
_handshake=new ByteArrayBuffer(request.toString(), false);
}
// TODO extensions
try
{
int len = _handshake.length();
int flushed = _endp.flush(_handshake);
if (flushed<0)
throw new IOException("incomplete handshake");
}
catch (IOException e)
{
_future.handshakeFailed(e);
}
return _handshake.length()==0;
}
public Connection handle() throws IOException
{
while (_endp.isOpen() && !_parser.isComplete())
{
if (_handshake==null || _handshake.length()>0)
if (!handshake())
return this;
if (!_parser.parseAvailable())
{
if (_endp.isInputShutdown())
_future.handshakeFailed(new IOException("Incomplete handshake response"));
return this;
}
}
if (_error == null)
{
if (_accept == null)
{
_error = "No Sec-WebSocket-Accept";
}
else if (!WebSocketConnectionRFC6455.hashKey(_key).equals(_accept))
{
_error = "Bad Sec-WebSocket-Accept";
}
else
{
WebSocketConnection connection = newWebSocketConnection();
Buffer header = _parser.getHeaderBuffer();
if (header.hasContent())
connection.fillBuffersFrom(header);
_buffers.returnBuffer(header);
_future.onConnection(connection);
return connection;
}
}
_endp.close();
return this;
}
private WebSocketConnection newWebSocketConnection() throws IOException
{
__log.debug("newWebSocketConnection()");
return new WebSocketClientConnection(
_future._client.getFactory(),
_future.getWebSocket(),
_endp,
_buffers,
System.currentTimeMillis(),
_future.getMaxIdleTime(),
_future.getProtocol(),
null,
WebSocketConnectionRFC6455.VERSION,
_future.getMaskGen());
}
public void onInputShutdown() throws IOException
{
_endp.close();
}
public boolean isIdle()
{
return false;
}
public boolean isSuspended()
{
return false;
}
public void onClose()
{
if (_error != null)
_future.handshakeFailed(new ProtocolException(_error));
else
_future.handshakeFailed(new EOFException());
}
}
private static class WebSocketClientConnection extends WebSocketConnectionRFC6455
{
private final WebSocketClientFactory factory;
public WebSocketClientConnection(WebSocketClientFactory factory, WebSocket webSocket, EndPoint endPoint, WebSocketBuffers buffers, long timeStamp, int maxIdleTime, String protocol, List extensions, int draftVersion, MaskGen maskGen) throws IOException
{
super(webSocket, endPoint, buffers, timeStamp, maxIdleTime, protocol, extensions, draftVersion, maskGen);
this.factory = factory;
}
@Override
public void onClose()
{
super.onClose();
factory.removeConnection(this);
}
}
}