org.xnio.ssl.JsseStreamConduit Maven / Gradle / Ivy
Go to download
This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including
all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and
JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up
with different versions on classes on the class path).
/*
* JBoss, Home of Professional Open Source
*
* Copyright 2013 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.xnio.ssl;
import static java.lang.Math.max;
import static java.lang.Thread.currentThread;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreClear;
import static org.xnio.Bits.anyAreSet;
import static org.xnio._private.Messages.msg;
import java.io.EOFException;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.xnio.Buffers;
import org.xnio.Pool;
import org.xnio.Pooled;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.ConduitReadableByteChannel;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.ReadReadyHandler;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.StreamSourceConduit;
import org.xnio.conduits.WriteReadyHandler;
final class JsseStreamConduit implements StreamSourceConduit, StreamSinkConduit, Runnable {
private static final boolean TRACE_SSL = Boolean.getBoolean("org.xnio.ssl.TRACE_SSL");
//================================================================
//
// Immutable state
//
//================================================================
private final JsseSslConnection connection;
private final SSLEngine engine;
private final StreamSourceConduit sourceConduit;
private final StreamSinkConduit sinkConduit;
/** The buffer into which incoming SSL data is written. */
private final Pooled receiveBuffer;
/** The buffer from which outbound SSL data is sent. */
private final Pooled sendBuffer;
/** The buffer into which inbound clear data is written. */
private final Pooled readBuffer;
//================================================================
//
// Mutable state
//
//================================================================
// always inline tasks, for now
private int state = FLAG_INLINE_TASKS;
// tasks counter - protected by {@code this}
private int tasks;
private ReadReadyHandler readReadyHandler;
private WriteReadyHandler writeReadyHandler;
//================================================================
//
// Constructors
//
//================================================================
JsseStreamConduit(final JsseSslConnection connection, final SSLEngine engine, final StreamSourceConduit sourceConduit, final StreamSinkConduit sinkConduit, final Pool socketBufferPool, final Pool applicationBufferPool) {
Pooled receiveBuffer;
Pooled sendBuffer;
Pooled readBuffer;
boolean ok = false;
final SSLSession session = engine.getSession();
final int packetBufferSize = session.getPacketBufferSize();
receiveBuffer = socketBufferPool.allocate();
try {
receiveBuffer.getResource().flip();
sendBuffer = socketBufferPool.allocate();
try {
if (receiveBuffer.getResource().capacity() < packetBufferSize || sendBuffer.getResource().capacity() < packetBufferSize) {
throw msg.socketBufferTooSmall();
}
final int applicationBufferSize = session.getApplicationBufferSize();
readBuffer = applicationBufferPool.allocate();
try {
if (readBuffer.getResource().capacity() < applicationBufferSize) {
throw msg.appBufferTooSmall();
}
ok = true;
} finally {
if (! ok) readBuffer.free();
}
} finally {
if (! ok) sendBuffer.free();
}
} finally {
if (! ok) receiveBuffer.free();
}
this.receiveBuffer = receiveBuffer;
this.sendBuffer = sendBuffer;
this.readBuffer = readBuffer;
receiveBuffer.getResource().clear().limit(0);
if (sourceConduit.getReadThread() != sinkConduit.getWriteThread()) {
throw new IllegalArgumentException("Source and sink thread mismatch");
}
this.connection = connection;
this.engine = engine;
this.sourceConduit = sourceConduit;
this.sinkConduit = sinkConduit;
sourceConduit.setReadReadyHandler(readReady);
sinkConduit.setWriteReadyHandler(writeReady);
}
//================================================================
//
// Flags
//
//================================================================
// global flags
// global-_write----_read----;
/** TLS is enabled */
private static final int FLAG_TLS = 0b00001_000000000_00000000;
/** Run tasks immediately rather than in the worker */
private static final int FLAG_INLINE_TASKS = 0b00010_000000000_00000000;
/** Set when task is queued; cleared when it runs */
private static final int FLAG_TASK_QUEUED = 0b00100_000000000_00000000;
/** Set when the engine needs the result of a delegated task to continue */
private static final int FLAG_NEED_ENGINE_TASK = 0b01000_000000000_00000000;
/** we need a flush to proceed with handshake */
private static final int FLAG_FLUSH_NEEDED = 0b10000_000000000_00000000;
// read state flags
/** shutdownReads() was called upstream */
private static final int READ_FLAG_SHUTDOWN = 0b00000_000000000_00000001;
/** -1 was read */
private static final int READ_FLAG_EOF = 0b00000_000000000_00000010;
/** resumeReads() called */
private static final int READ_FLAG_RESUMED = 0b00000_000000000_00000100;
/** upstream resumeReads() called */
private static final int READ_FLAG_UP_RESUMED = 0b00000_000000000_00001000;
/** wakeupReads() called */
private static final int READ_FLAG_WAKEUP = 0b00000_000000000_00010000;
/** user read handler should be called if resumed */
private static final int READ_FLAG_READY = 0b00000_000000000_00100000;
/** read needs wrap to proceed */
private static final int READ_FLAG_NEEDS_WRITE = 0b00000_000000000_01000000;
// write state flags
/** shutdownWrites() was called */
private static final int WRITE_FLAG_SHUTDOWN = 0b00000_000000001_00000000;
/** send buffer cleared for final wrap after shutdownWrites() */
private static final int WRITE_FLAG_SHUTDOWN2 = 0b00000_000000010_00000000;
/** wrapping was completed after shutdownWrites() was called */
private static final int WRITE_FLAG_SHUTDOWN3 = 0b00000_000000100_00000000;
/** flush() returned true after shutdown */
private static final int WRITE_FLAG_FINISHED = 0b00000_000001000_00000000;
/** resumeWrites() called */
private static final int WRITE_FLAG_RESUMED = 0b00000_000010000_00000000;
/** upstream resumeWrites() called */
private static final int WRITE_FLAG_UP_RESUMED = 0b00000_000100000_00000000;
/** wakeupWrites() called */
private static final int WRITE_FLAG_WAKEUP = 0b00000_001000000_00000000;
/** user write handler should be called if resumed */
private static final int WRITE_FLAG_READY = 0b00000_010000000_00000000;
/** write needs unwrap to proceed */
private static final int WRITE_FLAG_NEEDS_READ = 0b00000_100000000_00000000;
public String getStatus() {
final StringBuilder b = new StringBuilder();
b.append("General flags:");
final int state = this.state;
if (allAreSet(state, FLAG_TLS)) b.append(" TLS");
if (allAreSet(state, FLAG_INLINE_TASKS)) b.append(" INLINE_TASKS");
if (allAreSet(state, FLAG_TASK_QUEUED)) b.append(" TASK_QUEUED");
if (allAreSet(state, FLAG_NEED_ENGINE_TASK)) b.append(" NEED_ENGINE_TASK");
if (allAreSet(state, FLAG_FLUSH_NEEDED)) b.append(" FLUSH_NEEDED");
b.append("\nRead flags:");
if (allAreSet(state, READ_FLAG_SHUTDOWN)) b.append(" SHUTDOWN");
if (allAreSet(state, READ_FLAG_EOF)) b.append(" EOF");
if (allAreSet(state, READ_FLAG_RESUMED)) b.append(" RESUMED");
if (allAreSet(state, READ_FLAG_UP_RESUMED)) b.append(" UP_RESUMED");
if (allAreSet(state, READ_FLAG_WAKEUP)) b.append(" WAKEUP");
if (allAreSet(state, READ_FLAG_READY)) b.append(" READY");
if (allAreSet(state, READ_FLAG_NEEDS_WRITE)) b.append(" NEEDS_WRITE");
b.append("\nWrite flags:");
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) b.append(" SHUTDOWN");
if (allAreSet(state, WRITE_FLAG_SHUTDOWN2)) b.append(" SHUTDOWN2");
if (allAreSet(state, WRITE_FLAG_SHUTDOWN3)) b.append(" SHUTDOWN3");
if (allAreSet(state, WRITE_FLAG_FINISHED)) b.append(" FINISHED");
if (allAreSet(state, WRITE_FLAG_RESUMED)) b.append(" RESUMED");
if (allAreSet(state, WRITE_FLAG_UP_RESUMED)) b.append(" UP_RESUMED");
if (allAreSet(state, WRITE_FLAG_WAKEUP)) b.append(" WAKEUP");
if (allAreSet(state, WRITE_FLAG_READY)) b.append(" READY");
if (allAreSet(state, WRITE_FLAG_NEEDS_READ)) b.append(" NEEDS_READ");
b.append('\n');
return b.toString();
}
public String toString() {
return String.format("JSSE Stream Conduit for %s, status:%n%s", connection, getStatus());
}
//================================================================
//
// Global API
//
//================================================================
public XnioWorker getWorker() {
return connection.getIoThread().getWorker();
}
public XnioIoThread getReadThread() {
return connection.getIoThread();
}
public XnioIoThread getWriteThread() {
return connection.getIoThread();
}
private final WriteReadyHandler writeReady = new WriteReadyHandler() {
@Override
public void forceTermination() {
if (anyAreClear(state, WRITE_FLAG_FINISHED)) {
state |= WRITE_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN2 | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
}
final WriteReadyHandler writeReadyHandler = JsseStreamConduit.this.writeReadyHandler;
if (writeReadyHandler != null) try {
writeReadyHandler.forceTermination();
} catch (Throwable ignored) {
}
}
@Override
public void terminated() {
if (anyAreClear(state, WRITE_FLAG_FINISHED)) {
state |= WRITE_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN2 | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
}
final WriteReadyHandler writeReadyHandler = JsseStreamConduit.this.writeReadyHandler;
if (writeReadyHandler != null) try {
writeReadyHandler.terminated();
} catch (Throwable ignored) {
}
}
@Override
public void writeReady() {
JsseStreamConduit.this.writeReady();
}
};
private final ReadReadyHandler readReady = new ReadReadyHandler() {
@Override
public void forceTermination() {
if (anyAreClear(state, READ_FLAG_SHUTDOWN)) {
state |= READ_FLAG_SHUTDOWN;
}
final ReadReadyHandler readReadyHandler = JsseStreamConduit.this.readReadyHandler;
if (readReadyHandler != null) try {
readReadyHandler.forceTermination();
} catch (Throwable ignored) {
}
}
@Override
public void terminated() {
if (anyAreClear(state, READ_FLAG_SHUTDOWN)) {
state |= READ_FLAG_SHUTDOWN;
}
final ReadReadyHandler readReadyHandler = JsseStreamConduit.this.readReadyHandler;
if (readReadyHandler != null) try {
readReadyHandler.terminated();
} catch (Throwable ignored) {
}
}
@Override
public void readReady() {
JsseStreamConduit.this.readReady();
}
};
// non-public
void beginHandshake() throws IOException {
final int state = this.state;
if (anyAreSet(state, READ_FLAG_EOF | WRITE_FLAG_SHUTDOWN)) {
throw new ClosedChannelException();
}
if (allAreClear(state, FLAG_TLS)) {
this.state = state | FLAG_TLS;
}
engine.beginHandshake();
}
SSLSession getSslSession() {
return allAreSet(state, FLAG_TLS) ? engine.getSession() : null;
}
SSLEngine getEngine() {
return engine;
}
boolean isTls() {
return allAreSet(state, FLAG_TLS);
}
boolean markTerminated() {
readBuffer.free();
receiveBuffer.free();
sendBuffer.free();
if (anyAreClear(state, READ_FLAG_SHUTDOWN | WRITE_FLAG_FINISHED)) {
state |= READ_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN2 | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
return true;
} else {
return false;
}
}
//================================================================
//
// Main run task
//
//================================================================
public void run() {
assert currentThread() == getWriteThread();
int state = JsseStreamConduit.this.state;
final boolean flagTaskQueued = allAreSet(state, FLAG_TASK_QUEUED);
boolean modify = flagTaskQueued;
boolean queueTask = false;
state &= ~FLAG_TASK_QUEUED;
try {
// task(s)
if (allAreSet(state, FLAG_NEED_ENGINE_TASK)) {
throw new UnsupportedOperationException();
}
// write side
if (anyAreSet(state, WRITE_FLAG_WAKEUP) || allAreSet(state, WRITE_FLAG_RESUMED | WRITE_FLAG_READY)) {
final WriteReadyHandler writeReadyHandler = JsseStreamConduit.this.writeReadyHandler;
if (allAreSet(state, WRITE_FLAG_WAKEUP)) {
state = state & ~WRITE_FLAG_WAKEUP | WRITE_FLAG_RESUMED;
modify = true;
}
if (writeReadyHandler != null) {
if (allAreSet(state, WRITE_FLAG_RESUMED)) {
try {
// save flags -------------------------------+
if (modify) { // |
modify = false; // |
JsseStreamConduit.this.state = state; // |
} // |
writeReadyHandler.writeReady(); // |
} catch (Throwable ignored) { // |
} finally { // |
// restore flags <---------------------------+
// it is OK if this is stale
state = JsseStreamConduit.this.state & ~FLAG_TASK_QUEUED;
// make sure we flag to save state with no FLAG_TASK_QUEUED
modify = true;
}
// Thread safety notice:
//---> We must not modify flags unless read and/or write is still resumed; otherwise, the user might
// be doing something in another thread and we could end up overwriting each others' changes.
// level-triggering
if (allAreSet(state, WRITE_FLAG_RESUMED)) {
if (!allAreSet(state, WRITE_FLAG_READY) && allAreSet(state, WRITE_FLAG_NEEDS_READ) && allAreClear(state, READ_FLAG_UP_RESUMED)) {
state |= READ_FLAG_UP_RESUMED;
modify = true;
sourceConduit.resumeReads();
} else if (allAreClear(state, WRITE_FLAG_UP_RESUMED)) {
sinkConduit.resumeWrites();
}
}
} else {
if (allAreClear(state, READ_FLAG_NEEDS_WRITE | READ_FLAG_RESUMED) && allAreSet(state, WRITE_FLAG_UP_RESUMED)) {
state &= ~WRITE_FLAG_UP_RESUMED;
modify = true;
suspendWrites();
}
}
} else {
// no handler, we should not be resumed
state &= ~WRITE_FLAG_RESUMED;
modify = true;
if (allAreClear(state, READ_FLAG_NEEDS_WRITE | READ_FLAG_RESUMED) && allAreSet(state, WRITE_FLAG_UP_RESUMED)) {
state &= ~WRITE_FLAG_UP_RESUMED;
modify = true;
suspendWrites();
}
}
}
// read side
if (anyAreSet(state, READ_FLAG_WAKEUP) || allAreSet(state, READ_FLAG_RESUMED | READ_FLAG_READY)) {
final ReadReadyHandler readReadyHandler = JsseStreamConduit.this.readReadyHandler;
if (allAreSet(state, READ_FLAG_WAKEUP)) {
state = state & ~READ_FLAG_WAKEUP | READ_FLAG_RESUMED;
modify = true;
}
if (readReadyHandler != null) {
if (allAreSet(state, READ_FLAG_RESUMED)) {
try {
// save flags -------------------------------+
if (modify) { // |
modify = false; // |
JsseStreamConduit.this.state = state; // |
} // |
readReadyHandler.readReady(); // |
} catch (Throwable ignored) { // |
} finally { // |
// restore flags <---------------------------+
// it is OK if this is stale
state = JsseStreamConduit.this.state & ~FLAG_TASK_QUEUED;
modify = true;
}
// Thread safety notice:
//---> We must not modify flags unless read and/or write is still resumed; otherwise, the user might
// be doing something in another thread and we could end up overwriting each others' changes.
// level-triggering
if (allAreSet(state, READ_FLAG_RESUMED)) {
if (allAreSet(state, READ_FLAG_READY)) {
if (!flagTaskQueued) {
state |= FLAG_TASK_QUEUED;
modify = queueTask = true;
}
} else if (allAreSet(state, READ_FLAG_NEEDS_WRITE) && allAreClear(state, WRITE_FLAG_UP_RESUMED)) {
state |= WRITE_FLAG_UP_RESUMED;
modify = true;
sinkConduit.resumeWrites();
} else if (allAreClear(state, READ_FLAG_UP_RESUMED)) {
sourceConduit.resumeReads();
}
}
} else {
if (allAreClear(state, WRITE_FLAG_NEEDS_READ | WRITE_FLAG_RESUMED) && allAreSet(state, READ_FLAG_UP_RESUMED)) {
state &= ~READ_FLAG_UP_RESUMED;
modify = true;
suspendReads();
}
}
} else {
// no handler, we should not be resumed
state &= ~READ_FLAG_RESUMED;
modify = true;
if (allAreClear(state, WRITE_FLAG_NEEDS_READ | WRITE_FLAG_RESUMED) && allAreSet(state, READ_FLAG_UP_RESUMED)) {
state &= ~READ_FLAG_UP_RESUMED;
suspendReads();
}
}
}
} finally {
if (modify) {
JsseStreamConduit.this.state = state;
// execute this on read thread only after updating the state
if (queueTask) getReadThread().execute(this);
}
}
}
//================================================================
//
// Ready handlers
//
//================================================================
public void setWriteReadyHandler(final WriteReadyHandler handler) {
this.writeReadyHandler = handler;
}
public void setReadReadyHandler(final ReadReadyHandler handler) {
this.readReadyHandler = handler;
}
/**
* Called by the upstream conduit when writes are ready.
*/
public void writeReady() {
int state = this.state;
state |= WRITE_FLAG_READY;
if (allAreSet(state, READ_FLAG_NEEDS_WRITE)) {
state |= READ_FLAG_READY;
}
this.state = state;
// avoid double-fire
if (allAreClear(state, FLAG_TASK_QUEUED)) {
run();
}
state = this.state;
if (sinkConduit.isWriteResumed() && allAreClear(state, WRITE_FLAG_RESUMED | READ_FLAG_NEEDS_WRITE)) {
sinkConduit.suspendWrites();
}
if (sourceConduit.isReadResumed() && allAreClear(state, READ_FLAG_RESUMED | WRITE_FLAG_NEEDS_READ)) {
sourceConduit.suspendReads();
}
}
/**
* Called by the upstream conduit when reads are ready.
*/
public void readReady() {
int state = this.state;
state |= READ_FLAG_READY;
if (allAreSet(state, WRITE_FLAG_NEEDS_READ)) {
state |= WRITE_FLAG_READY;
}
this.state = state;
// avoid double-fire
if (allAreClear(state, FLAG_TASK_QUEUED)) {
run();
}
state = this.state;
if (sourceConduit.isReadResumed() && allAreClear(state, READ_FLAG_RESUMED | WRITE_FLAG_NEEDS_READ)) {
sourceConduit.suspendReads();
}
if (sinkConduit.isWriteResumed() && allAreClear(state, WRITE_FLAG_RESUMED | READ_FLAG_NEEDS_WRITE)) {
sinkConduit.suspendWrites();
}
}
//================================================================
//
// Resume state management
//
//================================================================
// writes
public void suspendWrites() {
int state = this.state;
try {
if (allAreSet(state, WRITE_FLAG_RESUMED)) {
state &= ~WRITE_FLAG_RESUMED;
if (allAreSet(state, WRITE_FLAG_UP_RESUMED) && allAreClear(state, READ_FLAG_NEEDS_WRITE)) {
// upstream writes were resumed but now may be suspended
state &= ~WRITE_FLAG_UP_RESUMED;
sinkConduit.suspendWrites();
}
if (allAreSet(state, READ_FLAG_UP_RESUMED) && allAreClear(state, READ_FLAG_RESUMED)) {
// reads were likely resumed because of us; clear it
state &= ~READ_FLAG_UP_RESUMED;
sourceConduit.suspendReads();
}
}
} finally {
this.state = state;
}
}
public void resumeWrites() {
int state = this.state;
if (allAreClear(state, WRITE_FLAG_RESUMED)) {
if (allAreSet(state, WRITE_FLAG_FINISHED)) {
// just re-call the handler one time
wakeupWrites();
return;
}
boolean queueTask = false;
try {
state |= WRITE_FLAG_RESUMED;
if (allAreSet(state, WRITE_FLAG_READY)) {
if (queueTask = allAreClear(state, FLAG_TASK_QUEUED)) {
state |= FLAG_TASK_QUEUED;
}
} else if (allAreSet(state, WRITE_FLAG_NEEDS_READ) && allAreClear(state, READ_FLAG_UP_RESUMED)) {
// need to resume reads to make this happen
state |= READ_FLAG_UP_RESUMED;
sourceConduit.resumeReads();
} else if (allAreClear(state, WRITE_FLAG_UP_RESUMED)) {
// upstream writes were not resumed
state |= WRITE_FLAG_UP_RESUMED;
sinkConduit.resumeWrites();
}
} finally {
this.state = state;
// execute this on read thread only after updating the state
if (queueTask) getReadThread().execute(this);
}
}
}
public void wakeupWrites() {
final int state = this.state;
if (allAreClear(state, WRITE_FLAG_WAKEUP)) {
if (allAreClear(state, FLAG_TASK_QUEUED)) {
this.state = state | WRITE_FLAG_WAKEUP | FLAG_TASK_QUEUED;
getReadThread().execute(this);
} else {
this.state = state | WRITE_FLAG_WAKEUP;
}
}
}
public void terminateWrites() throws IOException {
int state = this.state;
if (allAreClear(state, WRITE_FLAG_FINISHED)) {
this.state = state | WRITE_FLAG_SHUTDOWN;
if (allAreSet(state, FLAG_TLS)) try {
if (engine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) {
engine.closeOutbound();
}
performIO(IO_GOAL_FLUSH, NO_BUFFERS, 0, 0, NO_BUFFERS, 0, 0);
if (allAreSet(this.state, WRITE_FLAG_FINISHED)) {
sinkConduit.terminateWrites();
}
} catch (Throwable t) {
this.state |= WRITE_FLAG_FINISHED;
try {
sinkConduit.truncateWrites();
} catch (Throwable t2) {
t.addSuppressed(t2);
}
throw t;
} else {
sinkConduit.terminateWrites();
}
}
}
public void truncateWrites() throws IOException {
int state = this.state;
if (allAreClear(state, WRITE_FLAG_SHUTDOWN)) {
if (allAreSet(state, FLAG_TLS)) try {
state |= WRITE_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
try {
engine.closeOutbound();
} catch (Throwable t) {
try {
sinkConduit.truncateWrites();
} catch (Throwable t2) {
t.addSuppressed(t2);
}
throw t;
}
sinkConduit.truncateWrites();
} finally {
this.state = state;
} else {
this.state = state | WRITE_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
sinkConduit.truncateWrites();
}
}
}
// queries
public boolean isWriteResumed() {
return anyAreSet(state, WRITE_FLAG_RESUMED | WRITE_FLAG_WAKEUP);
}
public boolean isWriteShutdown() {
return allAreSet(state, WRITE_FLAG_SHUTDOWN);
}
// await
public void awaitWritable() throws IOException {
int state = this.state;
while (allAreSet(state, FLAG_NEED_ENGINE_TASK)) {
synchronized (this) {
while (tasks != 0) {
try {
wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new InterruptedIOException();
}
}
state &= ~FLAG_NEED_ENGINE_TASK;
this.state = state;
}
}
if (allAreClear(state, WRITE_FLAG_READY)) {
if (allAreSet(state, WRITE_FLAG_NEEDS_READ)) {
sourceConduit.awaitReadable();
} else {
sinkConduit.awaitWritable();
}
}
}
public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
int state = this.state;
long nanos = timeUnit.toNanos(time);
while (allAreSet(state, FLAG_NEED_ENGINE_TASK)) {
synchronized (this) {
long start = System.nanoTime();
while (tasks != 0) {
try {
if (nanos <= 0) {
return;
}
wait(nanos / 1_000_000, (int) (nanos % 1_000_000));
nanos -= -start + (start = System.nanoTime());
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new InterruptedIOException();
}
}
state &= ~FLAG_NEED_ENGINE_TASK;
this.state = state;
}
}
if (allAreClear(state, WRITE_FLAG_READY)) {
if (allAreSet(state, WRITE_FLAG_NEEDS_READ)) {
sourceConduit.awaitReadable(nanos, TimeUnit.NANOSECONDS);
} else {
sinkConduit.awaitWritable(nanos, TimeUnit.NANOSECONDS);
}
}
}
// reads
public void suspendReads() {
int state = this.state;
try {
if (allAreSet(state, READ_FLAG_RESUMED)) {
state &= ~READ_FLAG_RESUMED;
if (allAreSet(state, READ_FLAG_UP_RESUMED) && allAreClear(state, WRITE_FLAG_NEEDS_READ)) {
// upstream reads were resumed but now may be suspended
state &= ~READ_FLAG_UP_RESUMED;
sourceConduit.suspendReads();
}
if (allAreSet(state, WRITE_FLAG_UP_RESUMED) && allAreClear(state, WRITE_FLAG_RESUMED)) {
// writes were likely resumed because of us; clear it
state &= ~WRITE_FLAG_UP_RESUMED;
sinkConduit.suspendWrites();
}
}
} finally {
this.state = state;
}
}
public void resumeReads() {
int state = this.state;
boolean queueTask = false;
if (allAreClear(state, READ_FLAG_RESUMED)) try {
state |= READ_FLAG_RESUMED;
// in the absence of a writer, we need to wake up the reader to make sure that handshake will take place
// (the read listener could never be called if this is the side that is supposed to make the first wrap
// for handshake; if that happens, this side of the connection would starve)
if (allAreClear(state, WRITE_FLAG_RESUMED)) {
state |= READ_FLAG_READY;
}
if (allAreSet(state, READ_FLAG_READY)) {
if (queueTask = allAreClear(state, FLAG_TASK_QUEUED)) {
state |= FLAG_TASK_QUEUED;
}
} else if (allAreSet(state, READ_FLAG_NEEDS_WRITE) && allAreClear(state, WRITE_FLAG_UP_RESUMED)) {
// need to resume writes to make this happen
state |= WRITE_FLAG_UP_RESUMED;
sinkConduit.resumeWrites();
} else if (allAreClear(state, READ_FLAG_UP_RESUMED)) {
// upstream reads were not resumed
state |= READ_FLAG_UP_RESUMED;
sourceConduit.resumeReads();
}
} finally {
this.state = state;
// execute this on read thread only after updating the state
if (queueTask) getReadThread().execute(this);
}
}
public void wakeupReads() {
final int state = this.state;
if (allAreClear(state, READ_FLAG_WAKEUP)) {
if (allAreClear(state, FLAG_TASK_QUEUED)) {
this.state = state | READ_FLAG_WAKEUP | FLAG_TASK_QUEUED;
getReadThread().execute(this);
} else {
this.state = state | READ_FLAG_WAKEUP;
}
}
}
public void terminateReads() throws IOException {
int state = this.state;
if (allAreClear(state, READ_FLAG_SHUTDOWN)) {
if (allAreClear(state, FLAG_TLS)) {
// never fired up TLS in the first place
sourceConduit.terminateReads();
} else {
// indicate that the user doesn't want any more data
this.state = state | READ_FLAG_SHUTDOWN;
if (allAreClear(state, READ_FLAG_EOF)) {
performIO(IO_GOAL_FLUSH, NO_BUFFERS, 0, 0, NO_BUFFERS, 0, 0);
if (allAreSet(state, WRITE_FLAG_NEEDS_READ)) {
if (allAreClear(state, READ_FLAG_EOF)) {
return;
}
}
if (!engine.isInboundDone() && engine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) {
engine.closeInbound();
}
final long res = performIO(IO_GOAL_READ, NO_BUFFERS, 0, 0, NO_BUFFERS, 0, 0);
if (res == -1) {
this.state |= READ_FLAG_EOF;
}
}
if (allAreClear(this.state, READ_FLAG_EOF) || this.receiveBuffer.getResource().hasRemaining()) {
// potentially unread data :(
final EOFException exception = msg.connectionClosedEarly();
try {
sourceConduit.terminateReads();
} catch (IOException e) {
exception.addSuppressed(e);
}
throw exception;
} else {
sourceConduit.terminateReads();
}
}
}
}
// queries
public boolean isReadResumed() {
return anyAreSet(state, READ_FLAG_RESUMED | READ_FLAG_WAKEUP);
}
public boolean isReadShutdown() {
return allAreSet(state, READ_FLAG_SHUTDOWN);
}
// await
public void awaitReadable() throws IOException {
int state = this.state;
while (allAreSet(state, FLAG_NEED_ENGINE_TASK)) {
synchronized (this) {
while (tasks != 0) {
try {
wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new InterruptedIOException();
}
}
state &= ~FLAG_NEED_ENGINE_TASK;
this.state = state;
}
}
if (allAreClear(state, READ_FLAG_READY)) {
if (allAreSet(state, READ_FLAG_NEEDS_WRITE)) {
sinkConduit.awaitWritable();
} else {
sourceConduit.awaitReadable();
}
}
}
public void awaitReadable(final long time, final TimeUnit timeUnit) throws IOException {
int state = this.state;
long nanos = timeUnit.toNanos(time);
while (allAreSet(state, FLAG_NEED_ENGINE_TASK)) {
synchronized (this) {
long start = System.nanoTime();
while (tasks != 0) {
try {
if (nanos <= 0) {
return;
}
wait(nanos / 1_000_000, (int) (nanos % 1_000_000));
nanos -= -start + (start = System.nanoTime());
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new InterruptedIOException();
}
}
state &= ~FLAG_NEED_ENGINE_TASK;
this.state = state;
}
}
if (allAreClear(state, READ_FLAG_READY)) {
if (allAreSet(state, READ_FLAG_NEEDS_WRITE)) {
sinkConduit.awaitWritable(nanos, TimeUnit.NANOSECONDS);
} else {
sourceConduit.awaitReadable(nanos, TimeUnit.NANOSECONDS);
}
}
}
//================================================================
//
// I/O methods - read
//
//================================================================
private final ByteBuffer[] readBufferHolder = new ByteBuffer[1];
public int read(final ByteBuffer dst) throws IOException {
final int state = this.state;
if (anyAreSet(state, READ_FLAG_SHUTDOWN)) {
return -1;
}
if (anyAreSet(state, READ_FLAG_EOF)) {
// read data
if (readBuffer.getResource().position() > 0) {
final ByteBuffer readBufferResource = readBuffer.getResource();
readBufferResource.flip();
try {
if (TRACE_SSL) msg.tracef("TLS copy unwrapped data from %s to %s", Buffers.debugString(readBufferResource), Buffers.debugString(dst));
return Buffers.copy(dst, readBufferResource);
} finally {
readBufferResource.compact();
}
}
return -1;
} else if (allAreClear(state, FLAG_TLS)) {
int res = sourceConduit.read(dst);
if (res == 0) {
if (allAreSet(state, READ_FLAG_READY)) {
this.state = state & ~READ_FLAG_READY;
}
} else if (res == -1) {
this.state = (state | READ_FLAG_EOF) & ~READ_FLAG_READY;
}
return res;
} else {
// regular TLS time
final ByteBuffer[] readBufferHolder = this.readBufferHolder;
readBufferHolder[0] = dst;
try {
return (int) performIO(IO_GOAL_READ, NO_BUFFERS, 0, 0, readBufferHolder, 0, 1);
} finally {
readBufferHolder[0] = null;
}
}
}
public long read(final ByteBuffer[] dsts, final int offs, final int len) throws IOException {
final int state = this.state;
if (anyAreSet(state, READ_FLAG_SHUTDOWN)) {
return -1;
} else if (anyAreSet(state, READ_FLAG_EOF)){
if (readBuffer.getResource().position() > 0) {
final ByteBuffer readBufferResource = readBuffer.getResource();
readBufferResource.flip();
try {
if (TRACE_SSL) msg.tracef("TLS copy unwrapped data from %s to %s", Buffers.debugString(readBufferResource), Buffers.debugString(dsts, offs, len));
return Buffers.copy(dsts, offs, len, readBufferResource);
} finally {
readBufferResource.compact();
}
}
return -1;
} else if (allAreClear(state, FLAG_TLS)) {
long res = sourceConduit.read(dsts, offs, len);
if (res == 0) {
if (allAreSet(state, READ_FLAG_READY)) {
this.state = state & ~READ_FLAG_READY;
}
} else if (res == -1) {
this.state = (state | READ_FLAG_EOF) & ~READ_FLAG_READY;
}
return res;
} else {
// regular TLS time
return performIO(IO_GOAL_READ, NO_BUFFERS, 0, 0, dsts, offs, len);
}
}
public long transferTo(final long position, final long count, final FileChannel target) throws IOException {
if (allAreClear(state, FLAG_TLS)) {
return sourceConduit.transferTo(position, count, target);
} else {
return target.transferFrom(new ConduitReadableByteChannel(this), position, count);
}
}
public long transferTo(final long count, final ByteBuffer throughBuffer, final StreamSinkChannel target) throws IOException {
if (allAreClear(state, FLAG_TLS)) {
return sourceConduit.transferTo(count, throughBuffer, target);
} else {
// todo - transfer via application read buffer
return Conduits.transfer(this, count, throughBuffer, target);
}
}
//================================================================
//
// I/O methods - write
//
//================================================================
private final ByteBuffer[] writeBufferHolder = new ByteBuffer[1];
public int write(final ByteBuffer src) throws IOException {
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
throw new ClosedChannelException();
}
if (allAreClear(state, FLAG_TLS)) {
return sinkConduit.write(src);
} else {
final ByteBuffer[] writeBufferHolder = this.writeBufferHolder;
writeBufferHolder[0] = src;
try {
return (int) write(writeBufferHolder, 0, 1);
} finally {
writeBufferHolder[0] = null;
}
}
}
public int writeFinal(final ByteBuffer src) throws IOException {
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
throw new ClosedChannelException();
}
if (allAreClear(state, FLAG_TLS)) {
return sinkConduit.writeFinal(src);
} else {
final ByteBuffer[] writeBufferHolder = this.writeBufferHolder;
writeBufferHolder[0] = src;
try {
return (int) writeFinal(writeBufferHolder, 0, 1);
} finally {
writeBufferHolder[0] = null;
}
}
}
public long write(final ByteBuffer[] srcs, final int offs, final int len) throws IOException {
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
throw new ClosedChannelException();
}
if (allAreClear(state, FLAG_TLS)) {
return sinkConduit.write(srcs, offs, len);
} else {
final long r1 = Buffers.remaining(srcs, offs, len);
performIO(IO_GOAL_WRITE, srcs, offs, len, NO_BUFFERS, 0, 0);
return (r1 - Buffers.remaining(srcs, offs, len));
}
}
public long writeFinal(final ByteBuffer[] srcs, final int offs, final int len) throws IOException {
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
throw new ClosedChannelException();
}
if (allAreClear(state, FLAG_TLS)) {
return sinkConduit.writeFinal(srcs, offs, len);
} else {
final long r1 = Buffers.remaining(srcs, offs, len);
performIO(IO_GOAL_WRITE_FINAL, srcs, offs, len, NO_BUFFERS, 0, 0);
return (r1 - Buffers.remaining(srcs, offs, len));
}
}
public boolean flush() throws IOException {
int state = this.state;
if (allAreSet(state, WRITE_FLAG_FINISHED)) {
return true;
} else if (allAreSet(state, WRITE_FLAG_SHUTDOWN3)) {
// just waiting for upstream flush
if (sinkConduit.flush()) {
this.state = state | WRITE_FLAG_FINISHED;
return true;
} else {
return false;
}
} else if (allAreClear(state, FLAG_TLS)) {
final boolean flushed = sinkConduit.flush();
if (allAreSet(state, WRITE_FLAG_SHUTDOWN) && flushed) {
this.state = state | WRITE_FLAG_SHUTDOWN2 | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
}
return flushed;
} else if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
// waiting for final wrap, then upstream shutdown & flush
return performIO(IO_GOAL_FLUSH, NO_BUFFERS, 0, 0, NO_BUFFERS, 0, 0) != 0L;
} else {
// regular flush
return performIO(IO_GOAL_FLUSH, NO_BUFFERS, 0, 0, NO_BUFFERS, 0, 0) != 0L;
}
}
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
if (allAreClear(state, FLAG_TLS)) {
return sinkConduit.transferFrom(src, position, count);
} else {
return src.transferTo(position, count, new ConduitWritableByteChannel(this));
}
}
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
if (allAreClear(state, FLAG_TLS)) {
return sinkConduit.transferFrom(source, count, throughBuffer);
} else {
return Conduits.transfer(source, count, throughBuffer, this);
}
}
//================================================================
//
// Main SSLEngine I/O loop
//
//================================================================
private static final ByteBuffer[] NO_BUFFERS = new ByteBuffer[0];
private static final int IO_GOAL_READ = 0;
private static final int IO_GOAL_WRITE = 1;
private static final int IO_GOAL_FLUSH = 2;
private static final int IO_GOAL_WRITE_FINAL = 3;
private static long actualIOResult(final long xfer, final int goal, final boolean flushed, final boolean eof) {
final long result = goal == IO_GOAL_FLUSH && flushed ? 1L : goal == IO_GOAL_READ && eof && xfer == 0L ? -1L : xfer;
if (TRACE_SSL) msg.tracef("returned TLS result %d", result);
return result;
}
private static String decodeGoal(int goal) {
switch (goal) {
case 0: return "READ";
case 1: return "WRITE";
case 2: return "FLUSH";
case 3: return "WRITE_FINAL";
default: return "UNKNOWN(" + goal + ")";
}
}
private long performIO(final int goal, final ByteBuffer[] srcs, final int srcOff, final int srcLen, final ByteBuffer[] dsts, final int dstOff, final int dstLen) throws IOException {
if (TRACE_SSL) msg.tracef("performing TLS I/O operation, goal %s, src: %s, dst: %s", decodeGoal(goal), Buffers.debugString(srcs, srcOff, srcLen), Buffers.debugString(dsts, dstOff, dstLen));
// one of these arrays is empty
assert srcs == NO_BUFFERS || dsts == NO_BUFFERS;
int state = this.state;
// this contradiction should never occur
assert ! allAreSet(state, READ_FLAG_NEEDS_WRITE | WRITE_FLAG_NEEDS_READ);
if (allAreSet(state, FLAG_NEED_ENGINE_TASK)) {
// can't do anything until the task is done
return 0L;
}
final SSLEngine engine = this.engine;
final ByteBuffer sendBuffer = this.sendBuffer.getResource();
final ByteBuffer receiveBuffer = this.receiveBuffer.getResource();
final ByteBuffer readBuffer = this.readBuffer.getResource();
// unwrap into our read buffer if necessary to avoid underflow problems
final ByteBuffer[] realDsts = Arrays.copyOfRange(dsts, dstOff, dstLen + 1);
realDsts[dstLen] = readBuffer;
long remaining = max(Buffers.remaining(srcs, srcOff, srcLen), Buffers.remaining(dsts, dstOff, dstLen));
boolean wrap = goal == IO_GOAL_READ ? anyAreSet(state, READ_FLAG_NEEDS_WRITE | FLAG_FLUSH_NEEDED) : allAreSet(state, FLAG_FLUSH_NEEDED) || allAreClear(state, WRITE_FLAG_NEEDS_READ);
boolean unwrap = !wrap;
boolean flushed = false;
boolean eof = false;
boolean readBlocked = false;
boolean writeBlocked = false;
boolean copiedUnwrappedBytes = false;
boolean wakeupReads = false;
SSLEngineResult result;
SSLEngineResult.HandshakeStatus handshakeStatus;
int rv = 0;
// amount of data moved to/from the user buffers (remember only zero or one of srcs/dsts can be given)
long xfer = 0L;
if (TRACE_SSL) msg.trace("TLS perform IO");
try {
for (;;) {
if (TRACE_SSL) msg.trace("TLS begin IO operation");
if (goal == IO_GOAL_READ && remaining > 0 && readBuffer.position() > 0) {
// read data
readBuffer.flip();
try {
if (TRACE_SSL) msg.tracef("TLS copy unwrapped data from %s to %s", Buffers.debugString(readBuffer), Buffers.debugString(dsts, dstOff, dstLen));
rv = Buffers.copy(dsts, dstOff, dstLen, readBuffer);
} finally {
readBuffer.compact();
}
if (rv > 0) {
copiedUnwrappedBytes = true;
xfer += rv;
if ((remaining -= rv) == 0L) {
return actualIOResult(xfer, goal, flushed, eof);
}
}
}
assert ! (wrap && unwrap);
if (wrap) {
if (TRACE_SSL) msg.tracef("TLS wrap from %s to %s", Buffers.debugString(srcs, srcOff, srcLen), Buffers.debugString(sendBuffer));
result = engine.wrap(srcs, srcOff, srcLen, sendBuffer);
WRAP_RESULT: switch (result.getStatus()) {
case BUFFER_UNDERFLOW: {
assert result.bytesConsumed() == 0;
assert result.bytesProduced() == 0;
// move directly to handshake result
if (TRACE_SSL) msg.trace("TLS wrap operation UNDERFLOW");
break;
}
case BUFFER_OVERFLOW: {
assert result.bytesConsumed() == 0;
assert result.bytesProduced() == 0;
if (TRACE_SSL) msg.trace("TLS wrap operation OVERFLOW");
if (sendBuffer.position() == 0) {
// our buffer is empty, and definitely large enough, so just throw an exception
throw msg.wrongBufferExpansion();
} else {
// there's some data in there, so send it first
sendBuffer.flip();
try {
while (sendBuffer.hasRemaining()) {
if (TRACE_SSL) msg.tracef("TLS wrap operation send %s", Buffers.debugString(sendBuffer));
final int res = sinkConduit.write(sendBuffer);
if (res == 0) {
writeBlocked = true;
state &= ~WRITE_FLAG_READY;
// not flushed
assert goal != IO_GOAL_FLUSH || xfer == 0L;
flushed = false;
wrap = false;
break WRAP_RESULT;
}
}
} finally {
sendBuffer.compact();
}
if (goal == IO_GOAL_FLUSH || allAreSet(state, FLAG_FLUSH_NEEDED)) {
if (flushed = sinkConduit.flush()) {
state &= ~FLAG_FLUSH_NEEDED;
}
}
if (goal == IO_GOAL_FLUSH && allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
state |= WRITE_FLAG_SHUTDOWN2;
}
}
// move to handshake result
break;
}
case CLOSED: {
if (TRACE_SSL) msg.trace("TLS wrap operation CLOSED");
if (allAreClear(state, WRITE_FLAG_SHUTDOWN) && result.bytesProduced() == 0) {
if (goal == IO_GOAL_FLUSH) {
// this is okay, the client is flushing and we may have received a close from the other
// end before we actually close this side
wrap = false;
if (goal == IO_GOAL_FLUSH || allAreSet(state, FLAG_FLUSH_NEEDED)) {
if (flushed = sinkConduit.flush()) {
state &= ~FLAG_FLUSH_NEEDED;
}
}
break;
}
// attempted write after shutdown (should not be possible)
state &= ~(WRITE_FLAG_NEEDS_READ | READ_FLAG_NEEDS_WRITE);
state |= WRITE_FLAG_SHUTDOWN | WRITE_FLAG_SHUTDOWN2 | WRITE_FLAG_SHUTDOWN3 | WRITE_FLAG_FINISHED;
final ClosedChannelException exception = new ClosedChannelException();
try {
sinkConduit.truncateWrites();
} catch (IOException e) {
exception.addSuppressed(e);
}
throw exception;
}
if (allAreSet(state, WRITE_FLAG_SHUTDOWN2)) {
state |= WRITE_FLAG_SHUTDOWN3;
}
// else treat as OK
// fall thru!!!
}
case OK: {
if (TRACE_SSL) msg.tracef("TLS wrap operation OK consumed: %d produced: %d", result.bytesConsumed(), result.bytesProduced());
state &= ~(WRITE_FLAG_NEEDS_READ | READ_FLAG_NEEDS_WRITE);
final int consumed = result.bytesConsumed();
if (goal == IO_GOAL_READ) {
// sources should be empty
assert consumed == 0;
wrap = false;
unwrap = true;
} else {
if (consumed > 0 || remaining == 0) {
// if consumed > 0 then remaining must also be 0
assert remaining != 0 || consumed == 0;
// we've returned some data, or else there's nothing to consume
wrap = false;
}
xfer += consumed;
remaining -= consumed;
}
// try to send the generated bytes
sendBuffer.flip();
try {
flushed = false;
while (sendBuffer.hasRemaining()) {
final int res = allAreSet(state, WRITE_FLAG_SHUTDOWN3) ? sinkConduit.writeFinal(sendBuffer) : sinkConduit.write(sendBuffer);
if (res == 0) {
// not flushed; probably can't wrap any more anyway
writeBlocked = true;
wrap = false;
break;
}
}
} finally {
sendBuffer.compact();
}
// make sure we *really* flushed
if (sendBuffer.position() == 0) {
if (goal == IO_GOAL_FLUSH || allAreSet(state, FLAG_FLUSH_NEEDED)) {
if (flushed = sinkConduit.flush()) {
state &= ~FLAG_FLUSH_NEEDED;
}
}
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
if (allAreClear(state, WRITE_FLAG_SHUTDOWN2)) {
// send buffer should already be cleared, thanks to above write block
assert sendBuffer.position() == 0;
state |= WRITE_FLAG_SHUTDOWN2;
if (result.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) {
// make sure we get to shutdown 3 right away if we are no longer handshaking
state |= WRITE_FLAG_SHUTDOWN3;
}
}
if (allAreSet(state, WRITE_FLAG_SHUTDOWN3)) {
// the last wrap has occurred, and writes were shut down; we just need last flush
if (goal == IO_GOAL_FLUSH || sinkConduit.flush()) {
state |= WRITE_FLAG_FINISHED;
}
sinkConduit.terminateWrites();
}
}
}
// move to handshake result
break;
}
default: {
throw msg.unexpectedWrapResult(result.getStatus());
}
}
} else if (unwrap) {
if (TRACE_SSL) msg.tracef("TLS unwrap from %s to %s", Buffers.debugString(receiveBuffer), Buffers.debugString(realDsts, 0, dstLen + 1));
// use dstLen + 1 so that any leftovers are unwrapped into our read buffer to avoid underflow
// * offset is 0 because realDsts is a copyOfRange of the original dsts with one extra buf at the end
assert realDsts.length == 1 || realDsts[0] == dsts[dstOff];
assert realDsts[dstLen] == readBuffer;
// user-visible counts
final long preRem = Buffers.remaining(dsts, dstOff, dstLen);
result = engine.unwrap(receiveBuffer, realDsts, 0, dstLen + 1);
final long userProduced = preRem - Buffers.remaining(dsts, dstOff, dstLen);
switch (result.getStatus()) {
case BUFFER_OVERFLOW: {
assert result.bytesConsumed() == 0;
assert result.bytesProduced() == 0;
assert userProduced == 0;
if (TRACE_SSL) msg.trace("TLS unwrap operation OVERFLOW");
// not enough space in destination buffer; caller should consume the data they have first
if (!copiedUnwrappedBytes) { // realDsts is too small for message to unwrap
return actualIOResult(xfer, goal, flushed, eof);
}
unwrap = false;
break;
}
case BUFFER_UNDERFLOW: {
assert result.bytesConsumed() == 0;
assert result.bytesProduced() == 0;
assert userProduced == 0;
if (TRACE_SSL) msg.trace("TLS unwrap operation UNDERFLOW");
// fill the rest of the buffer, then retry!
receiveBuffer.compact();
try {
int res;
res = sourceConduit.read(receiveBuffer);
if (TRACE_SSL) msg.tracef("TLS unwrap operation read %s", Buffers.debugString(receiveBuffer));
if (res == -1) {
state &= ~READ_FLAG_READY;
engine.closeInbound();
} else if (res == 0) {
readBlocked = true;
state &= ~READ_FLAG_READY;
unwrap = false;
} else if (receiveBuffer.hasRemaining()) {
do {
// try more reads just in case
res = sourceConduit.read(receiveBuffer);
} while (res > 0 && receiveBuffer.hasRemaining());
if (res == 0) {
state &= ~READ_FLAG_READY;
}
}
} finally {
receiveBuffer.flip();
}
// we should now be able to unwrap.
break;
}
case CLOSED: {
if (result.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) {
// treat as buffer underflow
// fill the rest of the buffer, then retry!
receiveBuffer.compact();
try {
int res;
res = sourceConduit.read(receiveBuffer);
if (TRACE_SSL) msg.tracef("TLS unwrap operation read %s", Buffers.debugString(receiveBuffer));
if (res == -1) {
state &= ~READ_FLAG_READY;
engine.closeInbound();
return actualIOResult(xfer, goal, flushed, eof);
} else if (res == 0) {
readBlocked = true;
state &= ~READ_FLAG_READY;
unwrap = false;
return actualIOResult(xfer, goal, flushed, eof);
} else if (receiveBuffer.hasRemaining()) {
do {
// try more reads just in case
res = sourceConduit.read(receiveBuffer);
} while (res > 0 && receiveBuffer.hasRemaining());
if (res == 0) {
state &= ~READ_FLAG_READY;
}
}
} finally {
receiveBuffer.flip();
}
// we should now be able to unwrap.
break;
}
if (TRACE_SSL) msg.trace("TLS unwrap operation CLOSED");
state &= ~(WRITE_FLAG_NEEDS_READ | READ_FLAG_NEEDS_WRITE);
if (goal == IO_GOAL_READ) {
xfer += userProduced;
remaining -= userProduced;
// if we are performing any action that not read, we don't want to disable read handler yet
// (the handler needs to read -1 before that)
state = (state & ~READ_FLAG_READY) | READ_FLAG_EOF;
} else {
wakeupReads = true;
}
// if unwrap processed any data, it should return bytes produced instead of -1
eof = true;
unwrap = false;
if (goal == IO_GOAL_FLUSH) {
wrap = true;
}
break;
}
case OK: {
if (TRACE_SSL) msg.tracef("TLS unwrap operation OK consumed: %d produced: %d", result.bytesConsumed(), result.bytesProduced());
if (allAreClear(state, READ_FLAG_READY)) {
// make sure the caller keeps reading until the unwrapped data is consumed
state |= READ_FLAG_READY;
}
state &= ~(WRITE_FLAG_NEEDS_READ | READ_FLAG_NEEDS_WRITE);
if (goal == IO_GOAL_READ) {
xfer += userProduced;
remaining -= userProduced;
} else {
wrap = true;
unwrap = false;
if (result.bytesProduced() > 0) {
wakeupReads = true;
}
}
// handshake result
break;
}
default: {
throw msg.unexpectedUnwrapResult(result.getStatus());
}
}
} else {
// done
return actualIOResult(xfer, goal, flushed, eof);
}
// now handle handshake
handshakeStatus = result.getHandshakeStatus();
HS: for (;;) {
switch (handshakeStatus) {
case FINISHED: {
if (TRACE_SSL) msg.trace("TLS handshake FINISHED");
connection.invokeHandshakeListener();
// try original op again
// fall thru!
}
case NOT_HANDSHAKING: {
if (allAreSet(state, WRITE_FLAG_SHUTDOWN)) {
engine.closeOutbound();
}
// move on to next operation until I/O blocks
break HS;
}
case NEED_TASK: {
if (TRACE_SSL) msg.trace("TLS handshake NEED_TASK");
if (xfer != 0L) {
// only queue a task if the user isn't going to retry an I/O op immediately after
return actualIOResult(xfer, goal, flushed, eof);
}
if (allAreSet(state, FLAG_INLINE_TASKS)) {
Runnable task;
for (; ; ) {
task = engine.getDelegatedTask();
if (task == null) {
break;
}
try {
task.run();
} catch (Throwable cause) {
throw new SSLException("Delegated task threw an exception", cause);
}
}
// and that's that; loop again
handshakeStatus = engine.getHandshakeStatus();
// retry handshake evaluation
break;
} else {
state |= FLAG_NEED_ENGINE_TASK;
// await methods or the handler blob will take care of this
final ArrayList tasks = new ArrayList<>(4);
Runnable task;
for (;;) {
task = engine.getDelegatedTask();
if (task != null) {
tasks.add(task);
} else {
break;
}
}
final int size = tasks.size();
synchronized (JsseStreamConduit.this) {
this.tasks = size;
}
// use indexes to avoid iterator creation (which does the same thing anyway)
//noinspection ForLoopReplaceableByForEach
for (int i = 0; i < size; i ++) {
getWorker().execute(new TaskWrapper(tasks.get(i)));
}
return actualIOResult(xfer, goal, flushed, eof);
}
}
case NEED_WRAP: {
if (TRACE_SSL) msg.trace("TLS handshake NEED_WRAP");
state |= READ_FLAG_NEEDS_WRITE | FLAG_FLUSH_NEEDED;
if (writeBlocked) {
return actualIOResult(xfer, goal, flushed, eof);
}
wrap = true;
unwrap = false;
break HS;
}
case NEED_UNWRAP: {
if (TRACE_SSL) msg.trace("TLS handshake NEED_UNWRAP");
if (wrap && ! flushed && ! sinkConduit.flush()) {
// our wrap operation was probably actually a handshake msg, but we couldn't flush it
// we need to flush it to proceed else the other side may never send us a response
state |= FLAG_FLUSH_NEEDED;
}
state |= WRITE_FLAG_NEEDS_READ;
if (readBlocked) {
return actualIOResult(xfer, goal, flushed, eof);
}
wrap = false;
unwrap = true;
break HS;
}
default: {
throw msg.unexpectedHandshakeStatus(result.getHandshakeStatus());
}
}
}
}
} finally {
this.state = state;
if (wakeupReads) {
wakeupReads();
}
}
}
class TaskWrapper implements Runnable {
private final Runnable task;
TaskWrapper(final Runnable task) {
this.task = task;
}
public void run() {
try {
task.run();
} finally {
synchronized (JsseStreamConduit.this) {
if (tasks -- == 1) JsseStreamConduit.this.notifyAll();
}
}
}
}
}