All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.dellroad.muxable.simple.ProtocolReader Maven / Gradle / Ivy

/*
 * Copyright (C) 2021 Archie L. Cobbs. All rights reserved.
 */

package org.dellroad.muxable.simple;

import io.permazen.util.LongEncoder;

import java.io.IOException;
import java.nio.ByteBuffer;

import org.dellroad.muxable.Directions;
import org.slf4j.Logger;

/**
 * Input state machine for the {@link SimpleMuxableChannel} framing protocol.
 *
 * 

* Instances are not thread safe. */ public class ProtocolReader extends LoggingSupport { // Configuration setup private final ChannelIds channelIds; // channel ID tracker private final InputHandler inputHandler; // callback object // Stream position tracking private long offset; // total number of bytes read so far // Long value buffer/decoding private final ByteBuffer longValueBuffer // buffers a long value, possibly encoded = ByteBuffer.allocate(LongEncoder.MAX_ENCODED_LENGTH); private long longValueOffset; // stream offset of start of long value private long longValue; // the long value once completed // Incoming payload info private long payloadChannelId; // payload's channel ID (always positive) private boolean payloadChannelIdIsLocal; // true if channel ID is a local channel private boolean newChannelRequest; // payload is a new channel request private Directions newChannelDirections; // new channel request input and/or output private ByteBuffer payloadBuffer; // payload data buffer // State machine state private State state = State.READING_PROTOCOL_COOKIE; // state machine current state private ProtocolViolationException violation; // previous protocol violation, if any private boolean reentrantHandler; // we are invoking the input handler /** * Constructor. * * @param channelIds channel ID tracker (should be shared with the {@link ProtocolWriter}) * @param inputHandler callback interface for generated events * @throws IllegalArgumentException if {@code inputHandler} is null */ public ProtocolReader(ChannelIds channelIds, InputHandler inputHandler) { if (channelIds == null) throw new IllegalArgumentException("null channelIds"); if (inputHandler == null) throw new IllegalArgumentException("null inputHandler"); this.channelIds = channelIds; this.inputHandler = inputHandler; } /** * Constructor. * * @param log {@link Logger} to use * @param logPrefix prefix for all log messages, or null for empty string * @param channelIds channel ID tracker (should be shared with the {@link ProtocolWriter}) * @param inputHandler callback interface for generated events * @throws IllegalArgumentException if {@code log} or {@code inputHandler} is null */ public ProtocolReader(Logger log, String logPrefix, ChannelIds channelIds, InputHandler inputHandler) { super(log, logPrefix); if (channelIds == null) throw new IllegalArgumentException("null channelIds"); if (inputHandler == null) throw new IllegalArgumentException("null inputHandler"); this.channelIds = channelIds; this.inputHandler = inputHandler; } /** * Input new data from the remote side into the protocol state machine. * *

* Any generated events will be delivered to the configured {@link InputHandler} synchronously (in the current thread). * However, this method must not be invoked re-entrantly. * *

* If a protocol violation is detected, {@link ProtocolViolationException} is thrown and this instance becomes * unusable (any further invocations of this method will result in an {@link IllegalStateException}). * *

* For performance reasons, this instance assumes it may take ownership of {@code data}, i.e., retain a reference, * modify, and potentially pass {@code data} back to the {@link InputHandler} at some later point. * *

* This method consumes all of {@code data}, unless a "close connection" frame is read; in that case, any data * after the "close connection" frame will remain in {@code data}, this method will return false, and any subsequent * invocations of this method will throw {@link IllegalStateException}. * * @param data new data received * @return normally true, or false if a "close connection" frame is read * @throws IOException if the {@link InputHandler} does * @throws ProtocolViolationException if the input violates the framing protocol * @throws IllegalArgumentException if {@code data} is null * @throws IllegalStateException if this method is invoked re-entrantly by the {@link InputHandler} * @throws IllegalStateException if this instance has previously encountered a protocol violation * @throws IllegalStateException if the remote instance has closed the connection */ public boolean input(ByteBuffer data) throws IOException { // Sanity check if (this.violation != null) throw new IllegalStateException("a protocol violation has already occurred", this.violation); if (this.reentrantHandler) throw new IllegalStateException("illegal re-entrant invocation"); // Consume the data until if/when a "close connection" frame is seen while (data.hasRemaining()) { this.trace("state %s: input data %s", this.state, LoggingSupport.toString(data, 64)); if (!this.state.inputData(this, data)) return false; } // Done return true; } /** * Get the current stream offset. * * @return stream offset */ public long getOffset() { return this.offset; } // State Machine // State READING_PROTOCOL_COOKIE private boolean inputProtocolCookieData(ByteBuffer data) { // Decode plain (big endian) long value if (!this.inputPlainLongValueData(data)) return true; // Validate protocol cookie final long protocolCookie = this.longValue; this.trace("state %s: read protocol cookie 0x%016x", this.state, protocolCookie); if (protocolCookie != ProtocolConstants.PROTOCOL_COOKIE) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("rec'd invalid protocol cookie 0x%016x != 0x%016x", protocolCookie, ProtocolConstants.PROTOCOL_COOKIE)); } // Proceed this.state = State.READING_PROTOCOL_VERSION; return true; } // State READING_PROTOCOL_VERSION private boolean inputProtocolVersionData(ByteBuffer data) { // Decode encoded long value if (!this.inputEncodedLongValueData(data)) return true; // Validate protocol version final long protocolVersion = this.longValue; this.trace("state %s: read protocol version %d", this.state, protocolVersion); if (protocolVersion != ProtocolConstants.CURRENT_PROTOCOL_VERSION) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("rec'd unsupported protocol version %d (current is %d)", protocolVersion, ProtocolConstants.CURRENT_PROTOCOL_VERSION)); } // Proceed this.state = State.READING_CHANNEL_ID; return true; } // State READING_CHANNEL_ID private boolean inputChannelIdData(ByteBuffer data) { // Decode encoded long value if (!this.inputEncodedLongValueData(data)) return true; // Zero means close the whole thing down this.trace("state %s: read peer's encoded channel ID %d (channel %s%d)", this.state, this.longValue, this.longValue <= 0 ? "L" : "R", Math.abs(this.longValue)); if (this.longValue == 0) { this.state = State.CLOSED; return false; } // Validate channel ID if (this.longValue == Long.MIN_VALUE) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("rec'd invalid frame: invalid encoded channel ID %d", this.longValue)); } // Value was encoded by peer, so negative & positive are swapped this.payloadChannelIdIsLocal = this.longValue < 0; this.payloadChannelId = Math.abs(this.longValue); // Handle local channel ID vs. remote channel ID try { if (this.payloadChannelIdIsLocal) { this.newChannelRequest = false; this.channelIds.validateLocalChannelId(this.payloadChannelId); } else { this.newChannelRequest = this.channelIds.allocateRemoteChannelId(this.payloadChannelId); if (this.newChannelRequest) { this.state = State.READING_DIRECTIONS; return true; } this.channelIds.validateRemoteChannelId(this.payloadChannelId); } } catch (IllegalArgumentException e) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("rec'd invalid frame: %s", e.getMessage())); } // Proceed this.state = State.READING_PAYLOAD_LENGTH; return true; } // State READING_DIRECTIONS private boolean inputDirections(ByteBuffer data) { // Get single flags byte final byte flags = data.get(); switch (flags) { case ProtocolConstants.FLAG_DIRECTION_INPUT: this.newChannelDirections = Directions.OUTPUT_ONLY; // reverse peer's view of the world break; case ProtocolConstants.FLAG_DIRECTION_OUTPUT: this.newChannelDirections = Directions.INPUT_ONLY; // reverse peer's view of the world break; case ProtocolConstants.FLAG_DIRECTION_INPUT | ProtocolConstants.FLAG_DIRECTION_OUTPUT: this.newChannelDirections = Directions.BIDIRECTIONAL; break; default: throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("rec'd invalid frame: invalid flags byte 0x%02x", flags & 0xff)); } this.trace("state %s: new remote channel %d is %s", this.state, this.payloadChannelId, this.newChannelDirections); // Proceed this.state = State.READING_PAYLOAD_LENGTH; return true; } // State READING_PAYLOAD_LENGTH private boolean inputPayloadLengthData(ByteBuffer data) throws IOException { // Decode encoded long value if (!this.inputEncodedLongValueData(data)) return true; // Check value is within range this.trace("state %s: read length %d for %s on channel %s%d", this.state, this.longValue, this.newChannelRequest ? "requestData" : "payload", this.payloadChannelIdIsLocal ? "L" : "R", this.payloadChannelId); if (this.longValue < 0 || this.longValue > Integer.MAX_VALUE) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("rec'd frame on channel %s%d with invalid payload length %d", this.payloadChannelIdIsLocal ? "L" : "R", this.payloadChannelId, this.longValue)); } final int payloadLength = (int)this.longValue; // Check for zero length payload on an already-open channel, which means "close channel" if (!this.newChannelRequest && payloadLength == 0) { // Deallocate channel this.trace("state %s: closing channel %s%d", this.state, this.payloadChannelIdIsLocal ? "L" : "R", this.payloadChannelId); this.channelIds.freeChannelId(this.payloadChannelId, this.payloadChannelIdIsLocal); // Notify input handler final long encodedChannelId = this.getEncodedChannelId(); this.reentrantHandler = true; try { this.inputHandler.nestedChannelClosed(encodedChannelId); } finally { this.reentrantHandler = false; } // Read the next frame this.trace("state %s: closed channel %s%d", this.state, this.payloadChannelIdIsLocal ? "L" : "R", this.payloadChannelId); this.state = State.READING_CHANNEL_ID; return true; } // If we received the entire payload in this data buffer, deliver it directly to the handler without doing any copying if (data.remaining() >= payloadLength) { this.deliverPayload(this.readOut(data, payloadLength)); return true; } // Partial payload received: initialize our payload buffer and copy the initial portion of the payload into it this.payloadBuffer = ByteBuffer.allocate(payloadLength); this.payloadBuffer.put(data); this.state = State.READING_PAYLOAD; return true; } // State READING_PAYLOAD private boolean inputPayloadData(ByteBuffer data) throws IOException { // How much more data do we need? final int payloadRemaining = this.payloadBuffer.remaining(); // If payload still not incomplete, go back for more if (data.remaining() < payloadRemaining) { this.payloadBuffer.put(data); return true; } // Complete the payload using the new data this.payloadBuffer.put(this.readOut(data, payloadRemaining)); // Deliver it to the input handler this.deliverPayload(this.payloadBuffer.flip()); return true; } // State CLOSED private boolean inputClosed(ByteBuffer data) { throw this.violation = new ProtocolViolationException(this.offset, "the connection has been closed by the remote side"); } // Internal methods // Deliver completed payload to input handler private void deliverPayload(ByteBuffer payload) throws IOException { // Debug this.trace("state %s: deliver %d byte payload from channel %s%d", this.state, payload.remaining(), this.payloadChannelIdIsLocal ? "L" : "R", this.payloadChannelId); // Check whether channel is still open, and if so deliver payload to handler final long encodedChannelId = this.getEncodedChannelId(); if (this.channelIds.isChannelOpen(encodedChannelId)) { this.reentrantHandler = true; try { if (this.newChannelRequest) this.inputHandler.nestedChannelRequest(encodedChannelId, payload, this.newChannelDirections); else this.inputHandler.nestedChannelData(encodedChannelId, payload); } finally { this.reentrantHandler = false; } } else { this.trace("state %s: discarding %d byte payload on closed channel %s%d", this.state, payload.remaining(), this.payloadChannelIdIsLocal ? "L" : "R", this.payloadChannelId); } // Reset state and start reading the next frame this.payloadBuffer = null; this.payloadChannelId = 0; this.payloadChannelIdIsLocal = false; this.newChannelRequest = false; this.state = State.READING_CHANNEL_ID; } // Decode a LongEncoder-encoded long value private boolean inputEncodedLongValueData(ByteBuffer data) { // Remember the offset of the start of the long value; calculate total length final int decodeLength; assert data.hasRemaining(); if (this.longValueBuffer.position() == 0) { this.longValueOffset = this.offset; final byte firstByte = data.get(); this.longValueBuffer.put(firstByte); try { decodeLength = LongEncoder.decodeLength(firstByte); } catch (IllegalArgumentException e) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("read invalid encoded long value (in state %s)", this.state), e); } this.offset++; } else decodeLength = LongEncoder.decodeLength(this.longValueBuffer.get(0)); // Add more bytes to accumulator until we get a whole value while (this.longValueBuffer.position() < decodeLength) { if (!data.hasRemaining()) return false; this.longValueBuffer.put(data.get()); this.offset++; } // Decode the encoded value this.longValueBuffer.flip(); try { this.longValue = LongEncoder.read(this.longValueBuffer); } catch (IllegalArgumentException e) { throw this.violation = new ProtocolViolationException(this.longValueOffset, String.format("read invalid encoded long value (in state %s)", this.state), e); } this.longValueBuffer.clear(); // reset for next time return true; } // Decode an 8 byte, big endian long value private boolean inputPlainLongValueData(ByteBuffer data) { // Remember the offset of the start of the long value if (this.longValueBuffer.position() == 0) this.longValueOffset = this.offset; // Attempt to fill buffer with 8 bytes while (this.longValueBuffer.position() < 8) { if (!data.hasRemaining()) return false; this.longValueBuffer.put(data.get()); this.offset++; } // Decode the (big endian) long value this.longValueBuffer.flip(); this.longValue = this.longValueBuffer.getLong(); this.longValueBuffer.clear(); // reset for next time return true; } private long getEncodedChannelId() { assert this.payloadChannelId >= 1; return this.payloadChannelIdIsLocal ? this.payloadChannelId : -this.payloadChannelId; } // Read out the next "length" bytes from the given ByteBuffer and return them in a new ByteBuffer private ByteBuffer readOut(ByteBuffer buffer, int length) { // Sanity check final int available = buffer.remaining(); if (available < length) throw new IllegalArgumentException("invalid length"); // Extract "length" bytes from what's available final ByteBuffer slice = buffer.slice().limit(length); // Advance the underlying buffer buffer.position(buffer.position() + length); // Done return slice; } // InputHandler /** * Callback interface invoked by a {@link ProtocolReader}. * *

* Channel ID's are {@code long} values between one and {@link Long#MAX_VALUE}, inclusive. To disambiguate * local vs. remote channels, local channel ID's are encoded as their positive values, while remote channel ID's * are encoded as the negative of their values (note: zero is not a valid channel ID). */ public interface InputHandler { /** * A new nested channel request has been received. * *

* Because the channel was created by the remote side, {@code channelId} will always be negative. * * @param channelId the encoded channel ID of the new nested channel (always negative) * @param requestData associated request data (read-only) * @param directions which I/O direction(s) are being established (from the local point of view) * @throws IOException if an I/O error occurs */ void nestedChannelRequest(long channelId, ByteBuffer requestData, Directions directions) throws IOException; /** * Data has been received on a nested channel. * *

* Note that it's possible to receive data on a nested channel after the local side has closed that channel, * because the remote side may not yet know that the nested channel has been closed. * * @param channelId the encoded ID of a nested channel (postive for local channels, negative for remote channels) * @param data new channel data * @throws IOException if an I/O error occurs */ void nestedChannelData(long channelId, ByteBuffer data) throws IOException; /** * An existing nested channel has been closed by the remote side. * *

* No more incoming data will {@linkplain #nestedChannelData appear} on the specified nested channel. * *

* Although nested channels are in general bidirectional, both directions are opened and closed at the same time. * This method implies both the incoming and outgoing directions are being closed. * * @param channelId the encoded ID of an open nested channel (postive for local channels, negative for remote channels) * @throws IOException if an I/O error occurs */ void nestedChannelClosed(long channelId) throws IOException; } // State @FunctionalInterface private interface StateHook { boolean inputData(ProtocolReader reader, ByteBuffer data) throws IOException; } private enum State { READING_PROTOCOL_COOKIE(ProtocolReader::inputProtocolCookieData), READING_PROTOCOL_VERSION(ProtocolReader::inputProtocolVersionData), READING_CHANNEL_ID(ProtocolReader::inputChannelIdData), READING_DIRECTIONS(ProtocolReader::inputDirections), READING_PAYLOAD_LENGTH(ProtocolReader::inputPayloadLengthData), READING_PAYLOAD(ProtocolReader::inputPayloadData), CLOSED(ProtocolReader::inputClosed); private final StateHook hook; State(StateHook hook) { this.hook = hook; } public boolean inputData(ProtocolReader input, ByteBuffer data) throws IOException { return this.hook.inputData(input, data); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy