All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.pqc.crypto.ntruprime.Utils Maven / Gradle / Ivy

Go to download

The Bouncy Castle Crypto package is a Java implementation of cryptographic algorithms. This jar contains JCE provider and lightweight API for the Bouncy Castle Cryptography APIs for Java 1.8 and later with debug enabled.

The newest version!
package org.bouncycastle.pqc.crypto.ntruprime;

import java.security.SecureRandom;

import org.bouncycastle.crypto.StreamCipher;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.engines.AESEngine;
import org.bouncycastle.crypto.modes.SICBlockCipher;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.params.ParametersWithIV;

class Utils
{
    protected static int getRandomUnsignedInteger(SecureRandom random)
    {
        byte[] c = new byte[4];
        random.nextBytes(c);
        return (bToUnsignedInt(c[0])
                + (bToUnsignedInt(c[1]) << 8)
                + (bToUnsignedInt(c[2]) << 16)
                + (bToUnsignedInt(c[3]) << 24));
    }

    protected static void getRandomSmallPolynomial(SecureRandom random, byte[] g)
    {
        for (int i = 0; i < g.length; i++)
            g[i] = (byte)((((getRandomUnsignedInteger(random) & 0x3fffffff) * 3) >>> 30) - 1);
    }

    protected static int getModFreeze(int x, int n)
    {
        return getSignedDivMod((x + ((n - 1) / 2)), n)[1] - ((n - 1) / 2);
    }

    protected static boolean isInvertiblePolynomialInR3(byte[] g, byte[] ginv, int p)
    {
        byte[] f = new byte[p + 1];
        byte[] h = new byte[p + 1];
        byte[] r = new byte[p + 1];
        byte[] v = new byte[p + 1];
        int loop, delta, sign, swap, t, i;

        r[0] = 1;
        f[0] = 1;
        f[p - 1] = -1;
        f[p] = -1;
        for (i = 0; i < p; i++)
            h[p - 1 - i] = g[i];

        h[p] = 0;
        delta = 1;

        for (loop = 0; loop < (2 * p) - 1; loop++)
        {
            System.arraycopy(v, 0, v, 1, p);
            v[0] = 0;

            sign = -h[0] * f[0];
            swap = checkLessThanZero(-delta) & checkNotEqualToZero(h[0]);
            delta ^= swap & (delta ^ -delta);
            delta += 1;

            for (i = 0; i < p + 1; i++)
            {
                t = swap & (f[i] ^ h[i]);
                f[i] ^= t;
                h[i] ^= t;
                t = swap & (v[i] ^ r[i]);
                v[i] ^= t;
                r[i] ^= t;
            }

            for (i = 0; i < p + 1; i++)
                h[i] = (byte)getModFreeze(h[i] + sign * f[i], 3);
            for (i = 0; i < p + 1; i++)
                r[i] = (byte)getModFreeze(r[i] + sign * v[i], 3);

            for (i = 0; i < p; i++)
                h[i] = h[i+1];
            h[p] = 0;
        }

        sign = f[0];
        for (i = 0; i < p; i++)
            ginv[i] = (byte)(sign * v[p - 1 - i]);

        return (delta == 0);
    }

    protected static void minmax(int[] L, int x, int y)
    {
        int xi = L[x];
        int yi = L[y];
        int xy = xi ^ yi;
        int c = yi - xi;
        c ^= xy & (c ^ yi ^ 0x80000000);
        c = c >>> 31;
        c = -c;
        c &= xy;
        L[x] = xi ^ c;
        L[y] = yi ^ c;
    }

    protected static void cryptoSort(int[] L, int p)
    {
        int top, a, b, i;

        if (p < 2)
          return;

        top = 1;
        while (top < p - top)
          top += top;

        for (a = top; a > 0; a = a >>> 1)
        {
            for (i = 0; i < p - a; i++)
                if ((i & a) == 0)
                    minmax(L, i, i + a);
            for (b = top; b > a; b = b >>> 1)
                for (i = 0; i < p - b; i++)
                    if ((i & a) == 0)
                        minmax(L, i + a, i + b);
        }
    }

    protected static void sortGenerateShortPolynomial(byte[] f, int[] L, int p, int w)
    {
        for (int i = 0; i < w; i++)
            L[i] = L[i] & -2;
        for (int i = w; i < p; i++)
            L[i] = (L[i] & -3) | 1;
        cryptoSort(L, p);
        for (int i = 0; i < p; i++)
            f[i] = (byte)((L[i] & 3) - 1);
    }

    protected static void getRandomShortPolynomial(SecureRandom random, byte[] f, int p, int w)
    {
        int[] L = new int[p];
        for (int i = 0; i < p; i++)
            L[i] = getRandomUnsignedInteger(random);
        sortGenerateShortPolynomial(f, L, p, w);
    }

    protected static int getInverseInRQ(int x, int q)
    {
        int ai = x;
        for (int i = 1; i < q - 2; i++)
            ai = getModFreeze(x * ai, q);
        return ai;
    }

    protected static void getOneThirdInverseInRQ(short[] finv3, byte[] f, int p, int q)
    {
        short[] h = new short[p + 1];
        short[] g = new short[p + 1];
        short[] r = new short[p + 1];
        short[] v = new short[p + 1];
        int loop, delta, scale, swap, h0, g0, t, i;

        r[0] = (short)getInverseInRQ(3, q);
        h[0] = 1;
        h[p - 1] = -1;
        h[p] = -1;
        for (i = 0; i < p; i++)
            g[p - 1 - i] = f[i];

        g[p] = 0;
        delta = 1;

        for (loop = 0; loop < (2 * p) - 1; loop++)
        {
            System.arraycopy(v, 0, v, 1, p);
            v[0] = 0;

            swap = checkLessThanZero(-delta) & checkNotEqualToZero(g[0]);
            delta ^= swap & (delta ^ -delta);
            delta += 1;

            for (i = 0; i < p + 1; i++)
            {
                t = swap & (h[i] ^ g[i]);
                h[i] ^= t;
                g[i] ^= t;
                t = swap & (v[i] ^ r[i]);
                v[i] ^= t;
                r[i] ^= t;
            }

            h0 = h[0];
            g0 = g[0];
            for (i = 0; i < p + 1; i++)
                g[i] = (short)getModFreeze((h0 * g[i]) - (g0 * h[i]), q);
            for (i = 0; i < p + 1; i++)
                r[i] = (short)getModFreeze((h0 * r[i]) - (g0 * v[i]), q);

            for (i = 0; i < p; i++)
                g[i] = g[i+1];
            g[p] = 0;
        }

        scale = getInverseInRQ(h[0], q);
        for (i = 0; i < p; i++)
            finv3[i] = (short)getModFreeze(scale * v[p - 1 - i], q);
    }

    protected static void multiplicationInRQ(short[] h, short[] finv3, byte[] g, int p, int q)
    {
        short[] fg = new short[p + p - 1];
        short result;
        int i, j;

        for (i = 0; i < p; i++)
        {
            result = 0;
            for (j = 0; j <= i; j++)
                result = (short)getModFreeze(result + (finv3[j] * g[i - j]), q);
            fg[i] = result;
        }

        for (i = p; i < p + p - 1; i++)
        {
            result = 0;
            for (j = i - p + 1; j < p; j++)
                result = (short)getModFreeze(result + (finv3[j] * g[i - j]), q);
            fg[i] = result;
        }

        for (i = p + p - 2; i >= p; i--)
        {
            fg[i - p] = (short)getModFreeze(fg[i - p] + fg[i], q);
            fg[i - p + 1] = (short)getModFreeze(fg[i - p + 1] + fg[i], q);
        }

        for (i = 0; i < p; i++)
            h[i] = fg[i];
    }

    private static void encode(byte[] out, short[] R, short[] M, int len, int start)
    {
        if (len == 1)
        {
            short r = R[0];
            short m = M[0];
            while (m > 1)
            {
                out[start++] = (byte)r;
                r = (short)(r >>> 8);
                m = (short)((m + 255) >>> 8);
            }
        }

        if (len > 1)
        {
            short[] R2 = new short[(len + 1) / 2];
            short[] M2 = new short[(len + 1) / 2];
            int i;

            for (i = 0; i < len - 1; i += 2)
            {
                int m0 = M[i];
                int r = R[i] + (R[i + 1] * m0);
                int m = M[i + 1] * m0;
                while (m >= 16384)
                {
                    out[start++] = (byte)r;
                    r = r >>> 8;
                    m = (m + 255) >>> 8;
                }
                R2[i / 2] = (short)r;
                M2[i / 2] = (short)m;
            }

            if (i < len)
            {
                R2[i / 2] = R[i];
                M2[i / 2] = M[i];
            }

            encode(out, R2, M2,(len + 1) / 2, start);
        }
    }

    protected static void getEncodedPolynomial(byte[] enc, short[] h, int p, int q)
    {
        short[] R = new short[p];
        short[] M = new short[p];

        for (int i = 0; i < p; i++)
            R[i] = (short)(h[i] + ((q - 1) / 2));
        for (int i = 0; i < p; i++)
            M[i] = (short)q;

        encode(enc, R, M, p, 0);
    }

    protected static void getEncodedSmallPolynomial(byte[] encSP, byte[] sp, int p)
    {
        byte x;
        int spIndex = 0;
        int encSPIndex = 0;
        for (int i = 0; i < p / 4; i++)
        {
            x = (byte)(sp[spIndex++] + 1);
            x += (byte)(sp[spIndex++] + 1) << 2;
            x += (byte)(sp[spIndex++] + 1) << 4;
            x += (byte)(sp[spIndex++] + 1) << 6;
            encSP[encSPIndex++] = x;
        }
        encSP[encSPIndex] = (byte)(sp[spIndex] + 1);
    }

    private static void generateAES256CTRStream(byte[] in, byte[] out, byte[] nonce, byte[] key)
    {
        StreamCipher cipher = SICBlockCipher.newInstance(AESEngine.newInstance());
        cipher.init(true, new ParametersWithIV(new KeyParameter(key), nonce));
        cipher.processBytes(in, 0, out.length, out, 0);
    }

    protected static void expand(int[] L, byte[] k)
    {
        byte[] aesInput = new byte[L.length * 4];
        byte[] aesOutput = new byte[L.length * 4];
        byte[] nonce = new byte[16];
        generateAES256CTRStream(aesInput, aesOutput, nonce, k);
        for (int i = 0; i < L.length; i++)
            L[i] = (bToUnsignedInt(aesOutput[i * 4])
                    + (bToUnsignedInt(aesOutput[(i * 4) + 1]) << 8)
                    + (bToUnsignedInt(aesOutput[(i * 4) + 2]) << 16)
                    + (bToUnsignedInt(aesOutput[(i * 4) + 3]) << 24));
    }

    private static int getUnsignedMod(int x, int n)
    {
        return getUnsignedDivMod(x, n)[1];
    }

    protected static void generatePolynomialInRQFromSeed(short[] G, byte[] seed, int p, int q)
    {
        int[] L = new int[p];
        expand(L, seed);

        for (int i = 0; i < p; i++)
            G[i] = (short)(getUnsignedMod(L[i], q) - ((q - 1) / 2));
    }

    protected static void roundPolynomial(short[] out, short[] in)
    {
        for (int i = 0; i < out.length; i++)
            out[i] = (short)(in[i] - getModFreeze(in[i], 3));
    }

    protected static void getRoundedEncodedPolynomial(byte[] out, short[] in, int p, int q)
    {
        short[] R = new short[p];
        short[] M = new short[p];

        for (int i = 0; i < p; i++)
        {
            R[i] = (short)(((in[i] + ((q - 1) / 2)) * 10923) >>> 15);
            M[i] = (short)((q + 2) / 3);
        }

        encode(out, R, M, p, 0);
    }

    protected static byte[] getHashWithPrefix(byte[] prefix, byte[] hashInput)
    {
        byte[] hash = new byte[64];
        byte[] input = new byte[prefix.length + hashInput.length];
        System.arraycopy(prefix, 0, input, 0, prefix.length);
        System.arraycopy(hashInput, 0, input, prefix.length, hashInput.length);
        SHA512Digest hashDigest = new SHA512Digest();
        hashDigest.update(input, 0, input.length);
        hashDigest.doFinal(hash, 0);
        return hash;
    }

    private static void decode(short[] out, byte[] S, short[] M, int len, int start, int sIndex)
    {
        if (len == 1)
        {
            if (M[0] == 1)
                out[start] = 0;
            else if (M[0] <= 256)
                out[start] = (short)getUnsignedMod(bToUnsignedInt(S[sIndex]), M[0]);
            else
                out[start] = (short)getUnsignedMod(bToUnsignedInt(S[sIndex]) + (S[sIndex + 1] << 8), M[0]);
        }

        if (len > 1)
        {
            short[] R2 = new short[(len + 1) / 2];
            short[] M2 = new short[(len + 1) / 2];
            short[] bottomr = new short[len / 2];
            int[] bottomt = new int[len / 2];
            int i;

            for (i = 0; i < len - 1; i += 2)
            {
                int m = M[i] * (int)M[i+1];
                if (m > (256 * 16383))
                {
                    bottomt[i / 2] = 256 * 256;
                    bottomr[i / 2] = (short)(bToUnsignedInt(S[sIndex]) + (256 * bToUnsignedInt(S[sIndex + 1])));
                    sIndex += 2;
                    M2[i / 2] = (short)((((m + 255) >>> 8) + 255) >>> 8);
                }
                else if (m >= 16384)
                {
                    bottomt[i / 2] = 256;
                    bottomr[i / 2] = (short)bToUnsignedInt(S[sIndex]);
                    sIndex += 1;
                    M2[i / 2] = (short)((m + 255) >>> 8);
                }
                else
                {
                    bottomt[i / 2] = 1;
                    bottomr[i / 2] = 0;
                    M2[i / 2] = (short)m;
                }
            }
            if (i < len)
                M2[i / 2] = M[i];

            decode(R2, S, M2,(len + 1) / 2, start, sIndex);

            for (i = 0; i < len - 1; i += 2)
            {
                int r = sToUnsignedInt(bottomr[i / 2]);
                r += bottomt[i / 2] * sToUnsignedInt(R2[i / 2]);
                int[] r01 = getUnsignedDivMod(r, M[i]);
                out[start++] = (short)r01[1];
                out[start++] = (short)getUnsignedMod(r01[0], M[i + 1]);
            }
            if (i < len)
                out[start] = R2[i / 2];
        }
    }

    protected static void getDecodedPolynomial(short[] h, byte[] enc, int p, int q)
    {
        short[] R = new short[p];
        short[] M = new short[p];

        for (int i = 0; i < p; i++)
            M[i] = (short)q;

        decode(R, enc, M, p, 0, 0);

        for (int i = 0; i < p; i++)
            h[i] = (short)(R[i] - ((q - 1) / 2));
    }

    protected static void getRandomInputs(SecureRandom random, byte[] r)
    {
        byte[] seed = new byte[r.length / 8];
        random.nextBytes(seed);

        for (int i = 0; i < r.length; i++)
            r[i] = (byte)(1 & (seed[i >>> 3] >>> (i & 7)));
    }

    protected static void getEncodedInputs(byte[] out, byte[] in)
    {
        for (int i = 0; i < in.length; i++)
            out[i >>> 3] |= in[i] << (i & 7);
    }

    protected static void getRoundedDecodedPolynomial(short[] h, byte[] enc, int p, int q)
    {
        short[] R = new short[p];
        short[] M = new short[p];

        for (int i = 0; i < p; i++)
            M[i] = (short)((q + 2) / 3);

        decode(R, enc, M, p, 0, 0);

        for (int i = 0; i < p; i++)
            h[i] = (short)((R[i] * 3) - ((q - 1) / 2));
    }

    protected static void top(byte[] out, short[] bA, byte[] r, int q, int tau0, int tau1)
    {
        for (int i = 0; i < out.length; i++)
            out[i] = (byte)((tau1 * (getModFreeze(bA[i] + r[i] * ((q - 1) / 2), q) + tau0) + 16384) >>> 15);
    }

    protected static void getTopEncodedPolynomial(byte[] out, byte[] in)
    {
        for (int i = 0; i < out.length; i++)
            out[i] = (byte)(in[2 * i] + (in[(2 * i) + 1] << 4));
    }

    protected static void getDecodedSmallPolynomial(byte[] sp, byte[] encSP, int p)
    {
        byte x;
        int spIndex = 0;
        int encSPIndex = 0;

        for (int i = 0; i < p / 4; i++)
        {
            x = encSP[encSPIndex++];
            sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1); x >>>= 2;
            sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1); x >>>= 2;
            sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1); x >>>= 2;
            sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1);
        }

        x = encSP[encSPIndex];
        sp[spIndex] = (byte)((bToUnsignedInt(x) & 3) - 1);
    }

    protected static void scalarMultiplicationInRQ(short[] out, short[] in, int scalar, int q)
    {
        for (int i = 0; i < in.length; i++)
            out[i] = (short)getModFreeze(scalar * in[i], q);
    }

    protected static void transformRQToR3(byte[] out, short[] in)
    {
        for (int i = 0; i < in.length; i++)
            out[i] = (byte)getModFreeze(in[i], 3);
    }

    protected static void multiplicationInR3(byte[] h, byte[] finv3, byte[] g, int p)
    {
        byte[] fg = new byte[p + p - 1];
        byte result;
        int i, j;

        for (i = 0; i < p; i++)
        {
            result = 0;
            for (j = 0; j <= i; j++)
                result = (byte) getModFreeze(result + (finv3[j] * g[i - j]), 3);
            fg[i] = result;
        }

        for (i = p; i < p + p - 1; i++)
        {
            result = 0;
            for (j = i - p + 1; j < p; j++)
                result = (byte)getModFreeze(result + (finv3[j] * g[i - j]), 3);
            fg[i] = result;
        }

        for (i = p + p - 2; i >= p; i--)
        {
            fg[i - p] = (byte)getModFreeze(fg[i - p] + fg[i], 3);
            fg[i - p + 1] = (byte)getModFreeze(fg[i - p + 1] + fg[i], 3);
        }

        for (i = 0; i < p; i++)
            h[i] = fg[i];
    }

    protected static void checkForSmallPolynomial(byte[] r, byte[] ev, int p, int w)
    {
        int weight = 0;
        for (int i = 0; i != ev.length; i++)
        {
            weight += ev[i] & 1;
        }

        int mask = checkNotEqualToZero(weight - w);
        for (int i = 0; i < w; i++)
            r[i] = (byte)(((ev[i] ^ 1) & ~mask) ^ 1);
        for (int i = w; i < p; i++)
            r[i] = (byte)(ev[i] & ~mask);
    }

    protected static void updateDiffMask(byte[] encR, byte[] rho, int mask)
    {
        for (int i = 0; i < encR.length; i++)
            encR[i] ^= mask & (encR[i] ^ rho[i]);
    }

    protected static void getTopDecodedPolynomial(byte[] out, byte[] in)
    {
        for (int i = 0; i < in.length; i++)
        {
            out[2 * i] = (byte)(in[i] & 15);
            out[(2 * i) + 1] = (byte)(in[i] >>> 4);
        }
    }

    protected static void right(byte[] out, short[] aB, byte[] T, int q, int w, int tau2, int tau3)
    {
        for (int i = 0; i < out.length; i++)
            out[i] = (byte)(-checkLessThanZero(getModFreeze(getModFreeze((tau3 * T[i]) - tau2, q) - aB[i] + (4 * w) + 1, q)));
    }

    private static int[] getUnsignedDivMod(int dividend, int n)
    {
        long x = iToUnsignedLong(dividend);
        long v = iToUnsignedLong(0x80000000);
        long q, qpart, mask;

        v /= n;
        q = 0;

        qpart = (x * v) >>> 31;
        x -= qpart * n;
        q += qpart;

        qpart = (x * v) >>> 31;
        x -= qpart * n;
        q += qpart;

        x -= n;
        q += 1;
        mask = -(x >>> 63);
        x += mask & n;
        q += mask;

        return new int[]{toIntExact(q), toIntExact(x)};
    }

    private static int[] getSignedDivMod(int x, int n)
    {
        int q, r, mask;

        int[] div1 = getUnsignedDivMod(toIntExact(0x80000000 + iToUnsignedLong(x)), n);
        int[] div2 = getUnsignedDivMod(0x80000000, n);

        q = toIntExact(iToUnsignedLong(div1[0]) - iToUnsignedLong(div2[0]));
        r = toIntExact(iToUnsignedLong(div1[1]) - iToUnsignedLong(div2[1]));
        mask = -(r >>> 31);
        r += mask & n;
        q += mask;

        return new int[]{q, r};
    }

    private static int checkLessThanZero(int x)
    {
        return -(int)(x >>> 31);
    }

    private static int checkNotEqualToZero(int x)
    {
        long l = iToUnsignedLong(x);
        l = -l;
        return -(int)(l >>> 63);
    }

    static int bToUnsignedInt(byte b)
    {
        return b & 0xff;
    }

    static int sToUnsignedInt(short s)
    {
        return s & 0xffff;
    }
    
    static long iToUnsignedLong(int i)
    {
        return i & 0xffffffffL;
    }

    static int toIntExact(long l)
    {
        int i = (int)l;

        if (i != l)
        {
            throw new IllegalStateException("value out of integer range");
        }
        return i;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy