All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.tls.DeferredHash Maven / Gradle / Ivy

package org.bouncycastle.tls;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Enumeration;
import java.util.Hashtable;

import org.bouncycastle.tls.crypto.CryptoHashAlgorithm;
import org.bouncycastle.tls.crypto.TlsHash;
import org.bouncycastle.util.Integers;

/**
 * Buffers input until the hash algorithm is determined.
 */
class DeferredHash
    implements TlsHandshakeHash
{
    protected static final int BUFFERING_HASH_LIMIT = 4;

    protected TlsContext context;

    private DigestInputBuffer buf;
    private Hashtable hashes;
    private boolean forceBuffering;
    private boolean sealed;

    DeferredHash(TlsContext context)
    {
        this.context = context;
        this.buf = new DigestInputBuffer();
        this.hashes = new Hashtable();
        this.forceBuffering = false;
        this.sealed = false;
    }

    public void copyBufferTo(OutputStream output)
        throws IOException
    {
        if (buf == null)
        {
            // If you see this, you need to call forceBuffering() before sealHashAlgorithms()
            throw new IllegalStateException("Not buffering");
        }

        buf.copyInputTo(output);
    }

    public void forceBuffering()
    {
        if (sealed)
        {
            throw new IllegalStateException("Too late to force buffering");
        }

        this.forceBuffering = true;
    }

    public void notifyPRFDetermined()
    {
        SecurityParameters securityParameters = context.getSecurityParametersHandshake();

        switch (securityParameters.getPRFAlgorithm())
        {
        case PRFAlgorithm.ssl_prf_legacy:
        case PRFAlgorithm.tls_prf_legacy:
        {
            checkTrackingHash(CryptoHashAlgorithm.md5);
            checkTrackingHash(CryptoHashAlgorithm.sha1);
            break;
        }
        default:
        {
            checkTrackingHash(securityParameters.getPRFCryptoHashAlgorithm());
            break;
        }
        }
    }

    public void trackHashAlgorithm(int cryptoHashAlgorithm)
    {
        if (sealed)
        {
            throw new IllegalStateException("Too late to track more hash algorithms");
        }

        checkTrackingHash(cryptoHashAlgorithm);
    }

    public void sealHashAlgorithms()
    {
        if (sealed)
        {
            throw new IllegalStateException("Already sealed");
        }

        this.sealed = true;
        checkStopBuffering();
    }

    public void stopTracking()
    {
        SecurityParameters securityParameters = context.getSecurityParametersHandshake();

        Hashtable newHashes = new Hashtable();
        switch (securityParameters.getPRFAlgorithm())
        {
        case PRFAlgorithm.ssl_prf_legacy:
        case PRFAlgorithm.tls_prf_legacy:
        {
            cloneHash(newHashes, CryptoHashAlgorithm.md5);
            cloneHash(newHashes, CryptoHashAlgorithm.sha1);
            break;
        }
        default:
        {
            cloneHash(newHashes, securityParameters.getPRFCryptoHashAlgorithm());
            break;
        }
        }

        this.buf = null;
        this.hashes = newHashes;
        this.forceBuffering = false;
        this.sealed = true;
    }

    public TlsHash forkPRFHash()
    {
        checkStopBuffering();

        SecurityParameters securityParameters = context.getSecurityParametersHandshake();

        TlsHash prfHash;
        switch (securityParameters.getPRFAlgorithm())
        {
        case PRFAlgorithm.ssl_prf_legacy:
        case PRFAlgorithm.tls_prf_legacy:
        {
            TlsHash md5Hash = cloneHash(CryptoHashAlgorithm.md5);
            TlsHash sha1Hash = cloneHash(CryptoHashAlgorithm.sha1);
            prfHash = new CombinedHash(context, md5Hash, sha1Hash);
            break;
        }
        default:
        {
            prfHash = cloneHash(securityParameters.getPRFCryptoHashAlgorithm());
            break;
        }
        }

        if (buf != null)
        {
            buf.updateDigest(prfHash);
        }

        return prfHash;
    }

    public byte[] getFinalHash(int cryptoHashAlgorithm)
    {
        TlsHash hash = (TlsHash)hashes.get(box(cryptoHashAlgorithm));
        if (hash == null)
        {
            throw new IllegalStateException("CryptoHashAlgorithm." + cryptoHashAlgorithm + " is not being tracked");
        }

        checkStopBuffering();

        hash = hash.cloneHash();
        if (buf != null)
        {
            buf.updateDigest(hash);
        }

        return hash.calculateHash();
    }

    public void update(byte[] input, int inOff, int len)
    {
        if (buf != null)
        {
            buf.write(input, inOff, len);
            return;
        }

        Enumeration e = hashes.elements();
        while (e.hasMoreElements())
        {
            TlsHash hash = (TlsHash)e.nextElement();
            hash.update(input, inOff, len);
        }
    }

    public byte[] calculateHash()
    {
        throw new IllegalStateException("Use 'forkPRFHash' to get a definite hash");
    }

    public TlsHash cloneHash()
    {
        throw new IllegalStateException("attempt to clone a DeferredHash");
    }

    public void reset()
    {
        if (buf != null)
        {
            buf.reset();
            return;
        }

        Enumeration e = hashes.elements();
        while (e.hasMoreElements())
        {
            TlsHash hash = (TlsHash)e.nextElement();
            hash.reset();
        }
    }

    protected Integer box(int cryptoHashAlgorithm)
    {
        return Integers.valueOf(cryptoHashAlgorithm);
    }

    protected void checkStopBuffering()
    {
        if (!forceBuffering && sealed && buf != null && hashes.size() <= BUFFERING_HASH_LIMIT)
        {
            Enumeration e = hashes.elements();
            while (e.hasMoreElements())
            {
                TlsHash hash = (TlsHash)e.nextElement();
                buf.updateDigest(hash);
            }

            this.buf = null;
        }
    }

    protected void checkTrackingHash(int cryptoHashAlgorithm)
    {
        checkTrackingHash(box(cryptoHashAlgorithm));
    }

    protected void checkTrackingHash(Integer cryptoHashAlgorithm)
    {
        if (!hashes.containsKey(cryptoHashAlgorithm))
        {
            TlsHash hash = context.getCrypto().createHash(cryptoHashAlgorithm.intValue());
            hashes.put(cryptoHashAlgorithm, hash);
        }
    }

    protected TlsHash cloneHash(int cryptoHashAlgorithm)
    {
        return cloneHash(box(cryptoHashAlgorithm));
    }

    protected TlsHash cloneHash(Integer cryptoHashAlgorithm)
    {
        return ((TlsHash)hashes.get(cryptoHashAlgorithm)).cloneHash();
    }

    protected void cloneHash(Hashtable newHashes, int cryptoHashAlgorithm)
    {
        cloneHash(newHashes, box(cryptoHashAlgorithm));
    }

    protected void cloneHash(Hashtable newHashes, Integer cryptoHashAlgorithm)
    {
        TlsHash hash = cloneHash(cryptoHashAlgorithm);
        if (buf != null)
        {
            buf.updateDigest(hash);
        }
        newHashes.put(cryptoHashAlgorithm, hash);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy