All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.pqc.crypto.gmss.GMSSRootSig Maven / Gradle / Ivy

There is a newer version: 1.2.2.1-jre17
Show newest version
package org.bouncycastle.pqc.crypto.gmss;

import org.bouncycastle.crypto.Digest;
import org.bouncycastle.pqc.crypto.gmss.util.GMSSRandom;
import org.bouncycastle.util.encoders.Hex;


/**
 * This class implements the distributed signature generation of the Winternitz
 * one-time signature scheme (OTSS), described in C.Dods, N.P. Smart, and M.
 * Stam, "Hash Based Digital Signature Schemes", LNCS 3796, pages 96–115,
 * 2005. The class is used by the GMSS classes.
 */
public class GMSSRootSig
{

    /**
     * The hash function used by the OTS
     */
    private Digest messDigestOTS;

    /**
     * The length of the message digest and private key
     */
    private int mdsize, keysize;

    /**
     * The private key
     */
    private byte[] privateKeyOTS;

    /**
     * The message bytes
     */
    private byte[] hash;

    /**
     * The signature bytes
     */
    private byte[] sign;

    /**
     * The Winternitz parameter
     */
    private int w;

    /**
     * The source of randomness for OTS private key generation
     */
    private GMSSRandom gmssRandom;

    /**
     * Sizes of the message
     */
    private int messagesize;

    /**
     * Some precalculated values
     */
    private int k;

    /**
     * Some variables for storing the actual status of distributed signing
     */
    private int r, test, counter, ii;

    /**
     * variables for storing big numbers for the actual status of distributed
     * signing
     */
    private long test8, big8;

    /**
     * The necessary steps of each updateSign() call
     */
    private int steps;

    /**
     * The checksum part
     */
    private int checksum;

    /**
     * The height of the tree
     */
    private int height;

    /**
     * The current intern OTSseed
     */
    private byte[] seed;

    /**
     * This constructor regenerates a prior GMSSRootSig object used by the
     * GMSSPrivateKeyASN.1 class
     *
     * @param digest     an array of strings, containing the digest of the used hash
     *                 function, the digest of the PRGN and the names of the
     *                 corresponding providers
     * @param statByte status byte array
     * @param statInt  status int array
     */
    public GMSSRootSig(Digest digest, byte[][] statByte, int[] statInt)
    {
        messDigestOTS = digest;
        gmssRandom = new GMSSRandom(messDigestOTS);

        this.counter = statInt[0];
        this.test = statInt[1];
        this.ii = statInt[2];
        this.r = statInt[3];
        this.steps = statInt[4];
        this.keysize = statInt[5];
        this.height = statInt[6];
        this.w = statInt[7];
        this.checksum = statInt[8];

        this.mdsize = messDigestOTS.getDigestSize();

        this.k = (1 << w) - 1;

        int mdsizeBit = mdsize << 3;
        this.messagesize = (int)Math.ceil((double)(mdsizeBit) / (double)w);

        this.privateKeyOTS = statByte[0];
        this.seed = statByte[1];
        this.hash = statByte[2];

        this.sign = statByte[3];

        this.test8 = ((statByte[4][0] & 0xff))
            | ((long)(statByte[4][1] & 0xff) << 8)
            | ((long)(statByte[4][2] & 0xff) << 16)
            | ((long)(statByte[4][3] & 0xff)) << 24
            | ((long)(statByte[4][4] & 0xff)) << 32
            | ((long)(statByte[4][5] & 0xff)) << 40
            | ((long)(statByte[4][6] & 0xff)) << 48
            | ((long)(statByte[4][7] & 0xff)) << 56;

        this.big8 = ((statByte[4][8] & 0xff))
            | ((long)(statByte[4][9] & 0xff) << 8)
            | ((long)(statByte[4][10] & 0xff) << 16)
            | ((long)(statByte[4][11] & 0xff)) << 24
            | ((long)(statByte[4][12] & 0xff)) << 32
            | ((long)(statByte[4][13] & 0xff)) << 40
            | ((long)(statByte[4][14] & 0xff)) << 48
            | ((long)(statByte[4][15] & 0xff)) << 56;
    }

    /**
     * The constructor generates the PRNG and initializes some variables
     *
     * @param digest   an array of strings, containing the digest of the used hash
     *               function, the digest of the PRGN and the names of the
     *               corresponding providers
     * @param w      the winternitz parameter
     * @param height the heigth of the tree
     */
    public GMSSRootSig(Digest digest, int w, int height)
    {
        messDigestOTS = digest;
        gmssRandom = new GMSSRandom(messDigestOTS);

        this.mdsize = messDigestOTS.getDigestSize();
        this.w = w;
        this.height = height;

        this.k = (1 << w) - 1;

        int mdsizeBit = mdsize << 3;
        this.messagesize = (int)Math.ceil((double)(mdsizeBit) / (double)w);
    }

    /**
     * This method initializes the distributed sigature calculation. Variables
     * are reseted and necessary steps are calculated
     *
     * @param seed0   the initial OTSseed
     * @param message the massage which will be signed
     */
    public void initSign(byte[] seed0, byte[] message)
    {

        // create hash of message m
        this.hash = new byte[mdsize];
        messDigestOTS.update(message, 0, message.length);
        this.hash = new byte[messDigestOTS.getDigestSize()];
        messDigestOTS.doFinal(this.hash, 0);

        // variables for calculation of steps
        byte[] messPart = new byte[mdsize];
        System.arraycopy(hash, 0, messPart, 0, mdsize);
        int checkPart = 0;
        int sumH = 0;
        int checksumsize = getLog((messagesize << w) + 1);

        // ------- calculation of necessary steps ------
        if (8 % w == 0)
        {
            int dt = 8 / w;
            // message part
            for (int a = 0; a < mdsize; a++)
            {
                // count necessary hashs in 'sumH'
                for (int b = 0; b < dt; b++)
                {
                    sumH += messPart[a] & k;
                    messPart[a] = (byte)(messPart[a] >>> w);
                }
            }
            // checksum part
            this.checksum = (messagesize << w) - sumH;
            checkPart = checksum;
            // count necessary hashs in 'sumH'
            for (int b = 0; b < checksumsize; b += w)
            {
                sumH += checkPart & k;
                checkPart >>>= w;
            }
        } // end if ( 8 % w == 0 )
        else if (w < 8)
        {
            long big8;
            int ii = 0;
            int dt = mdsize / w;

            // first d*w bytes of hash (main message part)
            for (int i = 0; i < dt; i++)
            {
                big8 = 0;
                for (int j = 0; j < w; j++)
                {
                    big8 ^= (messPart[ii] & 0xff) << (j << 3);
                    ii++;
                }
                // count necessary hashs in 'sumH'
                for (int j = 0; j < 8; j++)
                {
                    sumH += (int)(big8 & k);
                    big8 >>>= w;
                }
            }
            // rest of message part
            dt = mdsize % w;
            big8 = 0;
            for (int j = 0; j < dt; j++)
            {
                big8 ^= (messPart[ii] & 0xff) << (j << 3);
                ii++;
            }
            dt <<= 3;
            // count necessary hashs in 'sumH'
            for (int j = 0; j < dt; j += w)
            {
                sumH += (int)(big8 & k);
                big8 >>>= w;
            }
            // checksum part
            this.checksum = (messagesize << w) - sumH;
            checkPart = checksum;
            // count necessary hashs in 'sumH'
            for (int i = 0; i < checksumsize; i += w)
            {
                sumH += checkPart & k;
                checkPart >>>= w;
            }
        }// end if(w<8)
        else if (w < 57)
        {
            long big8;
            int r = 0;
            int s, f, rest, ii;

            // first a*w bits of hash where a*w <= 8*mdsize < (a+1)*w (main
            // message part)
            while (r <= ((mdsize << 3) - w))
            {
                s = r >>> 3;
                rest = r % 8;
                r += w;
                f = (r + 7) >>> 3;
                big8 = 0;
                ii = 0;
                for (int j = s; j < f; j++)
                {
                    big8 ^= (messPart[j] & 0xff) << (ii << 3);
                    ii++;
                }
                big8 >>>= rest;
                // count necessary hashs in 'sumH'
                sumH += (big8 & k);

            }
            // rest of message part
            s = r >>> 3;
            if (s < mdsize)
            {
                rest = r % 8;
                big8 = 0;
                ii = 0;
                for (int j = s; j < mdsize; j++)
                {
                    big8 ^= (messPart[j] & 0xff) << (ii << 3);
                    ii++;
                }

                big8 >>>= rest;
                // count necessary hashs in 'sumH'
                sumH += (big8 & k);
            }
            // checksum part
            this.checksum = (messagesize << w) - sumH;
            checkPart = checksum;
            // count necessary hashs in 'sumH'
            for (int i = 0; i < checksumsize; i += w)
            {
                sumH += (checkPart & k);
                checkPart >>>= w;
            }
        }// end if(w<57)

        // calculate keysize
        this.keysize = messagesize
            + (int)Math.ceil((double)checksumsize / (double)w);

        // calculate steps: 'keysize' times PRNG, 'sumH' times hashing,
        // (1<steps steps of distributed signature
     * calculaion
     *
     * @return true if signature is generated completly, else false
     */
    public boolean updateSign()
    {
        // steps times do

        for (int s = 0; s < steps; s++)
        { // do 'step' times

            if (counter < keysize)
            { // generate the private key or perform
                // the next hash
                oneStep();
            }
            if (counter == keysize)
            {// finish
                return true;
            }
        }

        return false; // leaf not finished yet
    }

    /**
     * @return The private OTS key
     */
    public byte[] getSig()
    {

        return sign;
    }

    /**
     * @return The one-time signature of the message, generated step by step
     */
    private void oneStep()
    {
        // -------- if (8 % w == 0) ----------
        if (8 % w == 0)
        {
            if (test == 0)
            {
                // get current OTSprivateKey
                this.privateKeyOTS = gmssRandom.nextSeed(seed);
                // System.arraycopy(privateKeyOTS, 0, hlp, 0, mdsize);

                if (ii < mdsize)
                { // for main message part
                    test = hash[ii] & k;
                    hash[ii] = (byte)(hash[ii] >>> w);
                }
                else
                { // for checksum part
                    test = checksum & k;
                    checksum >>>= w;
                }
            }
            else if (test > 0)
            { // hash the private Key 'test' times (on
                // time each step)
                messDigestOTS.update(privateKeyOTS, 0, privateKeyOTS.length);
                privateKeyOTS = new byte[messDigestOTS.getDigestSize()];
                messDigestOTS.doFinal(privateKeyOTS, 0);
                test--;
            }
            if (test == 0)
            { // if all hashes done copy result to siganture
                // array
                System.arraycopy(privateKeyOTS, 0, sign, counter * mdsize,
                    mdsize);
                counter++;

                if (counter % (8 / w) == 0)
                { // raise array index for main
                    // massage part
                    ii++;
                }
            }

        }// ----- end if (8 % w == 0) -----
        // ---------- if ( w < 8 ) ----------------
        else if (w < 8)
        {

            if (test == 0)
            {
                if (counter % 8 == 0 && ii < mdsize)
                { // after every 8th "add
                    // to signature"-step
                    big8 = 0;
                    if (counter < ((mdsize / w) << 3))
                    {// main massage
                        // (generate w*8 Bits
                        // every time) part
                        for (int j = 0; j < w; j++)
                        {
                            big8 ^= (hash[ii] & 0xff) << (j << 3);
                            ii++;
                        }
                    }
                    else
                    { // rest of massage part (once)
                        for (int j = 0; j < mdsize % w; j++)
                        {
                            big8 ^= (hash[ii] & 0xff) << (j << 3);
                            ii++;
                        }
                    }
                }
                if (counter == messagesize)
                { // checksum part (once)
                    big8 = checksum;
                }

                test = (int)(big8 & k);
                // generate current OTSprivateKey
                this.privateKeyOTS = gmssRandom.nextSeed(seed);
                // System.arraycopy(privateKeyOTS, 0, hlp, 0, mdsize);

            }
            else if (test > 0)
            { // hash the private Key 'test' times (on
                // time each step)
                messDigestOTS.update(privateKeyOTS, 0, privateKeyOTS.length);
                privateKeyOTS = new byte[messDigestOTS.getDigestSize()];
                messDigestOTS.doFinal(privateKeyOTS, 0);
                test--;
            }
            if (test == 0)
            { // if all hashes done copy result to siganture
                // array
                System.arraycopy(privateKeyOTS, 0, sign, counter * mdsize,
                    mdsize);
                big8 >>>= w;
                counter++;
            }

        }// ------- end if(w<8)--------------------------------
        // --------- if w < 57 -----------------------------
        else if (w < 57)
        {

            if (test8 == 0)
            {
                int s, f, rest;
                big8 = 0;
                ii = 0;
                rest = r % 8;
                s = r >>> 3;
                // --- message part---
                if (s < mdsize)
                {
                    if (r <= ((mdsize << 3) - w))
                    { // first message part
                        r += w;
                        f = (r + 7) >>> 3;
                    }
                    else
                    { // rest of message part (once)
                        f = mdsize;
                        r += w;
                    }
                    // generate long 'big8' with minimum w next bits of the
                    // message array
                    for (int i = s; i < f; i++)
                    {
                        big8 ^= (hash[i] & 0xff) << (ii << 3);
                        ii++;
                    }
                    // delete bits on the right side, which were used already by
                    // the last loop
                    big8 >>>= rest;
                    test8 = (big8 & k);
                }
                // --- checksum part
                else
                {
                    test8 = (checksum & k);
                    checksum >>>= w;
                }
                // generate current OTSprivateKey
                this.privateKeyOTS = gmssRandom.nextSeed(seed);
                // System.arraycopy(privateKeyOTS, 0, hlp, 0, mdsize);

            }
            else if (test8 > 0)
            { // hash the private Key 'test' times (on
                // time each step)
                messDigestOTS.update(privateKeyOTS, 0, privateKeyOTS.length);
                privateKeyOTS = new byte[messDigestOTS.getDigestSize()];
                messDigestOTS.doFinal(privateKeyOTS, 0);
                test8--;
            }
            if (test8 == 0)
            { // if all hashes done copy result to siganture
                // array
                System.arraycopy(privateKeyOTS, 0, sign, counter * mdsize,
                    mdsize);
                counter++;
            }

        }
    }

    /**
     * This method returns the least integer that is greater or equal to the
     * logarithm to the base 2 of an integer intValue.
     *
     * @param intValue an integer
     * @return The least integer greater or equal to the logarithm to the base 2
     *         of intValue
     */
    public int getLog(int intValue)
    {
        int log = 1;
        int i = 2;
        while (i < intValue)
        {
            i <<= 1;
            log++;
        }
        return log;
    }

    /**
     * This method returns the status byte array
     *
     * @return statBytes
     */
    public byte[][] getStatByte()
    {

        byte[][] statByte = new byte[5][mdsize];
        statByte[0] = privateKeyOTS;
        statByte[1] = seed;
        statByte[2] = hash;
        statByte[3] = sign;
        statByte[4] = this.getStatLong();

        return statByte;
    }

    /**
     * This method returns the status int array
     *
     * @return statInt
     */
    public int[] getStatInt()
    {
        int[] statInt = new int[9];
        statInt[0] = counter;
        statInt[1] = test;
        statInt[2] = ii;
        statInt[3] = r;
        statInt[4] = steps;
        statInt[5] = keysize;
        statInt[6] = height;
        statInt[7] = w;
        statInt[8] = checksum;
        return statInt;
    }

    /**
     * Converts the long parameters into byte arrays to store it in
     * statByte-Array
     */
    public byte[] getStatLong()
    {
        byte[] bytes = new byte[16];

        bytes[0] = (byte)((test8) & 0xff);
        bytes[1] = (byte)((test8 >> 8) & 0xff);
        bytes[2] = (byte)((test8 >> 16) & 0xff);
        bytes[3] = (byte)((test8 >> 24) & 0xff);
        bytes[4] = (byte)((test8) >> 32 & 0xff);
        bytes[5] = (byte)((test8 >> 40) & 0xff);
        bytes[6] = (byte)((test8 >> 48) & 0xff);
        bytes[7] = (byte)((test8 >> 56) & 0xff);

        bytes[8] = (byte)((big8) & 0xff);
        bytes[9] = (byte)((big8 >> 8) & 0xff);
        bytes[10] = (byte)((big8 >> 16) & 0xff);
        bytes[11] = (byte)((big8 >> 24) & 0xff);
        bytes[12] = (byte)((big8) >> 32 & 0xff);
        bytes[13] = (byte)((big8 >> 40) & 0xff);
        bytes[14] = (byte)((big8 >> 48) & 0xff);
        bytes[15] = (byte)((big8 >> 56) & 0xff);

        return bytes;
    }

    /**
     * returns a string representation of the instance
     *
     * @return a string representation of the instance
     */
    public String toString()
    {
        String out = "" + this.big8 + "  ";
        int[] statInt = new int[9];
        statInt = this.getStatInt();
        byte[][] statByte = new byte[5][mdsize];
        statByte = this.getStatByte();
        for (int i = 0; i < 9; i++)
        {
            out = out + statInt[i] + " ";
        }
        for (int i = 0; i < 5; i++)
        {
            out = out + new String(Hex.encode(statByte[i])) + " ";
        }

        return out;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy