org.bouncycastle.tls.DTLSRecordLayer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of bctls-fips Show documentation
Show all versions of bctls-fips Show documentation
The Bouncy Castle Java APIs for the TLS, including a JSSE provider. The APIs are designed primarily to be used in conjunction with the BC FIPS provider. The APIs may also be used with other providers although if being used in a FIPS context it is the responsibility of the user to ensure that any other providers used are FIPS certified and used appropriately.
package org.bouncycastle.tls;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.SocketTimeoutException;
import org.bouncycastle.tls.crypto.TlsCipher;
import org.bouncycastle.tls.crypto.TlsNullNullCipher;
import org.bouncycastle.util.Arrays;
class DTLSRecordLayer
implements DatagramTransport
{
private static final int RECORD_HEADER_LENGTH = 13;
private static final int MAX_FRAGMENT_LENGTH = 1 << 14;
private static final long TCP_MSL = 1000L * 60 * 2;
private static final long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
static byte[] receiveClientHelloRecord(byte[] data, int dataOff, int dataLen) throws IOException
{
if (dataLen < RECORD_HEADER_LENGTH)
{
return null;
}
short contentType = TlsUtils.readUint8(data, dataOff + 0);
if (ContentType.handshake != contentType)
{
return null;
}
ProtocolVersion version = TlsUtils.readVersion(data, dataOff + 1);
if (!ProtocolVersion.DTLSv10.isEqualOrEarlierVersionOf(version))
{
return null;
}
int epoch = TlsUtils.readUint16(data, dataOff + 3);
if (0 != epoch)
{
return null;
}
// long sequenceNumber = TlsUtils.readUint48(data, dataOff + 5);
int length = TlsUtils.readUint16(data, dataOff + 11);
if (dataLen != RECORD_HEADER_LENGTH + length)
{
return null;
}
return Arrays.copyOfRange(data, dataOff + RECORD_HEADER_LENGTH, dataOff + dataLen);
}
static void sendHelloVerifyRequestRecord(DatagramSender sender, long recordSeq, byte[] message) throws IOException
{
TlsUtils.checkUint16(message.length);
byte[] record = new byte[RECORD_HEADER_LENGTH + message.length];
TlsUtils.writeUint8(ContentType.handshake, record, 0);
TlsUtils.writeVersion(ProtocolVersion.DTLSv10, record, 1);
TlsUtils.writeUint16(0, record, 3);
TlsUtils.writeUint48(recordSeq, record, 5);
TlsUtils.writeUint16(message.length, record, 11);
System.arraycopy(message, 0, record, RECORD_HEADER_LENGTH, message.length);
sendDatagram(sender, record);
}
private static void sendDatagram(DatagramSender sender, byte[] record)
throws IOException
{
try
{
sender.send(record, 0, record.length);
}
catch (InterruptedIOException e)
{
e.bytesTransferred = 0;
throw e;
}
}
private final TlsContext context;
private final TlsPeer peer;
private final DatagramTransport transport;
private final ByteQueue recordQueue = new ByteQueue();
private volatile boolean closed = false;
private volatile boolean failed = false;
// TODO[dtls13] Review the draft/RFC (legacy_record_version) to see if readVersion can be removed
private volatile ProtocolVersion readVersion = null, writeVersion = null;
private volatile boolean inConnection;
private volatile boolean inHandshake;
private volatile int plaintextLimit;
private DTLSEpoch currentEpoch, pendingEpoch;
private DTLSEpoch readEpoch, writeEpoch;
private DTLSHandshakeRetransmit retransmit = null;
private DTLSEpoch retransmitEpoch = null;
private Timeout retransmitTimeout = null;
private TlsHeartbeat heartbeat = null; // If non-null, controls the sending of heartbeat requests
private boolean heartBeatResponder = false; // Whether we should send heartbeat responses
private HeartbeatMessage heartbeatInFlight = null; // The current in-flight heartbeat request, if any
private Timeout heartbeatTimeout = null; // Idle timeout (if none in-flight), else expiry timeout for response
private int heartbeatResendMillis = -1; // Delay before retransmit of current in-flight heartbeat request
private Timeout heartbeatResendTimeout = null; // Timeout for next retransmit of the in-flight heartbeat request
DTLSRecordLayer(TlsContext context, TlsPeer peer, DatagramTransport transport)
{
this.context = context;
this.peer = peer;
this.transport = transport;
this.inHandshake = true;
this.currentEpoch = new DTLSEpoch(0, new TlsNullNullCipher());
this.pendingEpoch = null;
this.readEpoch = currentEpoch;
this.writeEpoch = currentEpoch;
setPlaintextLimit(MAX_FRAGMENT_LENGTH);
}
void resetAfterHelloVerifyRequestClient()
{
currentEpoch.getReplayWindow().reset();
}
void resetAfterHelloVerifyRequestServer(long writeRecordSeqNo)
{
this.inConnection = true;
currentEpoch.setSequenceNumber(writeRecordSeqNo);
/*
* TODO The sequence number reflects one that the client has already sent, so we could
* initialize the replay window here accordingly.
*/
}
void setPlaintextLimit(int plaintextLimit)
{
this.plaintextLimit = plaintextLimit;
}
int getReadEpoch()
{
return readEpoch.getEpoch();
}
ProtocolVersion getReadVersion()
{
return readVersion;
}
void setReadVersion(ProtocolVersion readVersion)
{
this.readVersion = readVersion;
}
void setWriteVersion(ProtocolVersion writeVersion)
{
this.writeVersion = writeVersion;
}
void initPendingEpoch(TlsCipher pendingCipher)
{
if (pendingEpoch != null)
{
throw new IllegalStateException();
}
/*
* TODO "In order to ensure that any given sequence/epoch pair is unique, implementations
* MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment
* lifetime."
*/
// TODO Check for overflow
this.pendingEpoch = new DTLSEpoch(writeEpoch.getEpoch() + 1, pendingCipher);
}
void handshakeSuccessful(DTLSHandshakeRetransmit retransmit)
{
if (readEpoch == currentEpoch || writeEpoch == currentEpoch)
{
// TODO
throw new IllegalStateException();
}
if (null != retransmit)
{
this.retransmit = retransmit;
this.retransmitEpoch = currentEpoch;
this.retransmitTimeout = new Timeout(RETRANSMIT_TIMEOUT);
}
this.inHandshake = false;
this.currentEpoch = pendingEpoch;
this.pendingEpoch = null;
}
void initHeartbeat(TlsHeartbeat heartbeat, boolean heartbeatResponder)
{
if (inHandshake)
{
throw new IllegalStateException();
}
this.heartbeat = heartbeat;
this.heartBeatResponder = heartbeatResponder;
if (null != heartbeat)
{
resetHeartbeat();
}
}
void resetWriteEpoch()
{
if (null != retransmitEpoch)
{
this.writeEpoch = retransmitEpoch;
}
else
{
this.writeEpoch = currentEpoch;
}
}
public int getReceiveLimit()
throws IOException
{
return Math.min(this.plaintextLimit,
readEpoch.getCipher().getPlaintextLimit(transport.getReceiveLimit() - RECORD_HEADER_LENGTH));
}
public int getSendLimit()
throws IOException
{
return Math.min(this.plaintextLimit,
writeEpoch.getCipher().getPlaintextLimit(transport.getSendLimit() - RECORD_HEADER_LENGTH));
}
public int receive(byte[] buf, int off, int len, int waitMillis)
throws IOException
{
long currentTimeMillis = System.currentTimeMillis();
Timeout timeout = Timeout.forWaitMillis(waitMillis, currentTimeMillis);
byte[] record = null;
while (waitMillis >= 0)
{
if (null != retransmitTimeout && retransmitTimeout.remainingMillis(currentTimeMillis) < 1)
{
retransmit = null;
retransmitEpoch = null;
retransmitTimeout = null;
}
if (Timeout.hasExpired(heartbeatTimeout, currentTimeMillis))
{
if (null != heartbeatInFlight)
{
throw new TlsTimeoutException("Heartbeat timed out");
}
this.heartbeatInFlight = HeartbeatMessage.create(context, HeartbeatMessageType.heartbeat_request,
heartbeat.generatePayload());
this.heartbeatTimeout = new Timeout(heartbeat.getTimeoutMillis(), currentTimeMillis);
this.heartbeatResendMillis = DTLSReliableHandshake.INITIAL_RESEND_MILLIS;
this.heartbeatResendTimeout = new Timeout(heartbeatResendMillis, currentTimeMillis);
sendHeartbeatMessage(heartbeatInFlight);
}
else if (Timeout.hasExpired(heartbeatResendTimeout, currentTimeMillis))
{
this.heartbeatResendMillis = DTLSReliableHandshake.backOff(heartbeatResendMillis);
this.heartbeatResendTimeout = new Timeout(heartbeatResendMillis, currentTimeMillis);
sendHeartbeatMessage(heartbeatInFlight);
}
waitMillis = Timeout.constrainWaitMillis(waitMillis, heartbeatTimeout, currentTimeMillis);
waitMillis = Timeout.constrainWaitMillis(waitMillis, heartbeatResendTimeout, currentTimeMillis);
// NOTE: Guard against bad logic giving a negative value
if (waitMillis < 0)
{
waitMillis = 1;
}
int receiveLimit = Math.min(len, getReceiveLimit()) + RECORD_HEADER_LENGTH;
if (null == record || record.length < receiveLimit)
{
record = new byte[receiveLimit];
}
int received = receiveRecord(record, 0, receiveLimit, waitMillis);
int processed = processRecord(received, record, buf, off);
if (processed >= 0)
{
return processed;
}
currentTimeMillis = System.currentTimeMillis();
waitMillis = Timeout.getWaitMillis(timeout, currentTimeMillis);
}
return -1;
}
public void send(byte[] buf, int off, int len)
throws IOException
{
short contentType = ContentType.application_data;
if (this.inHandshake || this.writeEpoch == this.retransmitEpoch)
{
contentType = ContentType.handshake;
short handshakeType = TlsUtils.readUint8(buf, off);
if (handshakeType == HandshakeType.finished)
{
DTLSEpoch nextEpoch = null;
if (this.inHandshake)
{
nextEpoch = pendingEpoch;
}
else if (this.writeEpoch == this.retransmitEpoch)
{
nextEpoch = currentEpoch;
}
if (nextEpoch == null)
{
// TODO
throw new IllegalStateException();
}
// Implicitly send change_cipher_spec and change to pending cipher state
// TODO Send change_cipher_spec and finished records in single datagram?
byte[] data = new byte[]{ 1 };
sendRecord(ContentType.change_cipher_spec, data, 0, data.length);
writeEpoch = nextEpoch;
}
}
sendRecord(contentType, buf, off, len);
}
public void close()
throws IOException
{
if (!closed)
{
if (inHandshake && inConnection)
{
warn(AlertDescription.user_canceled, "User canceled handshake");
}
closeTransport();
}
}
void fail(short alertDescription)
{
if (!closed)
{
if (inConnection)
{
try
{
raiseAlert(AlertLevel.fatal, alertDescription, null, null);
}
catch (Exception e)
{
// Ignore
}
}
failed = true;
closeTransport();
}
}
void failed()
{
if (!closed)
{
failed = true;
closeTransport();
}
}
void warn(short alertDescription, String message)
throws IOException
{
raiseAlert(AlertLevel.warning, alertDescription, message, null);
}
private void closeTransport()
{
if (!closed)
{
/*
* RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is
* required to send a close_notify alert before closing the write side of the
* connection. The other party MUST respond with a close_notify alert of its own and
* close down the connection immediately, discarding any pending writes.
*/
try
{
if (!failed)
{
warn(AlertDescription.close_notify, null);
}
transport.close();
}
catch (Exception e)
{
// Ignore
}
closed = true;
}
}
private void raiseAlert(short alertLevel, short alertDescription, String message, Throwable cause)
throws IOException
{
peer.notifyAlertRaised(alertLevel, alertDescription, message, cause);
byte[] error = new byte[2];
error[0] = (byte)alertLevel;
error[1] = (byte)alertDescription;
sendRecord(ContentType.alert, error, 0, 2);
}
private int receiveDatagram(byte[] buf, int off, int len, int waitMillis)
throws IOException
{
try
{
return transport.receive(buf, off, len, waitMillis);
}
catch (SocketTimeoutException e)
{
return -1;
}
catch (InterruptedIOException e)
{
e.bytesTransferred = 0;
throw e;
}
}
// TODO Include 'currentTimeMillis' as an argument, use with Timeout, resetHeartbeat
private int processRecord(int received, byte[] record, byte[] buf, int off)
throws IOException
{
// NOTE: received < 0 (timeout) is covered by this first case
if (received < RECORD_HEADER_LENGTH)
{
return -1;
}
int length = TlsUtils.readUint16(record, 11);
if (received != (length + RECORD_HEADER_LENGTH))
{
return -1;
}
short type = TlsUtils.readUint8(record, 0);
switch (type)
{
case ContentType.alert:
case ContentType.application_data:
case ContentType.change_cipher_spec:
case ContentType.handshake:
case ContentType.heartbeat:
break;
default:
return -1;
}
int epoch = TlsUtils.readUint16(record, 3);
DTLSEpoch recordEpoch = null;
if (epoch == readEpoch.getEpoch())
{
recordEpoch = readEpoch;
}
else if (type == ContentType.handshake && null != retransmitEpoch
&& epoch == retransmitEpoch.getEpoch())
{
recordEpoch = retransmitEpoch;
}
if (null == recordEpoch)
{
return -1;
}
long seq = TlsUtils.readUint48(record, 5);
if (recordEpoch.getReplayWindow().shouldDiscard(seq))
{
return -1;
}
ProtocolVersion version = TlsUtils.readVersion(record, 1);
if (!version.isDTLS())
{
return -1;
}
if (null != readVersion && !readVersion.equals(version))
{
/*
* Special-case handling for retransmitted ClientHello records.
*
* TODO Revisit how 'readVersion' works, since this is quite awkward.
*/
boolean isClientHelloFragment =
getReadEpoch() == 0
&& length > 0
&& ContentType.handshake == type
&& HandshakeType.client_hello == TlsUtils.readUint8(record, RECORD_HEADER_LENGTH);
if (!isClientHelloFragment)
{
return -1;
}
}
long macSeqNo = getMacSequenceNumber(recordEpoch.getEpoch(), seq);
byte[] plaintext = recordEpoch.getCipher().decodeCiphertext(macSeqNo, type, record,
RECORD_HEADER_LENGTH, length);
recordEpoch.getReplayWindow().reportAuthenticated(seq);
if (plaintext.length > this.plaintextLimit)
{
return -1;
}
if (null == readVersion)
{
readVersion = version;
}
switch (type)
{
case ContentType.alert:
{
if (plaintext.length == 2)
{
short alertLevel = plaintext[0];
short alertDescription = plaintext[1];
peer.notifyAlertReceived(alertLevel, alertDescription);
if (alertLevel == AlertLevel.fatal)
{
failed();
throw new TlsFatalAlert(alertDescription);
}
// TODO Can close_notify be a fatal alert?
if (alertDescription == AlertDescription.close_notify)
{
closeTransport();
}
}
return -1;
}
case ContentType.application_data:
{
if (inHandshake)
{
// TODO Consider buffering application data for new epoch that arrives
// out-of-order with the Finished message
return -1;
}
break;
}
case ContentType.change_cipher_spec:
{
// Implicitly receive change_cipher_spec and change to pending cipher state
for (int i = 0; i < plaintext.length; ++i)
{
short message = TlsUtils.readUint8(plaintext, i);
if (message != ChangeCipherSpec.change_cipher_spec)
{
continue;
}
if (pendingEpoch != null)
{
readEpoch = pendingEpoch;
}
}
return -1;
}
case ContentType.handshake:
{
if (!inHandshake)
{
if (null != retransmit)
{
retransmit.receivedHandshakeRecord(epoch, plaintext, 0, plaintext.length);
}
// TODO Consider support for HelloRequest
return -1;
}
break;
}
case ContentType.heartbeat:
{
if (null != heartbeatInFlight || heartBeatResponder)
{
try
{
ByteArrayInputStream input = new ByteArrayInputStream(plaintext);
HeartbeatMessage heartbeatMessage = HeartbeatMessage.parse(input);
if (null != heartbeatMessage)
{
switch (heartbeatMessage.getType())
{
case HeartbeatMessageType.heartbeat_request:
{
if (heartBeatResponder)
{
HeartbeatMessage response = HeartbeatMessage.create(context,
HeartbeatMessageType.heartbeat_response, heartbeatMessage.getPayload());
sendHeartbeatMessage(response);
}
break;
}
case HeartbeatMessageType.heartbeat_response:
{
if (null != heartbeatInFlight
&& Arrays.areEqual(heartbeatMessage.getPayload(), heartbeatInFlight.getPayload()))
{
resetHeartbeat();
}
break;
}
default:
break;
}
}
}
catch (Exception e)
{
// Ignore
}
}
return -1;
}
}
/*
* NOTE: If we receive any non-handshake data in the new epoch implies the peer has
* received our final flight.
*/
if (!inHandshake && null != retransmit)
{
this.retransmit = null;
this.retransmitEpoch = null;
this.retransmitTimeout = null;
}
System.arraycopy(plaintext, 0, buf, off, plaintext.length);
return plaintext.length;
}
private int receiveRecord(byte[] buf, int off, int len, int waitMillis)
throws IOException
{
if (recordQueue.available() > 0)
{
int length = 0;
if (recordQueue.available() >= RECORD_HEADER_LENGTH)
{
byte[] lengthBytes = new byte[2];
recordQueue.read(lengthBytes, 0, 2, 11);
length = TlsUtils.readUint16(lengthBytes, 0);
}
int received = Math.min(recordQueue.available(), RECORD_HEADER_LENGTH + length);
recordQueue.removeData(buf, off, received, 0);
return received;
}
int received = receiveDatagram(buf, off, len, waitMillis);
if (received >= RECORD_HEADER_LENGTH)
{
this.inConnection = true;
int fragmentLength = TlsUtils.readUint16(buf, off + 11);
int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
if (received > recordLength)
{
recordQueue.addData(buf, off + recordLength, received - recordLength);
received = recordLength;
}
}
return received;
}
private void resetHeartbeat()
{
this.heartbeatInFlight = null;
this.heartbeatResendMillis = -1;
this.heartbeatResendTimeout = null;
this.heartbeatTimeout = new Timeout(heartbeat.getIdleMillis());
}
private void sendHeartbeatMessage(HeartbeatMessage heartbeatMessage)
throws IOException
{
ByteArrayOutputStream output = new ByteArrayOutputStream();
heartbeatMessage.encode(output);
byte[] buf = output.toByteArray();
sendRecord(ContentType.heartbeat, buf, 0, buf.length);
}
/*
* Currently synchronized here to ensure heartbeat sends and application data sends don't
* interfere with each other. It may be overly cautious; the sequence number allocation is
* atomic, and if we synchronize only on the datagram send instead, then the only effect should
* be possible reordering of records (which might surprise a reliable transport implementation).
*/
private synchronized void sendRecord(short contentType, byte[] buf, int off, int len)
throws IOException
{
// Never send anything until a valid ClientHello has been received
if (writeVersion == null)
{
return;
}
if (len > this.plaintextLimit)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
/*
* RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
* or ChangeCipherSpec content types.
*/
if (len < 1 && contentType != ContentType.application_data)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
int recordEpoch = writeEpoch.getEpoch();
long recordSequenceNumber = writeEpoch.allocateSequenceNumber();
byte[] ciphertext = writeEpoch.getCipher().encodePlaintext(
getMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len);
// TODO Check the ciphertext length?
byte[] record = new byte[ciphertext.length + RECORD_HEADER_LENGTH];
TlsUtils.writeUint8(contentType, record, 0);
TlsUtils.writeVersion(writeVersion, record, 1);
TlsUtils.writeUint16(recordEpoch, record, 3);
TlsUtils.writeUint48(recordSequenceNumber, record, 5);
TlsUtils.writeUint16(ciphertext.length, record, 11);
System.arraycopy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.length);
sendDatagram(transport, record);
}
private static long getMacSequenceNumber(int epoch, long sequence_number)
{
return ((epoch & 0xFFFFFFFFL) << 48) | sequence_number;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy