io.nats.client.impl.NatsConnectionReader Maven / Gradle / Ivy
// Copyright 2015-2018 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package io.nats.client.impl;
import io.nats.client.support.IncomingHeadersProcessor;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import static io.nats.client.support.NatsConstants.*;
class NatsConnectionReader implements Runnable {
enum Mode {
GATHER_OP,
GATHER_PROTO,
GATHER_MSG_HMSG_PROTO,
PARSE_PROTO,
GATHER_HEADERS,
GATHER_DATA
};
private final NatsConnection connection;
private ByteBuffer protocolBuffer; // use a byte buffer to assist character decoding
private boolean gotCR;
private String op;
private final char[] opArray;
private int opPos;
private final char[] msgLineChars;
private int msgLinePosition;
private Mode mode;
private IncomingMessageFactory incoming;
private byte[] msgHeaders;
private byte[] msgData;
private int msgHeadersPosition;
private int msgDataPosition;
private final byte[] buffer;
private int bufferPosition;
private Future stopped;
private Future dataPortFuture;
private DataPort dataPort;
private final AtomicBoolean running;
private final boolean utf8Mode;
NatsConnectionReader(NatsConnection connection) {
this.connection = connection;
this.running = new AtomicBoolean(false);
this.stopped = new CompletableFuture<>();
((CompletableFuture)this.stopped).complete(Boolean.TRUE); // we are stopped on creation
this.protocolBuffer = ByteBuffer.allocate(this.connection.getOptions().getMaxControlLine());
this.msgLineChars = new char[this.connection.getOptions().getMaxControlLine()];
this.opArray = new char[MAX_PROTOCOL_RECEIVE_OP_LENGTH];
this.buffer = new byte[connection.getOptions().getBufferSize()];
this.bufferPosition = 0;
this.utf8Mode = connection.getOptions().supportUTF8Subjects();
}
// Should only be called if the current thread has exited.
// Use the Future from stop() to determine if it is ok to call this.
// This method resets that future so mistiming can result in badness.
void start(Future dataPortFuture) {
this.dataPortFuture = dataPortFuture;
this.running.set(true);
this.stopped = connection.getExecutor().submit(this, Boolean.TRUE);
}
Future stop() {
return stop(true);
}
// May be called several times on an error.
// Returns a future that is completed when the thread completes, not when this
// method does.
Future stop(boolean shutdownDataPort) {
if (running.get()) {
running.set(false);
if (shutdownDataPort && dataPort != null) {
try {
dataPort.shutdownInput();
}
catch (IOException e) {
// we don't care, we are shutting down anyway
}
}
}
return stopped;
}
boolean isRunning() {
return running.get();
}
@Override
public void run() {
try {
dataPort = this.dataPortFuture.get(); // Will wait for the future to complete
this.mode = Mode.GATHER_OP;
this.gotCR = false;
this.opPos = 0;
while (this.running.get()) {
this.bufferPosition = 0;
int bytesRead = dataPort.read(this.buffer, 0, this.buffer.length);
if (bytesRead > 0) {
connection.getNatsStatistics().registerRead(bytesRead);
while (this.bufferPosition < bytesRead) {
if (this.mode == Mode.GATHER_OP) {
this.gatherOp(bytesRead);
}
else if (this.mode == Mode.GATHER_MSG_HMSG_PROTO) {
if (this.utf8Mode) {
this.gatherProtocol(bytesRead);
} else {
this.gatherMessageProtocol(bytesRead);
}
}
else if (this.mode == Mode.GATHER_PROTO) {
this.gatherProtocol(bytesRead);
}
else if (this.mode == Mode.GATHER_HEADERS) {
this.gatherHeaders(bytesRead);
}
else { // Mode.GATHER_DATA
this.gatherMessageData(bytesRead);
}
if (this.mode == Mode.PARSE_PROTO) { // Could be the end of the read
this.parseProtocolMessage();
this.protocolBuffer.clear();
}
}
} else if (bytesRead < 0) {
throw new IOException("Read channel closed.");
} else {
this.connection.getNatsStatistics().registerRead(bytesRead); // track the 0
}
}
} catch (IOException io) {
// if already not running, an IOE is not unreasonable in a transition state
if (running.get()) {
this.connection.handleCommunicationIssue(io);
}
} catch (CancellationException | ExecutionException | InterruptedException ex) {
// Exit
} finally {
this.running.set(false);
// Clear the buffers, since they are only used inside this try/catch
// We will reuse later
this.protocolBuffer.clear();
}
}
// Gather the op, either up to the first space or the first carriage return.
void gatherOp(int maxPos) throws IOException {
try {
while(this.bufferPosition < maxPos) {
byte b = this.buffer[this.bufferPosition];
this.bufferPosition++;
if (gotCR) {
if (b == LF) { // Got CRLF, jump to parsing
this.op = opFor(opArray, opPos);
this.gotCR = false;
this.opPos = 0;
this.mode = Mode.PARSE_PROTO;
break;
} else {
throw new IllegalStateException("Bad socket data, no LF after CR");
}
} else if (b == SP || b == TAB) { // Got a space, get the rest of the protocol line
this.op = opFor(opArray, opPos);
this.opPos = 0;
if (this.op.equals(OP_MSG) || this.op.equals(OP_HMSG)) {
this.msgLinePosition = 0;
this.mode = Mode.GATHER_MSG_HMSG_PROTO;
} else {
this.mode = Mode.GATHER_PROTO;
}
break;
} else if (b == CR) {
this.gotCR = true;
} else {
this.opArray[opPos] = (char) b;
this.opPos++;
}
}
} catch (ArrayIndexOutOfBoundsException | IllegalStateException | NumberFormatException | NullPointerException ex) {
this.encounteredProtocolError(ex);
}
}
// Stores the message protocol line in a char buffer that will be read for subject, reply
void gatherMessageProtocol(int maxPos) throws IOException {
try {
while(this.bufferPosition < maxPos) {
byte b = this.buffer[this.bufferPosition];
this.bufferPosition++;
if (gotCR) {
if (b == LF) {
this.mode = Mode.PARSE_PROTO;
this.gotCR = false;
break;
} else {
throw new IllegalStateException("Bad socket data, no LF after CR");
}
} else if (b == CR) {
this.gotCR = true;
} else {
if (this.msgLinePosition >= this.msgLineChars.length) {
throw new IllegalStateException("Protocol line is too long");
}
this.msgLineChars[this.msgLinePosition] = (char) b; // Assumes ascii, as per protocol doc
this.msgLinePosition++;
}
}
} catch (IllegalStateException | NumberFormatException | NullPointerException ex) {
this.encounteredProtocolError(ex);
}
}
// Gather bytes for a protocol line
void gatherProtocol(int maxPos) throws IOException {
// protocol buffer has max capacity, shouldn't need resizing
try {
while(this.bufferPosition < maxPos) {
byte b = this.buffer[this.bufferPosition];
this.bufferPosition++;
if (gotCR) {
if (b == LF) {
this.protocolBuffer.flip();
this.mode = Mode.PARSE_PROTO;
this.gotCR = false;
break;
} else {
throw new IllegalStateException("Bad socket data, no LF after CR");
}
} else if (b == CR) {
this.gotCR = true;
} else {
if (!protocolBuffer.hasRemaining()) {
this.protocolBuffer = this.connection.enlargeBuffer(this.protocolBuffer); // just double it
}
this.protocolBuffer.put(b);
}
}
} catch (IllegalStateException | NumberFormatException | NullPointerException ex) {
this.encounteredProtocolError(ex);
}
}
void gatherHeaders(int maxPos) throws IOException {
try {
while(this.bufferPosition < maxPos) {
int possible = maxPos - this.bufferPosition;
int want = msgHeaders.length - msgHeadersPosition;
// Grab all we can, until we get the necessary number of bytes
if (want > 0 && want <= possible) {
System.arraycopy(this.buffer, this.bufferPosition, this.msgHeaders, this.msgHeadersPosition, want);
msgHeadersPosition += want;
this.bufferPosition += want;
continue;
} else if (want > 0) {
System.arraycopy(this.buffer, this.bufferPosition, this.msgHeaders, this.msgHeadersPosition, possible);
msgHeadersPosition += possible;
this.bufferPosition += possible;
continue;
}
if (msgHeadersPosition == msgHeaders.length) {
incoming.setHeaders(new IncomingHeadersProcessor(msgHeaders));
msgHeaders = null;
msgHeadersPosition = -1;
this.mode = Mode.GATHER_DATA;
break;
} else {
throw new IllegalStateException("Bad socket data, headers do not match expected length");
}
}
} catch (IllegalStateException | NullPointerException ex) {
this.encounteredProtocolError(ex);
}
}
// Gather bytes for a message body into a byte array that is then
// given to the message object
void gatherMessageData(int maxPos) throws IOException {
try {
while(this.bufferPosition < maxPos) {
int possible = maxPos - this.bufferPosition;
int want = msgData.length - msgDataPosition;
// Grab all we can, until we get to the CR/LF
if (want > 0 && want <= possible) {
System.arraycopy(this.buffer, this.bufferPosition, this.msgData, this.msgDataPosition, want);
msgDataPosition += want;
this.bufferPosition += want;
continue;
} else if (want > 0) {
System.arraycopy(this.buffer, this.bufferPosition, this.msgData, this.msgDataPosition, possible);
msgDataPosition += possible;
this.bufferPosition += possible;
continue;
}
byte b = this.buffer[this.bufferPosition];
this.bufferPosition++;
if (gotCR) {
if (b == LF) {
incoming.setData(msgData);
this.connection.deliverMessage(incoming.getMessage());
msgData = null;
msgDataPosition = 0;
incoming = null;
gotCR = false;
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
break;
} else {
throw new IllegalStateException("Bad socket data, no LF after CR");
}
} else if (b == CR) {
gotCR = true;
} else {
throw new IllegalStateException("Bad socket data, no CRLF after data");
}
}
} catch (IllegalStateException | NullPointerException ex) {
this.encounteredProtocolError(ex);
}
}
public String grabNextMessageLineElement(int max) {
if (this.msgLinePosition >= max) {
return null;
}
int start = this.msgLinePosition;
while (this.msgLinePosition < max) {
char c = this.msgLineChars[this.msgLinePosition];
this.msgLinePosition++;
if (c == SP || c == TAB) {
return new String(this.msgLineChars, start, this.msgLinePosition - start -1); //don't grab the space, avoid an intermediate char sequence
}
}
return new String(this.msgLineChars, start, this.msgLinePosition-start);
}
static String opFor(char[] chars, int length) {
if (length == 3) {
if ((chars[0] == 'M' || chars[0] == 'm') &&
(chars[1] == 'S' || chars[1] == 's') &&
(chars[2] == 'G' || chars[2] == 'g')) {
return OP_MSG;
} else if (chars[0] == '+' &&
(chars[1] == 'O' || chars[1] == 'o') &&
(chars[2] == 'K' || chars[2] == 'k')) {
return OP_OK;
} else {
return UNKNOWN_OP;
}
} else if (length == 4) { // do them in a unique order for uniqueness when possible to branch asap
if ((chars[1] == 'I' || chars[1] == 'i') &&
(chars[0] == 'P' || chars[0] == 'p') &&
(chars[2] == 'N' || chars[2] == 'n') &&
(chars[3] == 'G' || chars[3] == 'g')) {
return OP_PING;
} else if ((chars[1] == 'O' || chars[1] == 'o') &&
(chars[0] == 'P' || chars[0] == 'p') &&
(chars[2] == 'N' || chars[2] == 'n') &&
(chars[3] == 'G' || chars[3] == 'g')) {
return OP_PONG;
} else if (chars[0] == '-' &&
(chars[1] == 'E' || chars[1] == 'e') &&
(chars[2] == 'R' || chars[2] == 'r') &&
(chars[3] == 'R' || chars[3] == 'r')) {
return OP_ERR;
} else if ((chars[0] == 'I' || chars[0] == 'i') &&
(chars[1] == 'N' || chars[1] == 'n') &&
(chars[2] == 'F' || chars[2] == 'f') &&
(chars[3] == 'O' || chars[3] == 'o')) {
return OP_INFO;
} else if ((chars[0] == 'H' || chars[0] == 'h') &&
(chars[1] == 'M' || chars[1] == 'm') &&
(chars[2] == 'S' || chars[2] == 's') &&
(chars[3] == 'G' || chars[3] == 'g')) {
return OP_HMSG;
} else {
return UNKNOWN_OP;
}
} else {
return UNKNOWN_OP;
}
}
private static final int[] TENS = new int[] { 1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000};
public static int parseLength(String s) throws NumberFormatException {
int length = s.length();
int retVal = 0;
if (length > TENS.length) {
throw new NumberFormatException("Long in message length \"" + s + "\" "+length+" > "+TENS.length);
}
for (int i=length-1;i>=0;i--) {
char c = s.charAt(i);
int d = (c - '0');
if (d>9) {
throw new NumberFormatException("Invalid char in message length '" + c + "'");
}
retVal += d * TENS[length - i - 1];
}
return retVal;
}
void parseProtocolMessage() throws IOException {
try {
switch (this.op) {
case OP_MSG:
int protocolLength = this.msgLinePosition; //This is just after the last character
int protocolLineLength = protocolLength + 4; // 4 for the "MSG "
if (this.utf8Mode) {
protocolLineLength = protocolBuffer.remaining() + 4;
CharBuffer buff = StandardCharsets.UTF_8.decode(protocolBuffer);
protocolLength = buff.remaining();
buff.get(this.msgLineChars, 0, protocolLength);
}
this.msgLinePosition = 0;
String subject = grabNextMessageLineElement(protocolLength);
String sid = grabNextMessageLineElement(protocolLength);
String replyTo = grabNextMessageLineElement(protocolLength);
String lengthChars = null;
if (this.msgLinePosition < protocolLength) {
lengthChars = grabNextMessageLineElement(protocolLength);
} else {
lengthChars = replyTo;
replyTo = null;
}
if (subject == null || subject.isEmpty() || sid == null || sid.isEmpty() || lengthChars == null) {
throw new IllegalStateException("Bad MSG control line, missing required fields");
}
int incomingLength = parseLength(lengthChars);
this.incoming = new IncomingMessageFactory(sid, subject, replyTo, protocolLineLength, utf8Mode);
this.mode = Mode.GATHER_DATA;
this.msgData = new byte[incomingLength];
this.msgDataPosition = 0;
this.msgLinePosition = 0;
break;
case OP_HMSG:
int hProtocolLength = this.msgLinePosition; //This is just after the last character
int hProtocolLineLength = hProtocolLength + 5; // 5 for the "HMSG "
if (this.utf8Mode) {
hProtocolLineLength = protocolBuffer.remaining() + 5;
CharBuffer buff = StandardCharsets.UTF_8.decode(protocolBuffer);
hProtocolLength = buff.remaining();
buff.get(this.msgLineChars, 0, hProtocolLength);
}
this.msgLinePosition = 0;
String hSubject = grabNextMessageLineElement(hProtocolLength);
String hSid = grabNextMessageLineElement(hProtocolLength);
String replyToOrHdrLen = grabNextMessageLineElement(hProtocolLength);
String hdrLenOrTotLen = grabNextMessageLineElement(hProtocolLength);
String hReplyTo = null;
int hdrLen = -1;
int totLen = -1;
// if there is more it must be replyTo hdrLen totLen instead of just hdrLen totLen
if (this.msgLinePosition < hProtocolLength) {
hReplyTo = replyToOrHdrLen;
hdrLen = parseLength(hdrLenOrTotLen);
totLen = parseLength(grabNextMessageLineElement(hProtocolLength));
} else {
hdrLen = parseLength(replyToOrHdrLen);
totLen = parseLength(hdrLenOrTotLen);
}
if(hSubject==null || hSubject.isEmpty() || hSid==null || hSid.isEmpty()) {
throw new IllegalStateException("Bad HMSG control line, missing required fields");
}
this.incoming = new IncomingMessageFactory(hSid, hSubject, hReplyTo, hProtocolLineLength, utf8Mode);
this.msgHeaders = new byte[hdrLen];
this.msgData = new byte[totLen - hdrLen];
this.mode = Mode.GATHER_HEADERS;
this.msgHeadersPosition = 0;
this.msgDataPosition = 0;
this.msgLinePosition = 0;
break;
case OP_OK:
this.connection.processOK();
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
break;
case OP_ERR:
String errorText = StandardCharsets.UTF_8.decode(protocolBuffer).toString().replace("'", "");
this.connection.processError(errorText);
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
break;
case OP_PING:
this.connection.sendPong();
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
break;
case OP_PONG:
this.connection.handlePong();
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
break;
case OP_INFO:
String info = StandardCharsets.UTF_8.decode(protocolBuffer).toString();
this.connection.handleInfo(info);
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
break;
default:
throw new IllegalStateException("Unknown protocol operation "+op);
}
} catch (IllegalStateException | NumberFormatException | NullPointerException ex) {
this.encounteredProtocolError(ex);
}
}
void encounteredProtocolError(Exception ex) throws IOException {
throw new IOException(ex);
}
//For testing
void fakeReadForTest(byte[] bytes) {
System.arraycopy(bytes, 0, this.buffer, 0, bytes.length);
this.bufferPosition = 0;
this.op = UNKNOWN_OP;
this.mode = Mode.GATHER_OP;
}
String currentOp() {
return this.op;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy