All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.tls.crypto.impl.ChaCha20Poly1305Cipher Maven / Gradle / Ivy

package org.bouncycastle.tls.crypto.impl;

import java.io.IOException;

import org.bouncycastle.tls.AlertDescription;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.TlsCipher;
import org.bouncycastle.tls.crypto.TlsCryptoParameters;
import org.bouncycastle.tls.crypto.TlsMAC;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Pack;

/**
 * Cipher suite specified in RFC 7905 using ChaCha20 and Poly1305.
 */
public class ChaCha20Poly1305Cipher
    implements TlsCipher
{
    private static final byte[] ZEROES = new byte[15];

    protected TlsCryptoParameters context;

    protected TlsMAC writeMac;
    protected TlsMAC readMac;

    protected TlsStreamCipherImpl encryptCipher;
    protected TlsStreamCipherImpl decryptCipher;

    protected byte[] encryptIV, decryptIV;

    public ChaCha20Poly1305Cipher(TlsCryptoParameters context, TlsStreamCipherImpl encryptCipher, TlsStreamCipherImpl decryptCipher,
                                       TlsMAC writeMac, TlsMAC readMac)
        throws IOException
    {
        if (!TlsImplUtils.isTLSv12(context))
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        this.context = context;

        int cipherKeySize = 32;
        // TODO SecurityParameters.fixed_iv_length
        int fixed_iv_length = 12;
        // TODO SecurityParameters.record_iv_length = 0

        int key_block_size = (2 * cipherKeySize) + (2 * fixed_iv_length);

        byte[] key_block = TlsImplUtils.calculateKeyBlock(context, key_block_size);

        int offset = 0;

        byte[] client_write_key = Arrays.copyOfRange(key_block, offset, offset + cipherKeySize);
        offset += cipherKeySize;
        byte[] server_write_key = Arrays.copyOfRange(key_block, offset, offset + cipherKeySize);
        offset += cipherKeySize;
        byte[] client_write_IV = Arrays.copyOfRange(key_block, offset, offset + fixed_iv_length);
        offset += fixed_iv_length;
        byte[] server_write_IV = Arrays.copyOfRange(key_block, offset, offset + fixed_iv_length);
        offset += fixed_iv_length;

        if (offset != key_block_size)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        this.writeMac = writeMac;
        this.readMac = readMac;
        this.encryptCipher = encryptCipher;
        this.decryptCipher = decryptCipher;

        byte[] encryptKey, decryptKey;
        if (context.isServer())
        {
            encryptKey = server_write_key;
            decryptKey = client_write_key;
            this.encryptIV = server_write_IV;
            this.decryptIV = client_write_IV;
        }
        else
        {
            encryptKey = client_write_key;
            decryptKey = server_write_key;
            this.encryptIV = client_write_IV;
            this.decryptIV = server_write_IV;
        }

        this.encryptCipher.setKey(encryptKey);
        this.encryptCipher.init(encryptIV);
        this.decryptCipher.setKey(decryptKey);
        this.decryptCipher.init(decryptIV);
    }

    public int getPlaintextLimit(int ciphertextLimit)
    {
        return ciphertextLimit - 16;
    }

    public byte[] encodePlaintext(long seqNo, short type, byte[] plaintext, int offset, int len) throws IOException
    {
        initRecord(encryptCipher, seqNo, encryptIV);

        // MAC key is from the zeros at the front.
        byte[] cipherOut = new byte[64 + len];
        System.arraycopy(plaintext, offset, cipherOut, 64, len);

        encryptCipher.doFinal(cipherOut, 0, cipherOut.length, cipherOut, 0);

        byte[] output = new byte[len + 16];
        byte[] macKey = Arrays.copyOfRange(cipherOut, 0, 32);
        System.arraycopy(cipherOut, 64, output, 0, len);

        Arrays.fill(cipherOut, (byte)0);

        writeMac.setKey(macKey);

        byte[] additionalData = getAdditionalData(seqNo, type, len);
        byte[] mac = calculateRecordMAC(writeMac, additionalData, output, 0, len);
        System.arraycopy(mac, 0, output, len, mac.length);

        return output;
    }

    public byte[] decodeCiphertext(long seqNo, short type, byte[] ciphertext, int offset, int len) throws IOException
    {
        if (getPlaintextLimit(len) < 0)
        {
            throw new TlsFatalAlert(AlertDescription.decode_error);
        }

        initRecord(decryptCipher, seqNo, decryptIV);
        int plaintextLength = len - 16;

        // MAC key is from the zeros at the front.
        byte[] cipherOut = new byte[64 + plaintextLength];
        System.arraycopy(ciphertext, offset, cipherOut, 64, plaintextLength);

        decryptCipher.doFinal(cipherOut, 0, cipherOut.length, cipherOut, 0);

        byte[] macKey = Arrays.copyOfRange(cipherOut, 0, 32);
        readMac.setKey(macKey);

        byte[] additionalData = getAdditionalData(seqNo, type, plaintextLength);
        byte[] calculatedMAC = calculateRecordMAC(readMac, additionalData, ciphertext, offset, plaintextLength);
        byte[] receivedMAC = Arrays.copyOfRange(ciphertext, offset + plaintextLength, offset + len);
        byte[] output = new byte[plaintextLength];

        System.arraycopy(cipherOut, 64, output, 0, plaintextLength);

        Arrays.fill(cipherOut, (byte)0);

        if (!Arrays.constantTimeAreEqual(calculatedMAC, receivedMAC))
        {
            throw new TlsFatalAlert(AlertDescription.bad_record_mac);
        }

        return output;
    }

    protected void initRecord(TlsStreamCipherImpl cipher, long seqNo, byte[] iv)
        throws IOException
    {
        byte[] nonce = calculateNonce(seqNo, iv);
        cipher.init(nonce);
    }

    protected byte[] calculateNonce(long seqNo, byte[] iv)
    {
        byte[] nonce = new byte[12];
        TlsUtils.writeUint64(seqNo, nonce, 4);

        for (int i = 0; i < 12; ++i)
        {
            nonce[i] ^= iv[i];
        }

        return nonce;
    }

    protected byte[] calculateRecordMAC(TlsMAC mac, byte[] additionalData, byte[] buf, int off, int len)
    {
        updateRecordMACText(mac, additionalData, 0, additionalData.length);
        updateRecordMACText(mac, buf, off, len);
        updateRecordMACLength(mac, additionalData.length);
        updateRecordMACLength(mac, len);

        return mac.calculateMAC();
    }

    protected void updateRecordMACLength(TlsMAC mac, int len)
    {
        byte[] longLen = Pack.longToLittleEndian(len & 0xFFFFFFFFL);
        mac.update(longLen, 0, longLen.length);
    }

    protected void updateRecordMACText(TlsMAC mac, byte[] buf, int off, int len)
    {
        mac.update(buf, off, len);

        int partial = len % 16;
        if (partial != 0)
        {
            mac.update(ZEROES, 0, 16 - partial);
        }
    }

    protected byte[] getAdditionalData(long seqNo, short type, int len) throws IOException
    {
        /*
         * additional_data = seq_num + TLSCompressed.type + TLSCompressed.version +
         * TLSCompressed.length
         */
        byte[] additional_data = new byte[13];
        TlsUtils.writeUint64(seqNo, additional_data, 0);
        TlsUtils.writeUint8(type, additional_data, 8);
        TlsUtils.writeVersion(context.getServerVersion(), additional_data, 9);
        TlsUtils.writeUint16(len, additional_data, 11);

        return additional_data;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy