All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.pqc.crypto.xmss.BDS Maven / Gradle / Ivy

package org.bouncycastle.pqc.crypto.xmss;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.TreeMap;

/**
 * BDS.
 */
public final class BDS
    implements Serializable
{
    private static final long serialVersionUID = 1L;
    
    private transient WOTSPlus wotsPlus;

    private final int treeHeight;
    private final List treeHashInstances;
    private int k;
    private XMSSNode root;
    private List authenticationPath;
    private Map> retain;
    private Stack stack;

    private Map keep;
    private int index;
    private boolean used;

    /**
     * Place holder BDS for when state is exhausted.
     *
     * @param params tree parameters
     * @param index the index that has been reached.
     */
    BDS(XMSSParameters params, int index)
    {
        this(params.getWOTSPlus(), params.getHeight(), params.getK());
        this.index = index;
        this.used = true;
    }

    /**
     * Set up constructor.
     *
     * @param params tree parameters
     * @param publicSeed public seed for tree
     * @param secretKeySeed secret seed for tree
     * @param otsHashAddress hash address
     */
    BDS(XMSSParameters params, byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress)
    {
        this(params.getWOTSPlus(), params.getHeight(), params.getK());
        this.initialize(publicSeed, secretKeySeed, otsHashAddress);
    }

    /**
     * Set up constructor for a tree where the original BDS state was lost.
     *
     * @param params tree parameters
     * @param publicSeed public seed for tree
     * @param secretKeySeed secret seed for tree
     * @param otsHashAddress hash address
     * @param index index counter for the state to be at.
     */
    BDS(XMSSParameters params, byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress, int index)
    {
        this(params.getWOTSPlus(), params.getHeight(), params.getK());

        this.initialize(publicSeed, secretKeySeed, otsHashAddress);

        while (this.index < index)
        {
            this.nextAuthenticationPath(publicSeed, secretKeySeed, otsHashAddress);
            this.used = false;
        }
    }

    private BDS(WOTSPlus wotsPlus, int treeHeight, int k)
    {
        this.wotsPlus = wotsPlus;
        this.treeHeight = treeHeight;
        this.k = k;
        if (k > treeHeight || k < 2 || ((treeHeight - k) % 2) != 0)
        {
            throw new IllegalArgumentException("illegal value for BDS parameter k");
        }
        authenticationPath = new ArrayList();
        retain = new TreeMap>();
        stack = new Stack();

        treeHashInstances = new ArrayList();
        for (int height = 0; height < (treeHeight - k); height++)
        {
            treeHashInstances.add(new BDSTreeHash(height));
        }

        keep = new TreeMap();
        index = 0;
        this.used = false;
    }

    private BDS(BDS last, byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress)
    {
        this.wotsPlus = last.wotsPlus;
        this.treeHeight = last.treeHeight;
        this.k = last.k;
        this.root = last.root;
        this.authenticationPath = new ArrayList(last.authenticationPath);
        this.retain = last.retain;
        this.stack = (Stack)last.stack.clone();
        this.treeHashInstances = last.treeHashInstances;
        this.keep = new TreeMap(last.keep);
        this.index = last.index;

        this.nextAuthenticationPath(publicSeed, secretKeySeed, otsHashAddress);

        last.used = true;
    }
    
    public BDS getNextState(byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress)
    {
        return new BDS(this, publicSeed, secretKeySeed, otsHashAddress);
    }

    private void initialize(byte[] publicSeed, byte[] secretSeed, OTSHashAddress otsHashAddress)
    {
        if (otsHashAddress == null)
        {
            throw new NullPointerException("otsHashAddress == null");
        }
        /* prepare addresses */
        LTreeAddress lTreeAddress = (LTreeAddress)new LTreeAddress.Builder()
            .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
            .build();
        HashTreeAddress hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder()
            .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
            .build();

		/* iterate indexes */
        for (int indexLeaf = 0; indexLeaf < (1 << treeHeight); indexLeaf++)
        {
			/* generate leaf */
            otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder()
                .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
                .withOTSAddress(indexLeaf).withChainAddress(otsHashAddress.getChainAddress())
                .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
                .build();
			/*
			 * import WOTSPlusSecretKey as its needed to calculate the public
			 * key on the fly
			 */
            wotsPlus.importKeys(wotsPlus.getWOTSPlusSecretKey(secretSeed, otsHashAddress), publicSeed);
            WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
            lTreeAddress = (LTreeAddress)new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())
                .withTreeAddress(lTreeAddress.getTreeAddress()).withLTreeAddress(indexLeaf)
                .withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex())
                .withKeyAndMask(lTreeAddress.getKeyAndMask()).build();
            XMSSNode node = XMSSNodeUtil.lTree(wotsPlus, wotsPlusPublicKey, lTreeAddress);

            hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder()
                .withLayerAddress(hashTreeAddress.getLayerAddress())
                .withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeIndex(indexLeaf)
                .withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
            while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight())
            {
				/* add to authenticationPath if leafIndex == 1 */
                int indexOnHeight = ((int)Math.floor(indexLeaf / (1 << node.getHeight())));
                if (indexOnHeight == 1)
                {
                    authenticationPath.add(node.clone());
                }
				/* store next right authentication node */
                if (indexOnHeight == 3 && node.getHeight() < (treeHeight - k))
                {
                    treeHashInstances.get(node.getHeight()).setNode(node.clone());
                }
                if (indexOnHeight >= 3 && (indexOnHeight & 1) == 1 && node.getHeight() >= (treeHeight - k)
                    && node.getHeight() <= (treeHeight - 2))
                {
                    if (retain.get(node.getHeight()) == null)
                    {
                        LinkedList queue = new LinkedList();
                        queue.add(node.clone());
                        retain.put(node.getHeight(), queue);
                    }
                    else
                    {
                        retain.get(node.getHeight()).add(node.clone());
                    }
                }
                hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder()
                    .withLayerAddress(hashTreeAddress.getLayerAddress())
                    .withTreeAddress(hashTreeAddress.getTreeAddress())
                    .withTreeHeight(hashTreeAddress.getTreeHeight())
                    .withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
                    .withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
                node = XMSSNodeUtil.randomizeHash(wotsPlus, stack.pop(), node, hashTreeAddress);
                node = new XMSSNode(node.getHeight() + 1, node.getValue());
                hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder()
                    .withLayerAddress(hashTreeAddress.getLayerAddress())
                    .withTreeAddress(hashTreeAddress.getTreeAddress())
                    .withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
                    .withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())
                    .build();
            }
			/* push to stack */
            stack.push(node);
        }
        root = stack.pop();
    }

    private void nextAuthenticationPath(byte[] publicSeed, byte[] secretSeed, OTSHashAddress otsHashAddress)
    {
        if (otsHashAddress == null)
        {
            throw new NullPointerException("otsHashAddress == null");
        }
        if (used)
        {
            throw new IllegalStateException("index already used");
        }
        if (index > ((1 << treeHeight) - 2))
        {
            throw new IllegalStateException("index out of bounds");
        }
		/* prepare addresses */
        LTreeAddress lTreeAddress = (LTreeAddress)new LTreeAddress.Builder()
            .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
            .build();
        HashTreeAddress hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder()
            .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
            .build();

		/* determine tau */
        int tau = XMSSUtil.calculateTau(index, treeHeight);
    	/* parent of leaf on height tau+1 is a left node */
        if (((index >> (tau + 1)) & 1) == 0 && (tau < (treeHeight - 1)))
        {
            keep.put(tau, authenticationPath.get(tau).clone());
        }
		/* leaf is a left node */
        if (tau == 0)
        {
            otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder()
                .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
                .withOTSAddress(index).withChainAddress(otsHashAddress.getChainAddress())
                .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
                .build();
			/*
			 * import WOTSPlusSecretKey as its needed to calculate the public
			 * key on the fly
			 */
            wotsPlus.importKeys(wotsPlus.getWOTSPlusSecretKey(secretSeed, otsHashAddress), publicSeed);
            WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
            lTreeAddress = (LTreeAddress)new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())
                .withTreeAddress(lTreeAddress.getTreeAddress()).withLTreeAddress(index)
                .withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex())
                .withKeyAndMask(lTreeAddress.getKeyAndMask()).build();
            XMSSNode node = XMSSNodeUtil.lTree(wotsPlus, wotsPlusPublicKey, lTreeAddress);
            authenticationPath.set(0, node);
        }
        else
        {
			/* add new left node on height tau to authentication path */
            hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder()
                .withLayerAddress(hashTreeAddress.getLayerAddress())
                .withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeHeight(tau - 1)
                .withTreeIndex(index >> tau).withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
            XMSSNode node = XMSSNodeUtil.randomizeHash(wotsPlus, authenticationPath.get(tau - 1), keep.get(tau - 1), hashTreeAddress);
            node = new XMSSNode(node.getHeight() + 1, node.getValue());
            authenticationPath.set(tau, node);
            keep.remove(tau - 1);

			/* add new right nodes to authentication path */
            for (int height = 0; height < tau; height++)
            {
                if (height < (treeHeight - k))
                {
                    authenticationPath.set(height, treeHashInstances.get(height).getTailNode());
                }
                else
                {
                    authenticationPath.set(height, retain.get(height).removeFirst());
                }
            }

			/* reinitialize treehash instances */
            int minHeight = Math.min(tau, treeHeight - k);
            for (int height = 0; height < minHeight; height++)
            {
                int startIndex = index + 1 + (3 * (1 << height));
                if (startIndex < (1 << treeHeight))
                {
                    treeHashInstances.get(height).initialize(startIndex);
                }
            }
        }
 
		/* update treehash instances */
        for (int i = 0; i < (treeHeight - k) >> 1; i++)
        {
            BDSTreeHash treeHash = getBDSTreeHashInstanceForUpdate();
            if (treeHash != null)
            {
                treeHash.update(stack, wotsPlus, publicSeed, secretSeed, otsHashAddress);
            }
        }

        index++;
    }

    boolean isUsed()
    {
        return used;
    }

    private BDSTreeHash getBDSTreeHashInstanceForUpdate()
    {
        BDSTreeHash ret = null;
        for (BDSTreeHash treeHash : treeHashInstances)
        {
            if (treeHash.isFinished() || !treeHash.isInitialized())
            {
                continue;
            }
            if (ret == null)
            {
                ret = treeHash;
                continue;
            }
            if (treeHash.getHeight() < ret.getHeight())
            {
                ret = treeHash;
                continue;
            }
            if (treeHash.getHeight() == ret.getHeight())
            {
                if (treeHash.getIndexLeaf() < ret.getIndexLeaf())
                {
                    ret = treeHash;
                }
            }
        }
        return ret;
    }

    protected void validate()
    {
        if (authenticationPath == null)
        {
            throw new IllegalStateException("authenticationPath == null");
        }
        if (retain == null)
        {
            throw new IllegalStateException("retain == null");
        }
        if (stack == null)
        {
            throw new IllegalStateException("stack == null");
        }
        if (treeHashInstances == null)
        {
            throw new IllegalStateException("treeHashInstances == null");
        }
        if (keep == null)
        {
            throw new IllegalStateException("keep == null");
        }
        if (!XMSSUtil.isIndexValid(treeHeight, index))
        {
            throw new IllegalStateException("index in BDS state out of bounds");
        }
    }

    protected int getTreeHeight()
    {
        return treeHeight;
    }

    protected XMSSNode getRoot()
    {
        return root.clone();
    }

    protected List getAuthenticationPath()
    {
        List authenticationPath = new ArrayList();
        for (XMSSNode node : this.authenticationPath)
        {
            authenticationPath.add(node.clone());
        }
        return authenticationPath;
    }

    protected void setXMSS(XMSSParameters xmss)
    {
        if (treeHeight != xmss.getHeight())
        {
            throw new IllegalStateException("wrong height");
        }

        this.wotsPlus = xmss.getWOTSPlus();
    }

    protected int getIndex()
    {
        return index;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy