All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.tls.crypto.impl.TlsAEADCipher Maven / Gradle / Ivy

package org.bouncycastle.tls.crypto.impl;

import java.io.IOException;

import org.bouncycastle.tls.AlertDescription;
import org.bouncycastle.tls.ContentType;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.SecurityParameters;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.TlsCipher;
import org.bouncycastle.tls.crypto.TlsCryptoParameters;
import org.bouncycastle.tls.crypto.TlsCryptoUtils;
import org.bouncycastle.tls.crypto.TlsDecodeResult;
import org.bouncycastle.tls.crypto.TlsEncodeResult;
import org.bouncycastle.tls.crypto.TlsSecret;

/**
 * A generic TLS 1.2 AEAD cipher.
 */
public class TlsAEADCipher
    implements TlsCipher
{
    public static final int AEAD_CCM = 1;
    public static final int AEAD_CHACHA20_POLY1305 = 2;
    public static final int AEAD_GCM = 3;

    private static final int NONCE_RFC5288 = 1;
    private static final int NONCE_RFC7905 = 2;

    protected final TlsCryptoParameters cryptoParams;
    protected final int keySize;
    protected final int macSize;
    protected final int fixed_iv_length;
    protected final int record_iv_length;

    protected final TlsAEADCipherImpl decryptCipher, encryptCipher;
    protected final byte[] decryptNonce, encryptNonce;

    protected final boolean isTLSv13;
    protected final int nonceMode;

    public TlsAEADCipher(TlsCryptoParameters cryptoParams, TlsAEADCipherImpl encryptCipher, TlsAEADCipherImpl decryptCipher,
        int keySize, int macSize, int aeadType) throws IOException
    {
        final SecurityParameters securityParameters = cryptoParams.getSecurityParametersHandshake();
        final ProtocolVersion negotiatedVersion = securityParameters.getNegotiatedVersion();

        if (!TlsImplUtils.isTLSv12(negotiatedVersion))
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        this.isTLSv13 = TlsImplUtils.isTLSv13(negotiatedVersion);
        this.nonceMode = getNonceMode(isTLSv13, aeadType);

        switch (nonceMode)
        {
        case NONCE_RFC5288:
            this.fixed_iv_length = 4;
            this.record_iv_length = 8;
            break;
        case NONCE_RFC7905:
            this.fixed_iv_length = 12;
            this.record_iv_length = 0;
            break;
        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        this.cryptoParams = cryptoParams;
        this.keySize = keySize;
        this.macSize = macSize;

        this.decryptCipher = decryptCipher;
        this.encryptCipher = encryptCipher;

        this.decryptNonce = new byte[fixed_iv_length];
        this.encryptNonce = new byte[fixed_iv_length];

        final boolean isServer = cryptoParams.isServer();
        if (isTLSv13)
        {
            rekeyCipher(securityParameters, decryptCipher, decryptNonce, !isServer);
            rekeyCipher(securityParameters, encryptCipher, encryptNonce, isServer);
            return;
        }

        int keyBlockSize = (2 * keySize) + (2 * fixed_iv_length);
        byte[] keyBlock = TlsImplUtils.calculateKeyBlock(cryptoParams, keyBlockSize);
        int pos = 0;

        if (isServer)
        {
            decryptCipher.setKey(keyBlock, pos, keySize); pos += keySize;
            encryptCipher.setKey(keyBlock, pos, keySize); pos += keySize;

            System.arraycopy(keyBlock, pos, decryptNonce, 0, fixed_iv_length); pos += fixed_iv_length;
            System.arraycopy(keyBlock, pos, encryptNonce, 0, fixed_iv_length); pos += fixed_iv_length;
        }
        else
        {
            encryptCipher.setKey(keyBlock, pos, keySize); pos += keySize;
            decryptCipher.setKey(keyBlock, pos, keySize); pos += keySize;

            System.arraycopy(keyBlock, pos, encryptNonce, 0, fixed_iv_length); pos += fixed_iv_length;
            System.arraycopy(keyBlock, pos, decryptNonce, 0, fixed_iv_length); pos += fixed_iv_length;
        }

        if (keyBlockSize != pos)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        int nonceLength = fixed_iv_length + record_iv_length;

        // NOTE: Ensure dummy nonce is not part of the generated sequence(s)
        byte[] dummyNonce = new byte[nonceLength];
        dummyNonce[0] = (byte)~encryptNonce[0];
        dummyNonce[1] = (byte)~decryptNonce[1];

        encryptCipher.init(dummyNonce, macSize, null);
        decryptCipher.init(dummyNonce, macSize, null);
    }

    public int getCiphertextDecodeLimit(int plaintextLimit)
    {
        return plaintextLimit + macSize + record_iv_length + (isTLSv13 ? 1 : 0);
    }

    public int getCiphertextEncodeLimit(int plaintextLength, int plaintextLimit)
    {
        int innerPlaintextLimit = plaintextLength;
        if (isTLSv13)
        {
            // TODO[tls13] Add support for padding
            int maxPadding = 0;

            innerPlaintextLimit = 1 + Math.min(plaintextLimit, plaintextLength + maxPadding);
        }

        return innerPlaintextLimit + macSize + record_iv_length;
    }

    public int getPlaintextLimit(int ciphertextLimit)
    {
        return ciphertextLimit - macSize - record_iv_length - (isTLSv13 ? 1 : 0);
    }

    public TlsEncodeResult encodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion,
        int headerAllocation, byte[] plaintext, int plaintextOffset, int plaintextLength) throws IOException
    {
        byte[] nonce = new byte[encryptNonce.length + record_iv_length];

        switch (nonceMode)
        {
        case NONCE_RFC5288:
            System.arraycopy(encryptNonce, 0, nonce, 0, encryptNonce.length);
            // RFC 5288/6655: The nonce_explicit MAY be the 64-bit sequence number.
            TlsUtils.writeUint64(seqNo, nonce, encryptNonce.length);
            break;
        case NONCE_RFC7905:
            TlsUtils.writeUint64(seqNo, nonce, nonce.length - 8);
            for (int i = 0; i < encryptNonce.length; ++i)
            {
                nonce[i] ^= encryptNonce[i];
            }
            break;
        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        int extraLength = (isTLSv13 ? 1 : 0);

        // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review
        int encryptionLength = encryptCipher.getOutputSize(plaintextLength + extraLength);
        int ciphertextLength = record_iv_length + encryptionLength;

        byte[] output = new byte[headerAllocation + ciphertextLength];
        int outputPos = headerAllocation;

        if (record_iv_length != 0)
        {
            System.arraycopy(nonce, nonce.length - record_iv_length, output, outputPos, record_iv_length);
            outputPos += record_iv_length;
        }

        short recordType = isTLSv13 ? ContentType.application_data : contentType;

        byte[] additionalData = getAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);

        try
        {
            System.arraycopy(plaintext, plaintextOffset, output, outputPos, plaintextLength);
            if (isTLSv13)
            {
                output[outputPos + plaintextLength] = (byte)contentType;
            }

            encryptCipher.init(nonce, macSize, additionalData);
            outputPos += encryptCipher.doFinal(output, outputPos, plaintextLength + extraLength, output, outputPos);
        }
        catch (RuntimeException e)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error, e);
        }

        if (outputPos != output.length)
        {
            // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction.
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        return new TlsEncodeResult(output, 0, output.length, recordType);
    }

    public TlsDecodeResult decodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion,
        byte[] ciphertext, int ciphertextOffset, int ciphertextLength) throws IOException
    {
        if (getPlaintextLimit(ciphertextLength) < 0)
        {
            throw new TlsFatalAlert(AlertDescription.decode_error);
        }

        byte[] nonce = new byte[decryptNonce.length + record_iv_length];

        switch (nonceMode)
        {
        case NONCE_RFC5288:
            System.arraycopy(decryptNonce, 0, nonce, 0, decryptNonce.length);
            System.arraycopy(ciphertext, ciphertextOffset, nonce, nonce.length - record_iv_length, record_iv_length);
            break;
        case NONCE_RFC7905:
            TlsUtils.writeUint64(seqNo, nonce, nonce.length - 8);
            for (int i = 0; i < decryptNonce.length; ++i)
            {
                nonce[i] ^= decryptNonce[i];
            }
            break;
        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        int encryptionOffset = ciphertextOffset + record_iv_length;
        int encryptionLength = ciphertextLength - record_iv_length;
        int plaintextLength = decryptCipher.getOutputSize(encryptionLength);

        byte[] additionalData = getAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);

        int outputPos;
        try
        {
            decryptCipher.init(nonce, macSize, additionalData);
            outputPos = decryptCipher.doFinal(ciphertext, encryptionOffset, encryptionLength, ciphertext,
                encryptionOffset);
        }
        catch (RuntimeException e)
        {
            throw new TlsFatalAlert(AlertDescription.bad_record_mac, e);
        }

        if (outputPos != plaintextLength)
        {
            // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction.
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        short contentType = recordType;
        if (isTLSv13)
        {
            // Strip padding and read true content type from TLSInnerPlaintext
            int pos = plaintextLength;
            for (;;)
            {
                if (--pos < 0)
                {
                    throw new TlsFatalAlert(AlertDescription.unexpected_message);
                }

                byte octet = ciphertext[encryptionOffset + pos];
                if (0 != octet)
                {
                    contentType = (short)(octet & 0xFF);
                    plaintextLength = pos;
                    break;
                }
            }
        }

        return new TlsDecodeResult(ciphertext, encryptionOffset, plaintextLength, contentType);
    }

    public void rekeyDecoder() throws IOException
    {
        rekeyCipher(cryptoParams.getSecurityParametersConnection(), decryptCipher, decryptNonce, !cryptoParams.isServer());
    }

    public void rekeyEncoder() throws IOException
    {
        rekeyCipher(cryptoParams.getSecurityParametersConnection(), encryptCipher, encryptNonce, cryptoParams.isServer());
    }

    public boolean usesOpaqueRecordType()
    {
        return isTLSv13;
    }

    protected byte[] getAdditionalData(long seqNo, short recordType, ProtocolVersion recordVersion,
        int ciphertextLength, int plaintextLength) throws IOException
    {
        if (isTLSv13)
        {
            /*
             * TLSCiphertext.opaque_type || TLSCiphertext.legacy_record_version || TLSCiphertext.length
             */
            byte[] additional_data = new byte[5];
            TlsUtils.writeUint8(recordType, additional_data, 0);
            TlsUtils.writeVersion(recordVersion, additional_data, 1);
            TlsUtils.writeUint16(ciphertextLength, additional_data, 3);
            return additional_data;
        }
        else
        {
            /*
             * seq_num + TLSCompressed.type + TLSCompressed.version + TLSCompressed.length
             */
            byte[] additional_data = new byte[13];
            TlsUtils.writeUint64(seqNo, additional_data, 0);
            TlsUtils.writeUint8(recordType, additional_data, 8);
            TlsUtils.writeVersion(recordVersion, additional_data, 9);
            TlsUtils.writeUint16(plaintextLength, additional_data, 11);
            return additional_data;
        }
    }

    protected void rekeyCipher(SecurityParameters securityParameters, TlsAEADCipherImpl cipher, byte[] nonce,
        boolean serverSecret) throws IOException
    {
        if (!isTLSv13)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        TlsSecret secret = serverSecret
            ?   securityParameters.getTrafficSecretServer()
            :   securityParameters.getTrafficSecretClient();

        // TODO[tls13] For early data, have to disable server->client
        if (null == secret)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        setup13Cipher(cipher, nonce, secret, securityParameters.getPRFCryptoHashAlgorithm());
    }

    protected void setup13Cipher(TlsAEADCipherImpl cipher, byte[] nonce, TlsSecret secret, int cryptoHashAlgorithm)
        throws IOException
    {
        byte[] key = TlsCryptoUtils.hkdfExpandLabel(secret, cryptoHashAlgorithm, "key", TlsUtils.EMPTY_BYTES, keySize).extract();
        byte[] iv = TlsCryptoUtils.hkdfExpandLabel(secret, cryptoHashAlgorithm, "iv", TlsUtils.EMPTY_BYTES, fixed_iv_length).extract();

        cipher.setKey(key, 0, keySize);
        System.arraycopy(iv, 0, nonce, 0, fixed_iv_length);

        // NOTE: Ensure dummy nonce is not part of the generated sequence(s)
        iv[0] ^= 0x80;
        cipher.init(iv, macSize, null);
    }

    private static int getNonceMode(boolean isTLSv13, int aeadType) throws IOException
    {
        switch (aeadType)
        {
        case AEAD_CCM:
        case AEAD_GCM:
            return isTLSv13 ? NONCE_RFC7905 : NONCE_RFC5288;

        case AEAD_CHACHA20_POLY1305:
            return NONCE_RFC7905;

        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy