com.firefly.net.tcp.ssl.SSLSession Maven / Gradle / Ivy
package com.firefly.net.tcp.ssl;
import com.firefly.net.BufferPool;
import com.firefly.net.SSLContextFactory;
import com.firefly.net.SSLEventHandler;
import com.firefly.net.Session;
import com.firefly.net.buffer.FileRegion;
import com.firefly.net.buffer.ThreadSafeIOBufferPool;
import com.firefly.utils.concurrent.Callback;
import com.firefly.utils.concurrent.CountingCallback;
import com.firefly.utils.io.BufferReaderHandler;
import com.firefly.utils.io.BufferUtils;
import io.netty.handler.ssl.SslHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
public class SSLSession implements Closeable {
protected static final Logger log = LoggerFactory.getLogger("firefly-system");
private static final BufferPool bufferPool = new ThreadSafeIOBufferPool();
private final Session session;
private final SSLEngine sslEngine;
private ByteBuffer inNetBuffer;
private ByteBuffer outAppBuffer;
private static final int requestBufferSize = 1024 * 8;
private static final int writeBufferSize = 1024 * 8;
/*
* An empty ByteBuffer for use when one isn't available, say as a source
* buffer during initial handshake wraps or for close operations.
*/
private static final ByteBuffer hsBuffer = ByteBuffer.allocateDirect(0);
/*
* We have received the shutdown request by our caller, and have closed our
* outbound side.
*/
private boolean closed = false;
/*
* During our initial handshake, keep track of the next SSLEngine operation
* that needs to occur:
*
* NEED_WRAP/NEED_UNWRAP
*
* Once the initial handshake has completed, we can short circuit handshake
* checks with initialHSComplete.
*/
private HandshakeStatus initialHSStatus;
private boolean initialHSComplete;
private final SSLEventHandler sslEventHandler;
private final SslHandler sslHandler;
public SSLSession(SSLContextFactory factory, boolean clientMode, Session session, SSLEventHandler sslEventHandler) throws Throwable {
this(factory.createSSLEngine(clientMode), session, sslEventHandler);
}
private SSLSession(SSLEngine sslEngine, Session session, SSLEventHandler sslEventHandler) throws Throwable {
this.session = session;
this.sslEventHandler = sslEventHandler;
this.sslEngine = sslEngine;
outAppBuffer = ByteBuffer.allocate(requestBufferSize);
initialHSComplete = false;
sslHandler = new SslHandler(sslEngine);
// start tls
this.sslEngine.beginHandshake();
initialHSStatus = sslEngine.getHandshakeStatus();
if (sslEngine.getUseClientMode()) {
doHandshakeResponse();
}
}
/**
* The initial handshake is a procedure by which the two peers exchange
* communication parameters until an SSLSession is established. Application
* data can not be sent during this phase.
*
* @param receiveBuffer Encrypted message
* @return True means handshake success
* @throws IOException A runtime exception
*/
private boolean doHandshake(ByteBuffer receiveBuffer) throws IOException {
if (!session.isOpen()) {
sslEngine.closeInbound();
return (initialHSComplete = false);
}
if (initialHSComplete) {
return true;
}
switch (initialHSStatus) {
case NOT_HANDSHAKING:
case FINISHED: {
handshakeFinish();
return initialHSComplete;
}
case NEED_UNWRAP:
doHandshakeReceive(receiveBuffer);
if (initialHSStatus != HandshakeStatus.NEED_WRAP)
break;
case NEED_WRAP:
doHandshakeResponse();
break;
default: // NEED_TASK
throw new RuntimeException("Invalid Handshaking State" + initialHSStatus);
}
return initialHSComplete;
}
private void doHandshakeReceive(ByteBuffer receiveBuffer) throws IOException {
SSLEngineResult result;
merge(receiveBuffer);
needIO:
while (initialHSStatus == HandshakeStatus.NEED_UNWRAP) {
unwrap:
while (true) {
int netSize = sslEngine.getSession().getPacketBufferSize();
List inNetBuffers = BufferUtils.split(inNetBuffer, netSize);
for (ByteBuffer net : inNetBuffers) {
//FIXME using direct buffer avoid netty ByteBufAllocator bug
ByteBuffer directTmpBuffer = bufferPool.acquire(net.remaining());
try {
directTmpBuffer.put(net.slice()).flip();
result = sslEngine.unwrap(directTmpBuffer, outAppBuffer);
} finally {
bufferPool.release(directTmpBuffer);
}
int consumed = result.bytesConsumed();
inNetBuffer.position(inNetBuffer.position() + consumed);
net.position(net.position() + consumed);
initialHSStatus = result.getHandshakeStatus();
if (log.isDebugEnabled()) {
log.debug("session {} handshake receives data, init: {} | ret: {} | complete: {} ",
session.getSessionId(), initialHSStatus, result.getStatus(), initialHSComplete);
}
switch (result.getStatus()) {
case OK:
switch (initialHSStatus) {
case NEED_TASK:
initialHSStatus = doTasks();
break;
case NOT_HANDSHAKING:
case FINISHED:
handshakeFinish();
break needIO;
default:
break;
}
break unwrap;
case BUFFER_UNDERFLOW:
switch (initialHSStatus) {
case NOT_HANDSHAKING:
case FINISHED:
handshakeFinish();
break needIO;
}
break needIO;
case BUFFER_OVERFLOW:
// Reset the application buffer size.
int appSize = sslEngine.getSession().getApplicationBufferSize();
ByteBuffer b = ByteBuffer.allocate(appSize + outAppBuffer.position());
outAppBuffer.flip();
b.put(outAppBuffer);
outAppBuffer = b;
// retry the operation.
break;
default: // CLOSED:
throw new IOException("Received" + result.getStatus() + "during initial handshaking");
}
}
} // "unwrap" block.
} // "needIO" block.
}
private void handshakeFinish() {
log.info("session {} handshake success!", session.getSessionId());
initialHSComplete = true;
sslEventHandler.handshakeFinished(this);
}
private void doHandshakeResponse() throws IOException {
while (initialHSStatus == HandshakeStatus.NEED_WRAP) {
SSLEngineResult result;
ByteBuffer writeBuf = ByteBuffer.allocateDirect(sslEngine.getSession().getPacketBufferSize());
wrap:
while (true) {
result = sslEngine.wrap(hsBuffer, writeBuf);
initialHSStatus = result.getHandshakeStatus();
if (log.isDebugEnabled()) {
log.debug("session {} handshake response, init: {} | ret: {} | complete: {} ",
session.getSessionId(), initialHSStatus, result.getStatus(), initialHSComplete);
}
switch (result.getStatus()) {
case OK:
if (initialHSStatus == HandshakeStatus.NEED_TASK) {
initialHSStatus = doTasks();
}
writeBuf.flip();
session.write(writeBuf, Callback.NOOP);
break wrap;
case BUFFER_OVERFLOW:
int netSize = sslEngine.getSession().getPacketBufferSize();
ByteBuffer b = ByteBuffer.allocate(writeBuf.position() + netSize);
writeBuf.flip();
b.put(writeBuf);
writeBuf = b;
break;
default: // BUFFER_UNDERFLOW, CLOSED:
throw new IOException("Received " + result.getStatus() + " during initial handshaking");
}
}
}
}
private void merge(ByteBuffer now) {
if (!now.hasRemaining())
return;
if (inNetBuffer != null) {
if (inNetBuffer.hasRemaining()) {
ByteBuffer ret = ByteBuffer.allocate(inNetBuffer.remaining() + now.remaining());
ret.put(inNetBuffer).put(now).flip();
inNetBuffer = ret;
} else {
inNetBuffer = now;
}
} else {
inNetBuffer = now;
}
}
private ByteBuffer getOutAppBuffer() {
outAppBuffer.flip();
if (outAppBuffer.hasRemaining()) {
ByteBuffer buf = ByteBuffer.allocate(outAppBuffer.remaining());
buf.put(outAppBuffer).flip();
outAppBuffer = ByteBuffer.allocate(requestBufferSize);
if (log.isDebugEnabled()) {
log.debug("SSL session {} unwrap, app buffer -> {}", session.getSessionId(), buf.remaining());
}
return buf;
} else {
return null;
}
}
/**
* Do all the outstanding handshake tasks in the current Thread.
*
* @return The result of handshake
*/
private SSLEngineResult.HandshakeStatus doTasks() {
Runnable runnable;
// We could run this in a separate thread, but do in the current for
// now.
while ((runnable = sslEngine.getDelegatedTask()) != null) {
runnable.run();
}
return sslEngine.getHandshakeStatus();
}
@Override
public void close() throws IOException {
if (!closed) {
// log.debug("close SSL engine, {}|{}", sslEngine.isInboundDone(),
// sslEngine.isOutboundDone());
sslEngine.closeOutbound();
closed = true;
}
}
public String applicationProtocol() {
return sslHandler.applicationProtocol();
}
public boolean isOpen() {
return !closed;
}
/**
* This method is used to decrypt data, it implied do handshake
*
* @param receiveBuffer Encrypted message
* @return plaintext
* @throws IOException sslEngine error during data read
*/
public ByteBuffer read(ByteBuffer receiveBuffer) throws IOException {
if (!doHandshake(receiveBuffer))
return null;
if (!initialHSComplete)
throw new IllegalStateException("The initial handshake is not complete.");
if (log.isDebugEnabled()) {
log.debug("SSL read current session {} status -> {}", session.getSessionId(), session.isOpen());
}
merge(receiveBuffer);
if (!inNetBuffer.hasRemaining()) {
return null;
}
//split net buffer when the net buffer remaining great than the net size
int netSize = sslEngine.getSession().getPacketBufferSize();
List inNetBuffers = BufferUtils.split(inNetBuffer, netSize);
for (ByteBuffer net : inNetBuffers) {
SSLEngineResult result;
while (true) {
if (log.isDebugEnabled()) {
log.debug("SSL session {} unwrap, pocket -> {}, in -> {}, out -> {}, temp -> {}",
session.getSessionId(), netSize, inNetBuffer.remaining(), outAppBuffer.remaining(),
net.remaining());
}
//FIXME using direct buffer avoid netty ByteBufAllocator bug
ByteBuffer directTmpBuffer = bufferPool.acquire(net.remaining());
try {
directTmpBuffer.put(net.slice()).flip();
result = sslEngine.unwrap(directTmpBuffer, outAppBuffer);
} finally {
bufferPool.release(directTmpBuffer);
}
int consumed = result.bytesConsumed();
inNetBuffer.position(inNetBuffer.position() + consumed);
net.position(net.position() + consumed);
if (log.isDebugEnabled()) {
log.debug("SSL session {} unwrap, status -> {}, in -> {}, out -> {}, temp -> {}, consumed -> {}",
session.getSessionId(), result.getStatus(), inNetBuffer.remaining(), outAppBuffer.remaining(),
net.remaining(), consumed);
}
switch (result.getStatus()) {
case BUFFER_OVERFLOW:
// Reset the application buffer size.
int appSize = sslEngine.getSession().getApplicationBufferSize();
ByteBuffer b = ByteBuffer.allocate(appSize + outAppBuffer.position());
outAppBuffer.flip();
b.put(outAppBuffer);
outAppBuffer = b;
// retry the operation.
break;
case BUFFER_UNDERFLOW:
return getOutAppBuffer();
case OK:
if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
doTasks();
}
if (!inNetBuffer.hasRemaining()) {
return getOutAppBuffer();
}
break;
default:
throw new IOException("sslEngine error during data read: " + result.getStatus());
}
}
}
return getOutAppBuffer();
}
public int write(ByteBuffer[] outputBuffers, Callback callback) throws Throwable {
int ret = 0;
CountingCallback countingCallback = new CountingCallback(callback, outputBuffers.length);
for (ByteBuffer outputBuffer : outputBuffers) {
ret += write(outputBuffer, countingCallback);
}
return ret;
}
/**
* This method is used to encrypt and flush to socket channel
*
* @param outputBuffer Plaintext message
* @return writen length
* @throws IOException sslEngine error during data write
*/
public int write(ByteBuffer outputBuffer, Callback callback) throws IOException {
if (!initialHSComplete)
throw new IllegalStateException("The initial handshake is not complete.");
int ret = 0;
if (!outputBuffer.hasRemaining())
return ret;
final int remain = outputBuffer.remaining();
while (ret < remain) {
ByteBuffer writeBuf = ByteBuffer.allocateDirect(writeBufferSize);
wrap:
while (true) {
SSLEngineResult result;
//FIXME using direct buffer avoid netty ByteBufAllocator bug
ByteBuffer directTmpBuffer = bufferPool.acquire(outputBuffer.remaining());
try {
directTmpBuffer.put(outputBuffer.slice()).flip();
result = sslEngine.wrap(directTmpBuffer, writeBuf);
} finally {
bufferPool.release(directTmpBuffer);
}
ret += result.bytesConsumed();
switch (result.getStatus()) {
case OK:
if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
doTasks();
}
writeBuf.flip();
session.write(writeBuf, callback);
break wrap;
case BUFFER_OVERFLOW:
int netSize = sslEngine.getSession().getPacketBufferSize();
ByteBuffer b = ByteBuffer.allocateDirect(writeBuf.position() + netSize);
writeBuf.flip();
b.put(writeBuf);
writeBuf = b;
// retry the operation.
break;
default:
throw new IOException("sslEngine error during data write: " + result.getStatus());
}
}
}
return ret;
}
private class FileBufferReaderHandler implements BufferReaderHandler {
private final long len;
private FileBufferReaderHandler(long len) {
this.len = len;
}
@Override
public void readBuffer(ByteBuffer buf, CountingCallback countingCallback, long count) {
log.debug("write file, count: {} , lenth: {}", count, len);
try {
write(buf, countingCallback);
} catch (Throwable e) {
log.error("ssl session writing error", e);
}
}
}
public long transferFileRegion(FileRegion file, Callback callback) throws Throwable {
long ret = 0;
try (FileRegion fileRegion = file) {
fileRegion.transferTo(callback, new FileBufferReaderHandler(file.getLength()));
}
return ret;
}
public boolean isHandshakeFinished() {
return initialHSComplete;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy