com.neovisionaries.ws.client.HandshakeBuilder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of nv-websocket-client Show documentation
Show all versions of nv-websocket-client Show documentation
WebSocket client implementation in Java.
/*
* Copyright (C) 2015-2016 Neo Visionaries Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.neovisionaries.ws.client;
import java.net.URI;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
class HandshakeBuilder
{
private static final String[] CONNECTION_HEADER = { "Connection", "Upgrade" };
private static final String[] UPGRADE_HEADER = { "Upgrade", "websocket" };
private static final String[] VERSION_HEADER = { "Sec-WebSocket-Version", "13" };
private static final String RN = "\r\n";
private boolean mSecure;
private String mUserInfo;
private final String mHost;
private final String mPath;
private final URI mUri;
private String mKey;
private Set mProtocols;
private List mExtensions;
private List mHeaders;
public HandshakeBuilder(boolean secure, String userInfo, String host, String path)
{
mSecure = secure;
mUserInfo = userInfo;
mHost = host;
mPath = path;
// 'host' may contain ':{port}' at its end.
// 'path' may contain '?{query}' at its end.
mUri = URI.create(String.format("%s://%s%s",
(secure ? "wss" : "ws"), host, path));
}
public HandshakeBuilder(HandshakeBuilder source)
{
mSecure = source.mSecure;
mUserInfo = source.mUserInfo;
mHost = source.mHost;
mPath = source.mPath;
mUri = source.mUri;
mKey = source.mKey;
mProtocols = copyProtocols(source.mProtocols);
mExtensions = copyExtensions(source.mExtensions);
mHeaders = copyHeaders(source.mHeaders);
}
public void addProtocol(String protocol)
{
if (isValidProtocol(protocol) == false)
{
throw new IllegalArgumentException(
"'protocol' must be a non-empty string with characters in the range " +
"U+0021 to U+007E not including separator characters.");
}
synchronized (this)
{
if (mProtocols == null)
{
// 'LinkedHashSet' is used because the elements
// "MUST all be unique strings" and must be
// "ordered by preference. See RFC 6455, p18, 10.
mProtocols = new LinkedHashSet();
}
mProtocols.add(protocol);
}
}
public void removeProtocol(String protocol)
{
if (protocol == null)
{
return;
}
synchronized (this)
{
if (mProtocols == null)
{
return;
}
mProtocols.remove(protocol);
if (mProtocols.size() == 0)
{
mProtocols = null;
}
}
}
public void clearProtocols()
{
synchronized (this)
{
mProtocols = null;
}
}
private static boolean isValidProtocol(String protocol)
{
if (protocol == null || protocol.length() == 0)
{
return false;
}
int len = protocol.length();
for (int i = 0; i < len; ++i)
{
char ch = protocol.charAt(i);
if (ch < 0x21 || 0x7E < ch || Token.isSeparator(ch))
{
return false;
}
}
return true;
}
public boolean containsProtocol(String protocol)
{
synchronized (this)
{
if (mProtocols == null)
{
return false;
}
return mProtocols.contains(protocol);
}
}
public void addExtension(WebSocketExtension extension)
{
if (extension == null)
{
return;
}
synchronized (this)
{
if (mExtensions == null)
{
mExtensions = new ArrayList();
}
mExtensions.add(extension);
}
}
public void addExtension(String extension)
{
addExtension(WebSocketExtension.parse(extension));
}
public void removeExtension(WebSocketExtension extension)
{
if (extension == null)
{
return;
}
synchronized (this)
{
if (mExtensions == null)
{
return;
}
mExtensions.remove(extension);
if (mExtensions.size() == 0)
{
mExtensions = null;
}
}
}
public void removeExtensions(String name)
{
if (name == null)
{
return;
}
synchronized (this)
{
if (mExtensions == null)
{
return;
}
List extensionsToRemove = new ArrayList();
for (WebSocketExtension extension : mExtensions)
{
if (extension.getName().equals(name))
{
extensionsToRemove.add(extension);
}
}
for (WebSocketExtension extension : extensionsToRemove)
{
mExtensions.remove(extension);
}
if (mExtensions.size() == 0)
{
mExtensions = null;
}
}
}
public void clearExtensions()
{
synchronized (this)
{
mExtensions = null;
}
}
public boolean containsExtension(WebSocketExtension extension)
{
if (extension == null)
{
return false;
}
synchronized (this)
{
if (mExtensions == null)
{
return false;
}
return mExtensions.contains(extension);
}
}
public boolean containsExtension(String name)
{
if (name == null)
{
return false;
}
synchronized (this)
{
if (mExtensions == null)
{
return false;
}
for (WebSocketExtension extension : mExtensions)
{
if (extension.getName().equals(name))
{
return true;
}
}
return false;
}
}
public void addHeader(String name, String value)
{
if (name == null || name.length() == 0)
{
return;
}
if (value == null)
{
value = "";
}
synchronized (this)
{
if (mHeaders == null)
{
mHeaders = new ArrayList();
}
mHeaders.add(new String[] { name, value });
}
}
public void removeHeaders(String name)
{
if (name == null || name.length() == 0)
{
return;
}
synchronized (this)
{
if (mHeaders == null)
{
return;
}
List headersToRemove = new ArrayList();
for (String[] header : mHeaders)
{
if (header[0].equals(name))
{
headersToRemove.add(header);
}
}
for (String[] header : headersToRemove)
{
mHeaders.remove(header);
}
if (mHeaders.size() == 0)
{
mHeaders = null;
}
}
}
public void clearHeaders()
{
synchronized (this)
{
mHeaders = null;
}
}
public void setUserInfo(String userInfo)
{
synchronized (this)
{
mUserInfo = userInfo;
}
}
public void setUserInfo(String id, String password)
{
if (id == null)
{
id = "";
}
if (password == null)
{
password = "";
}
String userInfo = String.format("%s:%s", id, password);
setUserInfo(userInfo);
}
public void clearUserInfo()
{
synchronized (this)
{
mUserInfo = null;
}
}
public URI getURI()
{
return mUri;
}
public void setKey(String key)
{
mKey = key;
}
public String buildRequestLine()
{
return String.format("GET %s HTTP/1.1", mPath);
}
public List buildHeaders()
{
List headers = new ArrayList();
// Host
headers.add(new String[] { "Host", mHost } );
// Connection
headers.add(CONNECTION_HEADER);
// Upgrade
headers.add(UPGRADE_HEADER);
// Sec-WebSocket-Version
headers.add(VERSION_HEADER);
// Sec-WebSocket-Key
headers.add(new String[] { "Sec-WebSocket-Key", mKey } );
// Sec-WebSocket-Protocol
if (mProtocols != null && mProtocols.size() != 0)
{
headers.add(new String[] { "Sec-WebSocket-Protocol", Misc.join(mProtocols, ", ") } );
}
// Sec-WebSocket-Extensions
if (mExtensions != null && mExtensions.size() != 0)
{
headers.add(new String[] { "Sec-WebSocket-Extensions", Misc.join(mExtensions, ", ") } );
}
// Authorization: Basic
if (mUserInfo != null && mUserInfo.length() != 0)
{
headers.add(new String[] { "Authorization", "Basic " + Base64.encode(mUserInfo) } );
}
// Custom headers
if (mHeaders != null && mHeaders.size() != 0)
{
headers.addAll(mHeaders);
}
return headers;
}
public static String build(String requestLine, List headers)
{
StringBuilder builder = new StringBuilder();
// Append the request line, "GET {path} HTTP/1.1".
builder.append(requestLine).append(RN);
// For each header.
for (String[] header : headers)
{
// Append the header, "{name}: {value}".
builder.append(header[0]).append(": ").append(header[1]).append(RN);
}
// Append an empty line.
builder.append(RN);
return builder.toString();
}
private static Set copyProtocols(Set protocols)
{
if (protocols == null)
{
return null;
}
Set newProtocols = new LinkedHashSet(protocols.size());
newProtocols.addAll(protocols);
return newProtocols;
}
private static List copyExtensions(List extensions)
{
if (extensions == null)
{
return null;
}
List newExtensions =
new ArrayList(extensions.size());
for (WebSocketExtension extension : extensions)
{
newExtensions.add(new WebSocketExtension(extension));
}
return newExtensions;
}
private static List copyHeaders(List headers)
{
if (headers == null)
{
return null;
}
List newHeaders = new ArrayList(headers.size());
for (String[] header : headers)
{
newHeaders.add(copyHeader(header));
}
return newHeaders;
}
private static String[] copyHeader(String[] header)
{
String[] newHeader = new String[2];
newHeader[0] = header[0];
newHeader[1] = header[1];
return newHeader;
}
}