All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.tls.crypto.impl.jcajce.JceTlsSecret Maven / Gradle / Ivy

There is a newer version: 1.79
Show newest version
package org.bouncycastle.tls.crypto.impl.jcajce;

import java.security.GeneralSecurityException;
import java.security.MessageDigest;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import org.bouncycastle.tls.PRFAlgorithm;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.CryptoHashAlgorithm;
import org.bouncycastle.tls.crypto.TlsCryptoUtils;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.tls.crypto.impl.AbstractTlsCrypto;
import org.bouncycastle.tls.crypto.impl.AbstractTlsSecret;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Strings;

/**
 * JCE support class for handling TLS secrets and deriving key material and other secrets from them.
 */
public class JceTlsSecret
    extends AbstractTlsSecret
{
    public static JceTlsSecret convert(JcaTlsCrypto crypto, TlsSecret secret)
    {
        if (secret instanceof JceTlsSecret)
        {
            return (JceTlsSecret)secret;
        }

        if (secret instanceof AbstractTlsSecret)
        {
            AbstractTlsSecret abstractTlsSecret = (AbstractTlsSecret)secret;

            return crypto.adoptLocalSecret(copyData(abstractTlsSecret));
        }

        throw new IllegalArgumentException("unrecognized TlsSecret - cannot copy data: " + secret.getClass().getName());
    }

    // SSL3 magic mix constants ("A", "BB", "CCC", ...)
    private static final byte[] SSL3_CONST = generateSSL3Constants();

    private static byte[] generateSSL3Constants()
    {
        int n = 15;
        byte[] result = new byte[n * (n + 1) / 2];
        int pos = 0;
        for (int i = 0; i < n; ++i)
        {
            byte b = (byte)('A' + i);
            for (int j = 0; j <= i; ++j)
            {
                result[pos++] = b;
            }
        }
        return result;
    }

    protected final JcaTlsCrypto crypto;

    public JceTlsSecret(JcaTlsCrypto crypto, byte[] data)
    {
        super(data);

        this.crypto = crypto;
    }

    public synchronized TlsSecret deriveUsingPRF(int prfAlgorithm, String label, byte[] seed, int length)
    {
        checkAlive();

        try
        {
            switch (prfAlgorithm)
            {
            case PRFAlgorithm.tls13_hkdf_sha256:
                return TlsCryptoUtils.hkdfExpandLabel(this, CryptoHashAlgorithm.sha256, label, seed, length);
            case PRFAlgorithm.tls13_hkdf_sha384:
                return TlsCryptoUtils.hkdfExpandLabel(this, CryptoHashAlgorithm.sha384, label, seed, length);
            case PRFAlgorithm.tls13_hkdf_sm3:
                return TlsCryptoUtils.hkdfExpandLabel(this, CryptoHashAlgorithm.sm3, label, seed, length);
            default:
                return crypto.adoptLocalSecret(prf(prfAlgorithm, label, seed, length));
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }

    public synchronized TlsSecret hkdfExpand(int cryptoHashAlgorithm, byte[] info, int length)
    {
        if (length < 1)
        {
            return crypto.adoptLocalSecret(TlsUtils.EMPTY_BYTES);
        }

        int hashLen = TlsCryptoUtils.getHashOutputSize(cryptoHashAlgorithm);
        if (length > (255 * hashLen))
        {
            throw new IllegalArgumentException("'length' must be <= 255 * (output size of 'hashAlgorithm')");
        }

        checkAlive();

        byte[] prk = data;

        try
        {
            String algorithm = crypto.getHMACAlgorithmName(cryptoHashAlgorithm);
            Mac hmac = crypto.getHelper().createMac(algorithm);
            hmac.init(new SecretKeySpec(prk, 0, prk.length, algorithm));

            byte[] okm = new byte[length];

            byte[] t = new byte[hashLen];
            byte counter = 0x00;

            int pos = 0;
            for (;;)
            {
                hmac.update(info, 0, info.length);
                hmac.update((byte)++counter);
                hmac.doFinal(t, 0);

                int remaining = length - pos;
                if (remaining <= hashLen)
                {
                    System.arraycopy(t, 0, okm, pos, remaining);
                    break;
                }

                System.arraycopy(t, 0, okm, pos, hashLen);
                pos += hashLen;
                hmac.update(t, 0, t.length);
            }

            return crypto.adoptLocalSecret(okm);
        }
        catch (GeneralSecurityException e)
        {
            throw new RuntimeException(e);
        }
    }

    public synchronized TlsSecret hkdfExtract(int cryptoHashAlgorithm, TlsSecret ikm)
    {
        checkAlive();

        byte[] salt = data;
        this.data = null;

        try
        {
            String algorithm = crypto.getHMACAlgorithmName(cryptoHashAlgorithm);
            Mac hmac = crypto.getHelper().createMac(algorithm);
            hmac.init(new SecretKeySpec(salt, 0, salt.length, algorithm));

            convert(crypto, ikm).updateMac(hmac);

            byte[] prk = hmac.doFinal();

            return crypto.adoptLocalSecret(prk);
        }
        catch (GeneralSecurityException e)
        {
            throw new RuntimeException(e);
        }
    }

    protected AbstractTlsCrypto getCrypto()
    {
        return crypto;
    }

    protected void hmacHash(int cryptoHashAlgorithm, byte[] secret, int secretOff, int secretLen, byte[] seed,
        byte[] output) throws GeneralSecurityException
    {
        String digestName = crypto.getDigestName(cryptoHashAlgorithm).replaceAll("-", "");
        String macName = "Hmac" + digestName;
        Mac mac = crypto.getHelper().createMac(macName);
        mac.init(new SecretKeySpec(secret, secretOff, secretLen, macName));

        byte[] a = seed;

        int macSize = mac.getMacLength();

        byte[] b1 = new byte[macSize];
        byte[] b2 = new byte[macSize];

        int pos = 0;
        while (pos < output.length)
        {
            mac.update(a, 0, a.length);
            mac.doFinal(b1, 0);
            a = b1;
            mac.update(a, 0, a.length);
            mac.update(seed, 0, seed.length);
            mac.doFinal(b2, 0);
            System.arraycopy(b2, 0, output, pos, Math.min(macSize, output.length - pos));
            pos += macSize;
        }
    }

    protected byte[] prf(int prfAlgorithm, String label, byte[] seed, int length)
        throws GeneralSecurityException
    {
        if (PRFAlgorithm.ssl_prf_legacy == prfAlgorithm)
        {
            return prf_SSL(seed, length);
        }

        byte[] labelSeed = Arrays.concatenate(Strings.toByteArray(label), seed);

        if (PRFAlgorithm.tls_prf_legacy == prfAlgorithm)
        {
            return prf_1_0(labelSeed, length);
        }

        return prf_1_2(prfAlgorithm, labelSeed, length);
    }

    protected byte[] prf_SSL(byte[] seed, int length)
        throws GeneralSecurityException
    {
        MessageDigest md5 = crypto.getHelper().createMessageDigest("MD5");
        MessageDigest sha1 = crypto.getHelper().createMessageDigest("SHA-1");

        int md5Size = md5.getDigestLength();
        int sha1Size = sha1.getDigestLength();

        byte[] tmp = new byte[Math.max(md5Size, sha1Size)];
        byte[] result = new byte[length];

        int constLen = 1, constPos = 0, resultPos = 0;
        while (resultPos < length)
        {
            sha1.update(SSL3_CONST, constPos, constLen);
            constPos += constLen++;

            sha1.update(data, 0, data.length);
            sha1.update(seed, 0, seed.length);
            sha1.digest(tmp, 0, sha1Size);

            md5.update(data, 0, data.length);
            md5.update(tmp, 0, sha1Size);

            int remaining = length - resultPos;
            if (remaining < md5Size)
            {
                md5.digest(tmp, 0, md5Size);
                System.arraycopy(tmp, 0, result, resultPos, remaining);
                resultPos += remaining;
            }
            else
            {
                md5.digest(result, resultPos, md5Size);
                resultPos += md5Size;
            }
        }

        return result;
    }

    protected byte[] prf_1_0(byte[] labelSeed, int length)
        throws GeneralSecurityException
    {
        int s_half = (data.length + 1) / 2;

        byte[] b1 = new byte[length];
        hmacHash(CryptoHashAlgorithm.md5, data, 0, s_half, labelSeed, b1);

        byte[] b2 = new byte[length];
        hmacHash(CryptoHashAlgorithm.sha1, data, data.length - s_half, s_half, labelSeed, b2);

        for (int i = 0; i < length; i++)
        {
            b1[i] ^= b2[i];
        }
        return b1;
    }

    protected byte[] prf_1_2(int prfAlgorithm, byte[] labelSeed, int length)
        throws GeneralSecurityException
    {
        int cryptoHashAlgorithm = TlsCryptoUtils.getHashForPRF(prfAlgorithm);
        byte[] result = new byte[length];
        hmacHash(cryptoHashAlgorithm, data, 0, data.length, labelSeed, result);
        return result;
    }

    protected synchronized void updateMac(Mac mac)
    {
        checkAlive();

        mac.update(data, 0, data.length);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy