All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.crypto.engines.Salsa20Engine Maven / Gradle / Ivy

Go to download

The Bouncy Castle Crypto package is a Java implementation of cryptographic algorithms. This jar contains JCE provider and lightweight API for the Bouncy Castle Cryptography APIs for Java 1.8 and later with debug enabled.

The newest version!
package org.bouncycastle.crypto.engines;

import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoServicesRegistrar;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.MaxBytesExceededException;
import org.bouncycastle.crypto.OutputLengthException;
import org.bouncycastle.crypto.SkippingStreamCipher;
import org.bouncycastle.crypto.constraints.DefaultServiceProperties;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.params.ParametersWithIV;
import org.bouncycastle.util.Integers;
import org.bouncycastle.util.Pack;
import org.bouncycastle.util.Strings;

/**
 * Implementation of Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005
 */
public class Salsa20Engine
    implements SkippingStreamCipher
{
    public final static int DEFAULT_ROUNDS = 20;

    /** Constants */
    private final static int STATE_SIZE = 16; // 16, 32 bit ints = 64 bytes

    private final static int[] TAU_SIGMA = Pack.littleEndianToInt(Strings.toByteArray("expand 16-byte k" + "expand 32-byte k"), 0, 8);

    protected void packTauOrSigma(int keyLength, int[] state, int stateOffset)
    {
        int tsOff = (keyLength - 16) / 4;
        state[stateOffset    ] = TAU_SIGMA[tsOff    ];
        state[stateOffset + 1] = TAU_SIGMA[tsOff + 1];
        state[stateOffset + 2] = TAU_SIGMA[tsOff + 2];
        state[stateOffset + 3] = TAU_SIGMA[tsOff + 3];
    }

    /** @deprecated */
    protected final static byte[]
        sigma = Strings.toByteArray("expand 32-byte k"),
        tau   = Strings.toByteArray("expand 16-byte k");

    protected int rounds;

    /*
     * variables to hold the state of the engine
     * during encryption and decryption
     */
    private int         index = 0;
    protected int[]     engineState = new int[STATE_SIZE]; // state
    protected int[]     x = new int[STATE_SIZE] ; // internal buffer
    private byte[]      keyStream   = new byte[STATE_SIZE * 4]; // expanded state, 64 bytes
    private boolean     initialised = false;

    /*
     * internal counter
     */
    private int cW0, cW1, cW2;

    /**
     * Creates a 20 round Salsa20 engine.
     */
    public Salsa20Engine()
    {
        this(DEFAULT_ROUNDS);
    }

    /**
     * Creates a Salsa20 engine with a specific number of rounds.
     * @param rounds the number of rounds (must be an even number).
     */
    public Salsa20Engine(int rounds)
    {
        if (rounds <= 0 || (rounds & 1) != 0)
        {
            throw new IllegalArgumentException("'rounds' must be a positive, even number");
        }

        this.rounds = rounds;
    }

    /**
     * initialise a Salsa20 cipher.
     *
     * @param forEncryption whether or not we are for encryption.
     * @param params the parameters required to set up the cipher.
     * @exception IllegalArgumentException if the params argument is
     * inappropriate.
     */
    public void init(
        boolean             forEncryption, 
        CipherParameters     params)
    {
        /* 
        * Salsa20 encryption and decryption is completely symmetrical, so the 'forEncryption' is
        * irrelevant for implementation purposes. (Like 90% of stream ciphers)
        */

        if (!(params instanceof ParametersWithIV))
        {
            throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must include an IV");
        }

        ParametersWithIV ivParams = (ParametersWithIV) params;

        byte[] iv = ivParams.getIV();
        if (iv == null || iv.length != getNonceSize())
        {
            throw new IllegalArgumentException(getAlgorithmName() + " requires exactly " + getNonceSize()
                    + " bytes of IV");
        }

        CipherParameters keyParam = ivParams.getParameters();
        if (keyParam == null)
        {
            if (!initialised)
            {
                throw new IllegalStateException(getAlgorithmName() + " KeyParameter can not be null for first initialisation");
            }

            setKey(null, iv);
        }
        else if (keyParam instanceof KeyParameter)
        {
            byte[] key = ((KeyParameter)keyParam).getKey();

            setKey(key, iv);

            CryptoServicesRegistrar.checkConstraints(new DefaultServiceProperties(
                        this.getAlgorithmName(), key.length * 8, params, Utils.getPurpose(forEncryption)));
        }
        else
        {
            throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must contain a KeyParameter (or null for re-init)");
        }

        reset();

        initialised = true;
    }

    protected int getNonceSize()
    {
        return 8;
    }

    public String getAlgorithmName()
    {
        String name = "Salsa20";
        if (rounds != DEFAULT_ROUNDS)
        {
            name += "/" + rounds;
        }
        return name;
    }

    public byte returnByte(byte in)
    {
        if (limitExceeded())
        {
            throw new MaxBytesExceededException("2^70 byte limit per IV; Change IV");
        }

        byte out = (byte)(keyStream[index]^in);
        index = (index + 1) & 63;

        if (index == 0)
        {
            advanceCounter();
            generateKeyStream(keyStream);
        }

        return out;
    }

    protected void advanceCounter(long diff)
    {
        int hi = (int)(diff >>> 32);
        int lo = (int)diff;

        if (hi > 0)
        {
            engineState[9] += hi;
        }

        int oldState = engineState[8];

        engineState[8] += lo;

        if (oldState != 0 && engineState[8] < oldState)
        {
            engineState[9]++;
        }
    }

    protected void advanceCounter()
    {
        if (++engineState[8] == 0)
        {
            ++engineState[9];
        }
    }

    protected void retreatCounter(long diff)
    {
        int hi = (int)(diff >>> 32);
        int lo = (int)diff;

        if (hi != 0)
        {
            if ((engineState[9] & 0xffffffffL) >= (hi & 0xffffffffL))
            {
                engineState[9] -= hi;
            }
            else
            {
                throw new IllegalStateException("attempt to reduce counter past zero.");
            }
        }

        if ((engineState[8] & 0xffffffffL) >= (lo & 0xffffffffL))
        {
            engineState[8] -= lo;
        }
        else
        {
            if (engineState[9] != 0)
            {
                --engineState[9];
                engineState[8] -= lo;
            }
            else
            {
                throw new IllegalStateException("attempt to reduce counter past zero.");
            }
        }
    }

    protected void retreatCounter()
    {
        if (engineState[8] == 0 && engineState[9] == 0)
        {
            throw new IllegalStateException("attempt to reduce counter past zero.");
        }

        if (--engineState[8] == -1)
        {
            --engineState[9];
        }
    }

    public int processBytes(
        byte[]     in, 
        int     inOff, 
        int     len, 
        byte[]     out, 
        int     outOff)
    {
        if (!initialised)
        {
            throw new IllegalStateException(getAlgorithmName() + " not initialised");
        }

        if ((inOff + len) > in.length)
        {
            throw new DataLengthException("input buffer too short");
        }

        if ((outOff + len) > out.length)
        {
            throw new OutputLengthException("output buffer too short");
        }

        if (limitExceeded(len))
        {
            throw new MaxBytesExceededException("2^70 byte limit per IV would be exceeded; Change IV");
        }

        for (int i = 0; i < len; i++)
        {
            out[i + outOff] = (byte)(keyStream[index] ^ in[i + inOff]);
            index = (index + 1) & 63;

            if (index == 0)
            {
                advanceCounter();
                generateKeyStream(keyStream);
            }
        }

        return len;
    }

    public long skip(long numberOfBytes)
    {
        if (numberOfBytes >= 0)
        {
            long remaining = numberOfBytes;

            if (remaining >= 64)
            {
                long count = remaining / 64;

                advanceCounter(count);

                remaining -= count * 64;
            }

            int oldIndex = index;

            index = (index + (int)remaining) & 63;

            if (index < oldIndex)
            {
                advanceCounter();
            }
        }
        else
        {
            long remaining = -numberOfBytes;

            if (remaining >= 64)
            {
                long count = remaining / 64;

                retreatCounter(count);

                remaining -= count * 64;
            }

            for (long i = 0; i < remaining; i++)
            {
                if (index == 0)
                {
                    retreatCounter();
                }

                index = (index - 1) & 63;
            }
        }

        generateKeyStream(keyStream);

        return numberOfBytes;
    }

    public long seekTo(long position)
    {
        reset();

        return skip(position);
    }

    public long getPosition()
    {
        return getCounter() * 64 + index;
    }

    public void reset()
    {
        index = 0;
        resetLimitCounter();
        resetCounter();

        generateKeyStream(keyStream);
    }

    protected long getCounter()
    {
        return ((long)engineState[9] << 32) | (engineState[8] & 0xffffffffL);
    }

    protected void resetCounter()
    {
        engineState[8] = engineState[9] = 0;
    }

    protected void setKey(byte[] keyBytes, byte[] ivBytes)
    {
        if (keyBytes != null)
        {
            if ((keyBytes.length != 16) && (keyBytes.length != 32))
            {
                throw new IllegalArgumentException(getAlgorithmName() + " requires 128 bit or 256 bit key");
            }

            int tsOff = (keyBytes.length - 16) / 4;
            engineState[0 ] = TAU_SIGMA[tsOff    ];
            engineState[5 ] = TAU_SIGMA[tsOff + 1];
            engineState[10] = TAU_SIGMA[tsOff + 2];
            engineState[15] = TAU_SIGMA[tsOff + 3];

            // Key
            Pack.littleEndianToInt(keyBytes, 0, engineState, 1, 4);
            Pack.littleEndianToInt(keyBytes, keyBytes.length - 16, engineState, 11, 4);
        }

        // IV
        Pack.littleEndianToInt(ivBytes, 0, engineState, 6, 2);
    }

    protected void generateKeyStream(byte[] output)
    {
        salsaCore(rounds, engineState, x);
        Pack.intToLittleEndian(x, output, 0);
    }

    /**
     * Salsa20 function
     *
     * @param   input   input data
     */    
    public static void salsaCore(int rounds, int[] input, int[] x)
    {
        if (input.length != 16)
        {
            throw new IllegalArgumentException();
        }
        if (x.length != 16)
        {
            throw new IllegalArgumentException();
        }
        if (rounds % 2 != 0)
        {
            throw new IllegalArgumentException("Number of rounds must be even");
        }

        int x00 = input[ 0];
        int x01 = input[ 1];
        int x02 = input[ 2];
        int x03 = input[ 3];
        int x04 = input[ 4];
        int x05 = input[ 5];
        int x06 = input[ 6];
        int x07 = input[ 7];
        int x08 = input[ 8];
        int x09 = input[ 9];
        int x10 = input[10];
        int x11 = input[11];
        int x12 = input[12];
        int x13 = input[13];
        int x14 = input[14];
        int x15 = input[15];

        for (int i = rounds; i > 0; i -= 2)
        {
            x04 ^= Integers.rotateLeft(x00 + x12, 7);
            x08 ^= Integers.rotateLeft(x04 + x00, 9);
            x12 ^= Integers.rotateLeft(x08 + x04, 13);
            x00 ^= Integers.rotateLeft(x12 + x08, 18);
            x09 ^= Integers.rotateLeft(x05 + x01, 7);
            x13 ^= Integers.rotateLeft(x09 + x05, 9);
            x01 ^= Integers.rotateLeft(x13 + x09, 13);
            x05 ^= Integers.rotateLeft(x01 + x13, 18);
            x14 ^= Integers.rotateLeft(x10 + x06, 7);
            x02 ^= Integers.rotateLeft(x14 + x10, 9);
            x06 ^= Integers.rotateLeft(x02 + x14, 13);
            x10 ^= Integers.rotateLeft(x06 + x02, 18);
            x03 ^= Integers.rotateLeft(x15 + x11, 7);
            x07 ^= Integers.rotateLeft(x03 + x15, 9);
            x11 ^= Integers.rotateLeft(x07 + x03, 13);
            x15 ^= Integers.rotateLeft(x11 + x07, 18);

            x01 ^= Integers.rotateLeft(x00 + x03, 7);
            x02 ^= Integers.rotateLeft(x01 + x00, 9);
            x03 ^= Integers.rotateLeft(x02 + x01, 13);
            x00 ^= Integers.rotateLeft(x03 + x02, 18);
            x06 ^= Integers.rotateLeft(x05 + x04, 7);
            x07 ^= Integers.rotateLeft(x06 + x05, 9);
            x04 ^= Integers.rotateLeft(x07 + x06, 13);
            x05 ^= Integers.rotateLeft(x04 + x07, 18);
            x11 ^= Integers.rotateLeft(x10 + x09, 7);
            x08 ^= Integers.rotateLeft(x11 + x10, 9);
            x09 ^= Integers.rotateLeft(x08 + x11, 13);
            x10 ^= Integers.rotateLeft(x09 + x08, 18);
            x12 ^= Integers.rotateLeft(x15 + x14, 7);
            x13 ^= Integers.rotateLeft(x12 + x15, 9);
            x14 ^= Integers.rotateLeft(x13 + x12, 13);
            x15 ^= Integers.rotateLeft(x14 + x13, 18);
        }

        x[ 0] = x00 + input[ 0];
        x[ 1] = x01 + input[ 1];
        x[ 2] = x02 + input[ 2];
        x[ 3] = x03 + input[ 3];
        x[ 4] = x04 + input[ 4];
        x[ 5] = x05 + input[ 5];
        x[ 6] = x06 + input[ 6];
        x[ 7] = x07 + input[ 7];
        x[ 8] = x08 + input[ 8];
        x[ 9] = x09 + input[ 9];
        x[10] = x10 + input[10];
        x[11] = x11 + input[11];
        x[12] = x12 + input[12];
        x[13] = x13 + input[13];
        x[14] = x14 + input[14];
        x[15] = x15 + input[15];
    }

    private void resetLimitCounter()
    {
        cW0 = 0;
        cW1 = 0;
        cW2 = 0;
    }

    private boolean limitExceeded()
    {
        if (++cW0 == 0)
        {
            if (++cW1 == 0)
            {
                return (++cW2 & 0x20) != 0;          // 2^(32 + 32 + 6)
            }
        }

        return false;
    }

    /*
     * this relies on the fact len will always be positive.
     */
    private boolean limitExceeded(int len)
    {
        cW0 += len;
        if (cW0 < len && cW0 >= 0)
        {
            if (++cW1 == 0)
            {
                return (++cW2 & 0x20) != 0;          // 2^(32 + 32 + 6)
            }
        }

        return false;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy