All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.pqc.crypto.xmss.BDS Maven / Gradle / Ivy

Go to download

The Bouncy Castle Crypto package is a Java implementation of cryptographic algorithms. This jar contains JCE provider and lightweight API for the Bouncy Castle Cryptography APIs for JDK 1.5 to JDK 1.8. Note: this package includes the NTRU encryption algorithms.

The newest version!
package org.bouncycastle.pqc.crypto.xmss;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.TreeMap;

/**
 * BDS.
 *
 */
public final class BDS implements Serializable {

	private static final long serialVersionUID = 1L;

	private final class TreeHash implements Serializable {

		private static final long serialVersionUID = 1L;

		private XMSSNode tailNode;
		private final int initialHeight;
		private int height;
		private int nextIndex;
		private boolean initialized;
		private boolean finished;

		private TreeHash(int initialHeight) {
			super();
			this.initialHeight = initialHeight;
			initialized = false;
			finished = false;
		}

		private void initialize(int nextIndex) {
			tailNode = null;
			height = initialHeight;
			this.nextIndex = nextIndex;
			initialized = true;
			finished = false;
		}

		private void update(OTSHashAddress otsHashAddress) {
			if (otsHashAddress == null) {
				throw new NullPointerException("otsHashAddress == null");
			}
			if (finished || !initialized) {
				throw new IllegalStateException("finished or not initialized");
			}
			/* prepare addresses */
			otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withOTSAddress(nextIndex).withChainAddress(otsHashAddress.getChainAddress())
					.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
					.build();
			LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withLTreeAddress(nextIndex).build();
			HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withTreeIndex(nextIndex).build();

			/* calculate leaf node */
			wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
			WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
			XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);

			while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight()
					&& stack.peek().getHeight() != initialHeight) {
				hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
						.withLayerAddress(hashTreeAddress.getLayerAddress())
						.withTreeAddress(hashTreeAddress.getTreeAddress())
						.withTreeHeight(hashTreeAddress.getTreeHeight())
						.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
						.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
				node = xmss.randomizeHash(stack.pop(), node, hashTreeAddress);
				node = new XMSSNode(node.getHeight() + 1, node.getValue());
				hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
						.withLayerAddress(hashTreeAddress.getLayerAddress())
						.withTreeAddress(hashTreeAddress.getTreeAddress())
						.withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
						.withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())
						.build();
			}

			if (tailNode == null) {
				tailNode = node;
			} else {
				if (tailNode.getHeight() == node.getHeight()) {
					hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
							.withLayerAddress(hashTreeAddress.getLayerAddress())
							.withTreeAddress(hashTreeAddress.getTreeAddress())
							.withTreeHeight(hashTreeAddress.getTreeHeight())
							.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
							.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
					node = xmss.randomizeHash(tailNode, node, hashTreeAddress);
					node = new XMSSNode(tailNode.getHeight() + 1, node.getValue());
					tailNode = node;
					hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
							.withLayerAddress(hashTreeAddress.getLayerAddress())
							.withTreeAddress(hashTreeAddress.getTreeAddress())
							.withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
							.withTreeIndex(hashTreeAddress.getTreeIndex())
							.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
				} else {
					stack.push(node);
				}
			}

			if (tailNode.getHeight() == initialHeight) {
				finished = true;
			} else {
				height = node.getHeight();
				nextIndex++;
			}
		}

		private int getHeight() {
			if (!initialized || finished) {
				return Integer.MAX_VALUE;
			}
			return height;
		}

		private int getIndexLeaf() {
			return nextIndex;
		}

		private void setNode(XMSSNode node) {
			tailNode = node;
			height = node.getHeight();
			if (height == initialHeight) {
				finished = true;
			}
		}

		private boolean isFinished() {
			return finished;
		}

		private boolean isInitialized() {
			return initialized;
		}
	}

	private transient XMSS xmss;
	private transient WOTSPlus wotsPlus;
	private final int treeHeight;
	private int k;
	private XMSSNode root;
	private List authenticationPath;
	private Map> retain;
	private Stack stack;
	private List treeHashInstances;
	private Map keep;
	private int index;

	protected BDS(XMSS xmss) {
		super();
		if (xmss == null) {
			throw new NullPointerException("xmss == null");
		}
		this.xmss = xmss;
		wotsPlus = xmss.getWOTSPlus();
		treeHeight = xmss.getParams().getHeight();
		k = xmss.getParams().getK();
		if (k > treeHeight || k < 2 || ((treeHeight - k) % 2) != 0) {
			throw new IllegalArgumentException("illegal value for BDS parameter k");
		}
		authenticationPath = new ArrayList();
		retain = new TreeMap>();
		stack = new Stack();
		initializeTreeHashInstances();
		keep = new TreeMap();
		index = 0;
	}

	private void initializeTreeHashInstances() {
		treeHashInstances = new ArrayList();
		for (int height = 0; height < (treeHeight - k); height++) {
			treeHashInstances.add(new TreeHash(height));
		}
	}

	protected XMSSNode initialize(OTSHashAddress otsHashAddress) {
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		/* prepare addresses */
		LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
				.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
				.build();
		HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
				.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
				.build();

		/* iterate indexes */
		for (int indexLeaf = 0; indexLeaf < (1 << treeHeight); indexLeaf++) {
			/* generate leaf */
			otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withOTSAddress(indexLeaf).withChainAddress(otsHashAddress.getChainAddress())
					.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
					.build();
			/*
			 * import WOTSPlusSecretKey as its needed to calculate the public
			 * key on the fly
			 */
			wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
			WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
			lTreeAddress = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())
					.withTreeAddress(lTreeAddress.getTreeAddress()).withLTreeAddress(indexLeaf)
					.withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex())
					.withKeyAndMask(lTreeAddress.getKeyAndMask()).build();
			XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);

			hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
					.withLayerAddress(hashTreeAddress.getLayerAddress())
					.withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeIndex(indexLeaf)
					.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
			while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight()) {
				/* add to authenticationPath if leafIndex == 1 */
				int indexOnHeight = ((int) Math.floor(indexLeaf / (1 << node.getHeight())));
				if (indexOnHeight == 1) {
					authenticationPath.add(node.clone());
				}
				/* store next right authentication node */
				if (indexOnHeight == 3 && node.getHeight() < (treeHeight - k)) {
					treeHashInstances.get(node.getHeight()).setNode(node.clone());
				}
				if (indexOnHeight >= 3 && (indexOnHeight & 1) == 1 && node.getHeight() >= (treeHeight - k)
						&& node.getHeight() <= (treeHeight - 2)) {
					if (retain.get(node.getHeight()) == null) {
						LinkedList queue = new LinkedList();
						queue.add(node.clone());
						retain.put(node.getHeight(), queue);
					} else {
						retain.get(node.getHeight()).add(node.clone());
					}
				}
				hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
						.withLayerAddress(hashTreeAddress.getLayerAddress())
						.withTreeAddress(hashTreeAddress.getTreeAddress())
						.withTreeHeight(hashTreeAddress.getTreeHeight())
						.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
						.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
				node = xmss.randomizeHash(stack.pop(), node, hashTreeAddress);
				node = new XMSSNode(node.getHeight() + 1, node.getValue());
				hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
						.withLayerAddress(hashTreeAddress.getLayerAddress())
						.withTreeAddress(hashTreeAddress.getTreeAddress())
						.withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
						.withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())
						.build();
			}
			/* push to stack */
			stack.push(node);
		}
		root = stack.pop();
		return root.clone();
	}

	protected void nextAuthenticationPath(OTSHashAddress otsHashAddress) {
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		if (index > ((1 << treeHeight) - 2)) {
			throw new IllegalStateException("index out of bounds");
		}
		/* prepare addresses */
		LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
				.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
				.build();
		HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
				.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
				.build();

		/* determine tau */
		int tau = XMSSUtil.calculateTau(index, treeHeight);

		/* parent of leaf on height tau+1 is a left node */
		if (((index >> (tau + 1)) & 1) == 0 && (tau < (treeHeight - 1))) {
			keep.put(tau, authenticationPath.get(tau).clone());
		}
		/* leaf is a left node */
		if (tau == 0) {
			otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withOTSAddress(index).withChainAddress(otsHashAddress.getChainAddress())
					.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
					.build();
			/*
			 * import WOTSPlusSecretKey as its needed to calculate the public
			 * key on the fly
			 */
			wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
			WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
			lTreeAddress = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())
					.withTreeAddress(lTreeAddress.getTreeAddress()).withLTreeAddress(index)
					.withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex())
					.withKeyAndMask(lTreeAddress.getKeyAndMask()).build();
			XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
			authenticationPath.set(0, node);
		} else {
			/* add new left node on height tau to authentication path */
			hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
					.withLayerAddress(hashTreeAddress.getLayerAddress())
					.withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeHeight(tau - 1)
					.withTreeIndex(index >> tau).withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
			XMSSNode node = xmss.randomizeHash(authenticationPath.get(tau - 1), keep.get(tau - 1), hashTreeAddress);
			node = new XMSSNode(node.getHeight() + 1, node.getValue());
			authenticationPath.set(tau, node);
			keep.remove(tau - 1);

			/* add new right nodes to authentication path */
			for (int height = 0; height < tau; height++) {
				if (height < (treeHeight - k)) {
					authenticationPath.set(height, treeHashInstances.get(height).tailNode.clone());
				} else {
					authenticationPath.set(height, retain.get(height).removeFirst());
				}
			}

			/* reinitialize treehash instances */
			int minHeight = Math.min(tau, treeHeight - k);
			for (int height = 0; height < minHeight; height++) {
				int startIndex = index + 1 + (3 * (1 << height));
				if (startIndex < (1 << treeHeight)) {
					treeHashInstances.get(height).initialize(startIndex);
				}
			}
		}

		/* update treehash instances */
		for (int i = 0; i < (treeHeight - k) >> 1; i++) {
			TreeHash treeHash = getTreeHashInstanceForUpdate();
			if (treeHash != null) {
				treeHash.update(otsHashAddress);
			}
		}
		index++;
	}

	private TreeHash getTreeHashInstanceForUpdate() {
		TreeHash ret = null;
		for (TreeHash treeHash : treeHashInstances) {
			if (treeHash.isFinished() || !treeHash.isInitialized()) {
				continue;
			}
			if (ret == null) {
				ret = treeHash;
				continue;
			}
			if (treeHash.getHeight() < ret.getHeight()) {
				ret = treeHash;
				continue;
			}
			if (treeHash.getHeight() == ret.getHeight()) {
				if (treeHash.getIndexLeaf() < ret.getIndexLeaf()) {
					ret = treeHash;
				}
			}
		}
		return ret;
	}

	protected void validate() {
		if (treeHeight != xmss.getParams().getHeight()) {
			throw new IllegalStateException("wrong height");
		}
		if (authenticationPath == null) {
			throw new IllegalStateException("authenticationPath == null");
		}
		if (retain == null) {
			throw new IllegalStateException("retain == null");
		}
		if (stack == null) {
			throw new IllegalStateException("stack == null");
		}
		if (treeHashInstances == null) {
			throw new IllegalStateException("treeHashInstances == null");
		}
		if (keep == null) {
			throw new IllegalStateException("keep == null");
		}
		if (!XMSSUtil.isIndexValid(treeHeight, index)) {
			throw new IllegalStateException("index in BDS state out of bounds");
		}
	}

	protected int getTreeHeight() {
		return treeHeight;
	}

	protected XMSSNode getRoot() {
		return root.clone();
	}

	protected List getAuthenticationPath() {
		List authenticationPath = new ArrayList();
		for (XMSSNode node : this.authenticationPath) {
			authenticationPath.add(node.clone());
		}
		return authenticationPath;
	}

	protected void setXMSS(XMSS xmss) {
		this.xmss = xmss;
		this.wotsPlus = xmss.getWOTSPlus();
	}

	protected int getIndex() {
		return index;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy