All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.tls.RecordStream Maven / Gradle / Ivy

Go to download

The Bouncy Castle Java APIs for the TLS, including a JSSE provider. The APIs are designed primarily to be used in conjunction with the BC FIPS provider. The APIs may also be used with other providers although if being used in a FIPS context it is the responsibility of the user to ensure that any other providers used are FIPS certified and used appropriately.

There is a newer version: 2.0.19
Show newest version
package org.bouncycastle.tls;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;

import org.bouncycastle.tls.crypto.TlsCipher;
import org.bouncycastle.tls.crypto.TlsDecodeResult;
import org.bouncycastle.tls.crypto.TlsEncodeResult;
import org.bouncycastle.tls.crypto.TlsNullNullCipher;

/**
 * An implementation of the TLS 1.0/1.1/1.2 record layer.
 */
class RecordStream
{
    private static int DEFAULT_PLAINTEXT_LIMIT = (1 << 14);

    private final Record inputRecord = new Record();
    private final SequenceNumber readSeqNo = new SequenceNumber(), writeSeqNo = new SequenceNumber();

    private TlsProtocol handler;
    private InputStream input;
    private OutputStream output;
    private TlsCipher pendingCipher = null;
    private TlsCipher readCipher = TlsNullNullCipher.INSTANCE;
    private TlsCipher readCipherDeferred = null;
    private TlsCipher writeCipher = TlsNullNullCipher.INSTANCE;

    private ProtocolVersion writeVersion = null;

    private int plaintextLimit = DEFAULT_PLAINTEXT_LIMIT;
    private int ciphertextLimit = DEFAULT_PLAINTEXT_LIMIT;
    private boolean ignoreChangeCipherSpec = false;

    RecordStream(TlsProtocol handler, InputStream input, OutputStream output)
    {
        this.handler = handler;
        this.input = input;
        this.output = output;
    }

    int getPlaintextLimit()
    {
        return plaintextLimit;
    }

    void setPlaintextLimit(int plaintextLimit)
    {
        this.plaintextLimit = plaintextLimit;
        this.ciphertextLimit = readCipher.getCiphertextDecodeLimit(plaintextLimit);
    }

    void setWriteVersion(ProtocolVersion writeVersion)
    {
        this.writeVersion = writeVersion;
    }

    void setIgnoreChangeCipherSpec(boolean ignoreChangeCipherSpec)
    {
        this.ignoreChangeCipherSpec = ignoreChangeCipherSpec;
    }

    void setPendingCipher(TlsCipher tlsCipher)
    {
        this.pendingCipher = tlsCipher;
    }

    void notifyChangeCipherSpecReceived()
        throws IOException
    {
        if (pendingCipher == null)
        {
            throw new TlsFatalAlert(AlertDescription.unexpected_message, "No pending cipher");
        }

        enablePendingCipherRead(false);
    }

    void enablePendingCipherRead(boolean deferred)
        throws IOException
    {
        if (pendingCipher == null)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }
        if (readCipherDeferred != null)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }
        if (deferred)
        {
            this.readCipherDeferred = pendingCipher;
        }
        else
        {
            this.readCipher = pendingCipher;
            this.ciphertextLimit = readCipher.getCiphertextDecodeLimit(plaintextLimit);
            readSeqNo.reset();
        }
    }

    void enablePendingCipherWrite()
        throws IOException
    {
        if (pendingCipher == null)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }
        this.writeCipher = this.pendingCipher;
        writeSeqNo.reset();
    }

    void finaliseHandshake()
        throws IOException
    {
        if (readCipher != pendingCipher || writeCipher != pendingCipher)
        {
            throw new TlsFatalAlert(AlertDescription.handshake_failure);
        }
        this.pendingCipher = null;
    }

    boolean needsKeyUpdate()
    {
        return writeSeqNo.currentValue() >= (1L << 20);
    }

    void notifyKeyUpdateReceived() throws IOException
    {
        readCipher.rekeyDecoder();
        readSeqNo.reset();
    }

    void notifyKeyUpdateSent() throws IOException
    {
        writeCipher.rekeyEncoder();
        writeSeqNo.reset();
    }

    RecordPreview previewRecordHeader(byte[] recordHeader) throws IOException
    {
        short recordType = checkRecordType(recordHeader, RecordFormat.TYPE_OFFSET);

//        ProtocolVersion recordVersion = TlsUtils.readVersion(recordHeader, RecordFormat.VERSION_OFFSET);

        int length = TlsUtils.readUint16(recordHeader, RecordFormat.LENGTH_OFFSET);

        checkLength(length, ciphertextLimit, AlertDescription.record_overflow);

        int recordSize = RecordFormat.FRAGMENT_OFFSET + length;
        int applicationDataLimit = 0;

        // NOTE: For TLS 1.3, this only MIGHT be application data
        if (ContentType.application_data == recordType && handler.isApplicationDataReady())
        {
            applicationDataLimit = Math.max(0, Math.min(plaintextLimit, readCipher.getPlaintextLimit(length)));
        }

        return new RecordPreview(recordSize, applicationDataLimit);
    }

    RecordPreview previewOutputRecord(int contentLength)
    {
        int contentLimit = Math.max(0, Math.min(plaintextLimit, contentLength));
        int recordSize = previewOutputRecordSize(contentLimit);
        return new RecordPreview(recordSize, contentLimit);
    }

    int previewOutputRecordSize(int contentLength)
    {
//        assert contentLength <= plaintextLimit
        return RecordFormat.FRAGMENT_OFFSET + writeCipher.getCiphertextEncodeLimit(contentLength, plaintextLimit);
    }

    boolean readFullRecord(byte[] input, int inputOff, int inputLen)
        throws IOException
    {
        if (inputLen < RecordFormat.FRAGMENT_OFFSET)
        {
            return false;
        }

        int length = TlsUtils.readUint16(input, inputOff + RecordFormat.LENGTH_OFFSET);
        if (inputLen != (RecordFormat.FRAGMENT_OFFSET + length))
        {
            return false;
        }

        short recordType = checkRecordType(input, inputOff + RecordFormat.TYPE_OFFSET);

        ProtocolVersion recordVersion = TlsUtils.readVersion(input, inputOff + RecordFormat.VERSION_OFFSET);

        checkLength(length, ciphertextLimit, AlertDescription.record_overflow);

        if (ignoreChangeCipherSpec && ContentType.change_cipher_spec == recordType)
        {
            checkChangeCipherSpec(input, inputOff + RecordFormat.FRAGMENT_OFFSET, length);
            return true;
        }

        TlsDecodeResult decoded = decodeAndVerify(recordType, recordVersion, input,
            inputOff + RecordFormat.FRAGMENT_OFFSET, length);

        handler.processRecord(decoded.contentType, decoded.buf, decoded.off, decoded.len);
        return true;
    }

    boolean readRecord()
        throws IOException
    {
        if (!inputRecord.readHeader(input))
        {
            return false;
        }

        short recordType = checkRecordType(inputRecord.buf, RecordFormat.TYPE_OFFSET);

        ProtocolVersion recordVersion = TlsUtils.readVersion(inputRecord.buf, RecordFormat.VERSION_OFFSET);

        int length = TlsUtils.readUint16(inputRecord.buf, RecordFormat.LENGTH_OFFSET);

        checkLength(length, ciphertextLimit, AlertDescription.record_overflow);

        inputRecord.readFragment(input, length);

        TlsDecodeResult decoded;
        try
        {
            if (ignoreChangeCipherSpec && ContentType.change_cipher_spec == recordType)
            {
                checkChangeCipherSpec(inputRecord.buf, RecordFormat.FRAGMENT_OFFSET, length);
                return true;
            }

            decoded = decodeAndVerify(recordType, recordVersion, inputRecord.buf, RecordFormat.FRAGMENT_OFFSET, length);
        }
        finally
        {
            inputRecord.reset();
        }

        handler.processRecord(decoded.contentType, decoded.buf, decoded.off, decoded.len);
        return true;
    }

    TlsDecodeResult decodeAndVerify(short recordType, ProtocolVersion recordVersion, byte[] ciphertext, int off, int len)
        throws IOException
    {
        long seqNo = readSeqNo.nextValue(AlertDescription.unexpected_message);
        TlsDecodeResult decoded = readCipher.decodeCiphertext(seqNo, recordType, recordVersion, ciphertext, off, len);

        checkLength(decoded.len, plaintextLimit, AlertDescription.record_overflow);

        /*
         * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
         * or ChangeCipherSpec content types.
         */
        if (decoded.len < 1 && decoded.contentType != ContentType.application_data)
        {
            throw new TlsFatalAlert(AlertDescription.illegal_parameter);
        }

        return decoded;
    }

    void writeRecord(short contentType, byte[] plaintext, int plaintextOffset, int plaintextLength)
        throws IOException
    {
        // Never send anything until a valid ClientHello has been received
        if (writeVersion == null)
        {
            return;
        }

        /*
         * RFC 5246 6.2.1 The length should not exceed 2^14.
         */
        checkLength(plaintextLength, plaintextLimit, AlertDescription.internal_error);

        /*
         * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
         * or ChangeCipherSpec content types.
         */
        if (plaintextLength < 1 && contentType != ContentType.application_data)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        long seqNo = writeSeqNo.nextValue(AlertDescription.internal_error);
        ProtocolVersion recordVersion = writeVersion;

        TlsEncodeResult encoded = writeCipher.encodePlaintext(seqNo, contentType, recordVersion,
            RecordFormat.FRAGMENT_OFFSET, plaintext, plaintextOffset, plaintextLength);

        int ciphertextLength = encoded.len - RecordFormat.FRAGMENT_OFFSET;
        TlsUtils.checkUint16(ciphertextLength);

        TlsUtils.writeUint8(encoded.recordType, encoded.buf, encoded.off + RecordFormat.TYPE_OFFSET);
        TlsUtils.writeVersion(recordVersion, encoded.buf, encoded.off + RecordFormat.VERSION_OFFSET);
        TlsUtils.writeUint16(ciphertextLength, encoded.buf, encoded.off + RecordFormat.LENGTH_OFFSET);

        try
        {
            output.write(encoded.buf, encoded.off, encoded.len);
        }
        catch (InterruptedIOException e)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error, e);
        }

        output.flush();
    }

    void close() throws IOException
    {
        inputRecord.reset();

        IOException io = null;
        try
        {
            input.close();
        }
        catch (IOException e)
        {
            io = e;
        }

        try
        {
            output.close();
        }
        catch (IOException e)
        {
            if (io == null)
            {
                io = e;
            }
            else
            {
                // TODO[tls] Available from JDK 7
//                io.addSuppressed(e);
            }
        }

        if (io != null)
        {
            throw io;
        }
    }

    private void checkChangeCipherSpec(byte[] buf, int off, int len)
        throws IOException
    {
        if (1 != len || (byte)ChangeCipherSpec.change_cipher_spec != buf[off])
        {
            throw new TlsFatalAlert(AlertDescription.unexpected_message,
                "Malformed " + ContentType.getText(ContentType.change_cipher_spec));
        }
    }

    private short checkRecordType(byte[] buf, int off)
        throws IOException
    {
        short recordType = TlsUtils.readUint8(buf, off);

        if (null != readCipherDeferred && recordType == ContentType.application_data)
        {
            this.readCipher = readCipherDeferred;
            this.readCipherDeferred = null;
            this.ciphertextLimit = readCipher.getCiphertextDecodeLimit(plaintextLimit);
            readSeqNo.reset();
        }
        else if (readCipher.usesOpaqueRecordType())
        {
            if (ContentType.application_data != recordType)
            {
                if (ignoreChangeCipherSpec && ContentType.change_cipher_spec == recordType)
                {
                    // See RFC 8446 D.4.
                }
                else
                {
                    throw new TlsFatalAlert(AlertDescription.unexpected_message,
                        "Opaque " + ContentType.getText(recordType));
                }
            }
        }
        else
        {
            switch (recordType)
            {
            case ContentType.application_data:
            {
                if (!handler.isApplicationDataReady())
                {
                    throw new TlsFatalAlert(AlertDescription.unexpected_message,
                        "Not ready for " + ContentType.getText(ContentType.application_data));
                }
                break;
            }
            case ContentType.alert:
            case ContentType.change_cipher_spec:
            case ContentType.handshake:
    //        case ContentType.heartbeat:
                break;
            default:
                throw new TlsFatalAlert(AlertDescription.unexpected_message,
                    "Unsupported " + ContentType.getText(recordType));
            }
        }

        return recordType;
    }

    private static void checkLength(int length, int limit, short alertDescription)
        throws IOException
    {
        if (length > limit)
        {
            throw new TlsFatalAlert(alertDescription);
        }
    }

    private static class Record
    {
        private final byte[] header = new byte[RecordFormat.FRAGMENT_OFFSET];

        volatile byte[] buf = header;
        volatile int pos = 0;

        void fillTo(InputStream input, int length) throws IOException
        {
            while (pos < length)
            {
                try
                {
                    int numRead = input.read(buf, pos, length - pos);
                    if (numRead < 0)
                    {
                        break;
                    }
                    pos += numRead;
                }
                catch (InterruptedIOException e)
                {
                    /*
                     * Although modifying the bytesTransferred doesn't seem ideal, it's the simplest
                     * way to make sure we don't break client code that depends on the exact type,
                     * e.g. in Apache's httpcomponents-core-4.4.9, BHttpConnectionBase.isStale
                     * depends on the exception type being SocketTimeoutException (or a subclass).
                     *
                     * We can set to 0 here because the only relevant callstack (via
                     * TlsProtocol.readApplicationData) only ever processes one non-empty record (so
                     * interruption after partial output cannot occur).
                     */
                    pos += e.bytesTransferred;
                    e.bytesTransferred = 0;
                    throw e;
                }
            }
        }

        void readFragment(InputStream input, int fragmentLength) throws IOException
        {
            int recordLength = RecordFormat.FRAGMENT_OFFSET + fragmentLength;
            resize(recordLength);
            fillTo(input, recordLength);
            if (pos < recordLength)
            {
                throw new EOFException();
            }
        }

        boolean readHeader(InputStream input) throws IOException
        {
            fillTo(input, RecordFormat.FRAGMENT_OFFSET);
            if (pos == 0)
            {
                return false;
            }
            if (pos < RecordFormat.FRAGMENT_OFFSET)
            {
                throw new EOFException();
            }
            return true;
        }

        void reset()
        {
            buf = header;
            pos = 0;
        }

        private void resize(int length)
        {
            if (buf.length < length)
            {
                byte[] tmp = new byte[length];
                System.arraycopy(buf, 0, tmp, 0, pos);
                buf = tmp;
            }
        }
    }

    private static class SequenceNumber
    {
        private long value = 0L;
        private boolean exhausted = false;

        synchronized long currentValue()
        {
            return value;
        }

        synchronized long nextValue(short alertDescription) throws TlsFatalAlert
        {
            if (exhausted)
            {
                throw new TlsFatalAlert(alertDescription, "Sequence numbers exhausted");
            }
            long result = value;
            if (++value == 0)
            {
                exhausted = true;
            }
            return result;
        }

        synchronized void reset()
        {
            this.value = 0L;
            this.exhausted = false;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy