All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.tls.crypto.impl.TlsAEADCipher Maven / Gradle / Ivy

Go to download

The Bouncy Castle Java APIs for the TLS, including a JSSE provider. The APIs are designed primarily to be used in conjunction with the BC FIPS provider. The APIs may also be used with other providers although if being used in a FIPS context it is the responsibility of the user to ensure that any other providers used are FIPS certified and used appropriately.

There is a newer version: 2.0.19
Show newest version
package org.bouncycastle.tls.crypto.impl;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;

import org.bouncycastle.tls.AlertDescription;
import org.bouncycastle.tls.ContentType;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.SecurityParameters;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.TlsCipher;
import org.bouncycastle.tls.crypto.TlsCryptoParameters;
import org.bouncycastle.tls.crypto.TlsCryptoUtils;
import org.bouncycastle.tls.crypto.TlsDecodeResult;
import org.bouncycastle.tls.crypto.TlsEncodeResult;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.util.Arrays;

/**
 * A generic TLS 1.2 AEAD cipher.
 */
public class TlsAEADCipher
    implements TlsCipher
{
    public static final int AEAD_CCM = 1;
    public static final int AEAD_CHACHA20_POLY1305 = 2;
    public static final int AEAD_GCM = 3;

    private static final int NONCE_RFC5288 = 1;
    private static final int NONCE_RFC7905 = 2;

    public static final byte[] EPOCH_1 = {0x00, 0x01};

    private static final Class fipsNonceGeneratorClass = lookup("org.bouncycastle.crypto.fips.FipsNonceGenerator");

    protected final TlsCryptoParameters cryptoParams;
    protected final int keySize;
    protected final int macSize;
    protected final int fixed_iv_length;
    protected final int record_iv_length;

    protected final TlsAEADCipherImpl decryptCipher, encryptCipher;
    protected final byte[] decryptNonce, encryptNonce;

    protected final boolean isTLSv13;
    protected final int nonceMode;
    protected final AEADNonceGenerator encryptNonceGenerator;

    public TlsAEADCipher(TlsCryptoParameters cryptoParams, TlsAEADCipherImpl encryptCipher, TlsAEADCipherImpl decryptCipher,
                         int keySize, int macSize, int aeadType)
        throws IOException
    {
        final SecurityParameters securityParameters = cryptoParams.getSecurityParametersHandshake();
        final ProtocolVersion negotiatedVersion = securityParameters.getNegotiatedVersion();

        if (!TlsImplUtils.isTLSv12(negotiatedVersion))
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        this.isTLSv13 = TlsImplUtils.isTLSv13(negotiatedVersion);
        this.nonceMode = getNonceMode(isTLSv13, aeadType);

        switch (nonceMode)
        {
        case NONCE_RFC5288:
            this.fixed_iv_length = 4;
            this.record_iv_length = 8;
            break;
        case NONCE_RFC7905:
            this.fixed_iv_length = 12;
            this.record_iv_length = 0;
            break;
        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        this.cryptoParams = cryptoParams;
        this.keySize = keySize;
        this.macSize = macSize;

        this.decryptCipher = decryptCipher;
        this.encryptCipher = encryptCipher;

        this.decryptNonce = new byte[fixed_iv_length];
        this.encryptNonce = new byte[fixed_iv_length];

        final boolean isServer = cryptoParams.isServer();
        if (isTLSv13)
        {
            this.encryptNonceGenerator = null;
            rekeyCipher(securityParameters, decryptCipher, decryptNonce, !isServer);
            rekeyCipher(securityParameters, encryptCipher, encryptNonce, isServer);
            return;
        }

        int keyBlockSize = (2 * keySize) + (2 * fixed_iv_length);
        byte[] keyBlock = TlsImplUtils.calculateKeyBlock(cryptoParams, keyBlockSize);
        int pos = 0;

        if (isServer)
        {
            decryptCipher.setKey(keyBlock, pos, keySize);
            pos += keySize;
            encryptCipher.setKey(keyBlock, pos, keySize);
            pos += keySize;

            System.arraycopy(keyBlock, pos, decryptNonce, 0, fixed_iv_length);
            pos += fixed_iv_length;
            System.arraycopy(keyBlock, pos, encryptNonce, 0, fixed_iv_length);
            pos += fixed_iv_length;
        }
        else
        {
            encryptCipher.setKey(keyBlock, pos, keySize);
            pos += keySize;
            decryptCipher.setKey(keyBlock, pos, keySize);
            pos += keySize;

            System.arraycopy(keyBlock, pos, encryptNonce, 0, fixed_iv_length);
            pos += fixed_iv_length;
            System.arraycopy(keyBlock, pos, decryptNonce, 0, fixed_iv_length);
            pos += fixed_iv_length;
        }

        if (keyBlockSize != pos)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        int nonceLength = fixed_iv_length + record_iv_length;

        // NOTE: Ensure dummy nonce is not part of the generated sequence(s)
        byte[] dummyNonce = new byte[nonceLength];
        dummyNonce[0] = (byte)~encryptNonce[0];
        dummyNonce[1] = (byte)~decryptNonce[1];

        encryptCipher.init(dummyNonce, macSize, null);
        decryptCipher.init(dummyNonce, macSize, null);

        if (AEAD_GCM == aeadType && null != fipsNonceGeneratorClass)
        {
            int counterBits = 64;
            byte[] baseNonce = Arrays.copyOf(encryptNonce, nonceLength);

            if (negotiatedVersion.isDTLS())
            {
                counterBits = 48;
                baseNonce[baseNonce.length - 8] ^= EPOCH_1[0];
                baseNonce[baseNonce.length - 7] ^= EPOCH_1[1];
            }

            this.encryptNonceGenerator = new BCFipsAEADNonceGenerator(baseNonce, counterBits);
        }
        else
        {
            this.encryptNonceGenerator = null;
        }
    }

    public int getCiphertextDecodeLimit(int plaintextLimit)
    {
        return plaintextLimit + macSize + record_iv_length + (isTLSv13 ? 1 : 0);
    }

    public int getCiphertextEncodeLimit(int plaintextLength, int plaintextLimit)
    {
        int innerPlaintextLimit = plaintextLength;
        if (isTLSv13)
        {
            // TODO[tls13] Add support for padding
            int maxPadding = 0;

            innerPlaintextLimit = 1 + Math.min(plaintextLimit, plaintextLength + maxPadding);
        }

        return innerPlaintextLimit + macSize + record_iv_length;
    }

    public int getPlaintextLimit(int ciphertextLimit)
    {
        return ciphertextLimit - macSize - record_iv_length - (isTLSv13 ? 1 : 0);
    }

    public TlsEncodeResult encodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion,
                                           int headerAllocation, byte[] plaintext, int plaintextOffset, int plaintextLength)
        throws IOException
    {
        byte[] nonce = new byte[encryptNonce.length + record_iv_length];

        if (null != encryptNonceGenerator)
        {
            encryptNonceGenerator.generateNonce(nonce);
        }
        else
        {
            switch (nonceMode)
            {
            case NONCE_RFC5288:
                System.arraycopy(encryptNonce, 0, nonce, 0, encryptNonce.length);
                // RFC 5288/6655: The nonce_explicit MAY be the 64-bit sequence number.
                TlsUtils.writeUint64(seqNo, nonce, encryptNonce.length);
                break;
            case NONCE_RFC7905:
                TlsUtils.writeUint64(seqNo, nonce, nonce.length - 8);
                for (int i = 0; i < encryptNonce.length; ++i)
                {
                    nonce[i] ^= encryptNonce[i];
                }
                break;
            default:
                throw new TlsFatalAlert(AlertDescription.internal_error);
            }
        }

        int extraLength = (isTLSv13 ? 1 : 0);

        // TODO[tls13] If we support adding padding to TLSInnerPlaintext, this will need review
        int encryptionLength = encryptCipher.getOutputSize(plaintextLength + extraLength);
        int ciphertextLength = record_iv_length + encryptionLength;

        byte[] output = new byte[headerAllocation + ciphertextLength];
        int outputPos = headerAllocation;

        if (record_iv_length != 0)
        {
            System.arraycopy(nonce, nonce.length - record_iv_length, output, outputPos, record_iv_length);
            outputPos += record_iv_length;
        }

        short recordType = isTLSv13 ? ContentType.application_data : contentType;

        byte[] additionalData = getAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);

        try
        {
            System.arraycopy(plaintext, plaintextOffset, output, outputPos, plaintextLength);
            if (isTLSv13)
            {
                output[outputPos + plaintextLength] = (byte)contentType;
            }

            encryptCipher.init(nonce, macSize, additionalData);
            outputPos += encryptCipher.doFinal(output, outputPos, plaintextLength + extraLength, output, outputPos);
        }
        catch (RuntimeException e)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error, e);
        }

        if (outputPos != output.length)
        {
            // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction.
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        return new TlsEncodeResult(output, 0, output.length, recordType);
    }

    public TlsDecodeResult decodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion,
                                            byte[] ciphertext, int ciphertextOffset, int ciphertextLength)
        throws IOException
    {
        if (getPlaintextLimit(ciphertextLength) < 0)
        {
            throw new TlsFatalAlert(AlertDescription.decode_error);
        }

        byte[] nonce = new byte[decryptNonce.length + record_iv_length];

        switch (nonceMode)
        {
        case NONCE_RFC5288:
            System.arraycopy(decryptNonce, 0, nonce, 0, decryptNonce.length);
            System.arraycopy(ciphertext, ciphertextOffset, nonce, nonce.length - record_iv_length, record_iv_length);
            break;
        case NONCE_RFC7905:
            TlsUtils.writeUint64(seqNo, nonce, nonce.length - 8);
            for (int i = 0; i < decryptNonce.length; ++i)
            {
                nonce[i] ^= decryptNonce[i];
            }
            break;
        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        int encryptionOffset = ciphertextOffset + record_iv_length;
        int encryptionLength = ciphertextLength - record_iv_length;
        int plaintextLength = decryptCipher.getOutputSize(encryptionLength);

        byte[] additionalData = getAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);

        int outputPos;
        try
        {
            decryptCipher.init(nonce, macSize, additionalData);
            outputPos = decryptCipher.doFinal(ciphertext, encryptionOffset, encryptionLength, ciphertext,
                encryptionOffset);
        }
        catch (RuntimeException e)
        {
            throw new TlsFatalAlert(AlertDescription.bad_record_mac, e);
        }

        if (outputPos != plaintextLength)
        {
            // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction.
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        short contentType = recordType;
        if (isTLSv13)
        {
            // Strip padding and read true content type from TLSInnerPlaintext
            int pos = plaintextLength;
            for (; ; )
            {
                if (--pos < 0)
                {
                    throw new TlsFatalAlert(AlertDescription.unexpected_message);
                }

                byte octet = ciphertext[encryptionOffset + pos];
                if (0 != octet)
                {
                    contentType = (short)(octet & 0xFF);
                    plaintextLength = pos;
                    break;
                }
            }
        }

        return new TlsDecodeResult(ciphertext, encryptionOffset, plaintextLength, contentType);
    }

    public void rekeyDecoder()
        throws IOException
    {
        rekeyCipher(cryptoParams.getSecurityParametersConnection(), decryptCipher, decryptNonce, !cryptoParams.isServer());
    }

    public void rekeyEncoder()
        throws IOException
    {
        rekeyCipher(cryptoParams.getSecurityParametersConnection(), encryptCipher, encryptNonce, cryptoParams.isServer());
    }

    public boolean usesOpaqueRecordType()
    {
        return isTLSv13;
    }

    protected byte[] getAdditionalData(long seqNo, short recordType, ProtocolVersion recordVersion,
                                       int ciphertextLength, int plaintextLength)
        throws IOException
    {
        if (isTLSv13)
        {
            /*
             * TLSCiphertext.opaque_type || TLSCiphertext.legacy_record_version || TLSCiphertext.length
             */
            byte[] additional_data = new byte[5];
            TlsUtils.writeUint8(recordType, additional_data, 0);
            TlsUtils.writeVersion(recordVersion, additional_data, 1);
            TlsUtils.writeUint16(ciphertextLength, additional_data, 3);
            return additional_data;
        }
        else
        {
            /*
             * seq_num + TLSCompressed.type + TLSCompressed.version + TLSCompressed.length
             */
            byte[] additional_data = new byte[13];
            TlsUtils.writeUint64(seqNo, additional_data, 0);
            TlsUtils.writeUint8(recordType, additional_data, 8);
            TlsUtils.writeVersion(recordVersion, additional_data, 9);
            TlsUtils.writeUint16(plaintextLength, additional_data, 11);
            return additional_data;
        }
    }

    protected void rekeyCipher(SecurityParameters securityParameters, TlsAEADCipherImpl cipher, byte[] nonce,
                               boolean serverSecret)
        throws IOException
    {
        if (!isTLSv13)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        TlsSecret secret = serverSecret
            ? securityParameters.getTrafficSecretServer()
            : securityParameters.getTrafficSecretClient();

        // TODO[tls13] For early data, have to disable server->client
        if (null == secret)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        setup13Cipher(cipher, nonce, secret, securityParameters.getPRFCryptoHashAlgorithm());
    }

    protected void setup13Cipher(TlsAEADCipherImpl cipher, byte[] nonce, TlsSecret secret, int cryptoHashAlgorithm)
        throws IOException
    {
        byte[] key = TlsCryptoUtils.hkdfExpandLabel(secret, cryptoHashAlgorithm, "key", TlsUtils.EMPTY_BYTES, keySize).extract();
        byte[] iv = TlsCryptoUtils.hkdfExpandLabel(secret, cryptoHashAlgorithm, "iv", TlsUtils.EMPTY_BYTES, fixed_iv_length).extract();

        cipher.setKey(key, 0, keySize);
        System.arraycopy(iv, 0, nonce, 0, fixed_iv_length);

        // NOTE: Ensure dummy nonce is not part of the generated sequence(s)
        iv[0] ^= 0x80;
        cipher.init(iv, macSize, null);
    }

    private static int getNonceMode(boolean isTLSv13, int aeadType)
        throws IOException
    {
        switch (aeadType)
        {
        case AEAD_CCM:
        case AEAD_GCM:
            return isTLSv13 ? NONCE_RFC7905 : NONCE_RFC5288;

        case AEAD_CHACHA20_POLY1305:
            return NONCE_RFC7905;

        default:
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }
    }

    static Class lookup(final String className)
    {
        if (null == className)
        {
            return null;
        }

        return AccessController.doPrivileged(new PrivilegedAction>()
        {
            public Class run()
            {
                try
                {
                    ClassLoader classLoader = TlsAEADCipher.class.getClassLoader();
                    Class clazz = (null == classLoader)
                        ? Class.forName(className)
                        : classLoader.loadClass(className);
                    return clazz;
                }
                catch (Exception e)
                {
                }

                return null;
            }
        });
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy