org.bouncycastle.pqc.crypto.xmss.BDS Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of bcprov-ext-jdk15on Show documentation
Show all versions of bcprov-ext-jdk15on Show documentation
The Bouncy Castle Crypto package is a Java implementation of cryptographic algorithms. This jar contains JCE provider and lightweight API for the Bouncy Castle Cryptography APIs for JDK 1.5 to JDK 1.8. Note: this package includes the NTRU encryption algorithms.
The newest version!
package org.bouncycastle.pqc.crypto.xmss;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.TreeMap;
/**
* BDS.
*
*/
public final class BDS implements Serializable {
private static final long serialVersionUID = 1L;
private final class TreeHash implements Serializable {
private static final long serialVersionUID = 1L;
private XMSSNode tailNode;
private final int initialHeight;
private int height;
private int nextIndex;
private boolean initialized;
private boolean finished;
private TreeHash(int initialHeight) {
super();
this.initialHeight = initialHeight;
initialized = false;
finished = false;
}
private void initialize(int nextIndex) {
tailNode = null;
height = initialHeight;
this.nextIndex = nextIndex;
initialized = true;
finished = false;
}
private void update(OTSHashAddress otsHashAddress) {
if (otsHashAddress == null) {
throw new NullPointerException("otsHashAddress == null");
}
if (finished || !initialized) {
throw new IllegalStateException("finished or not initialized");
}
/* prepare addresses */
otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withOTSAddress(nextIndex).withChainAddress(otsHashAddress.getChainAddress())
.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
.build();
LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withLTreeAddress(nextIndex).build();
HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withTreeIndex(nextIndex).build();
/* calculate leaf node */
wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight()
&& stack.peek().getHeight() != initialHeight) {
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight())
.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
node = xmss.randomizeHash(stack.pop(), node, hashTreeAddress);
node = new XMSSNode(node.getHeight() + 1, node.getValue());
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
.withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())
.build();
}
if (tailNode == null) {
tailNode = node;
} else {
if (tailNode.getHeight() == node.getHeight()) {
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight())
.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
node = xmss.randomizeHash(tailNode, node, hashTreeAddress);
node = new XMSSNode(tailNode.getHeight() + 1, node.getValue());
tailNode = node;
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
.withTreeIndex(hashTreeAddress.getTreeIndex())
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
} else {
stack.push(node);
}
}
if (tailNode.getHeight() == initialHeight) {
finished = true;
} else {
height = node.getHeight();
nextIndex++;
}
}
private int getHeight() {
if (!initialized || finished) {
return Integer.MAX_VALUE;
}
return height;
}
private int getIndexLeaf() {
return nextIndex;
}
private void setNode(XMSSNode node) {
tailNode = node;
height = node.getHeight();
if (height == initialHeight) {
finished = true;
}
}
private boolean isFinished() {
return finished;
}
private boolean isInitialized() {
return initialized;
}
}
private transient XMSS xmss;
private transient WOTSPlus wotsPlus;
private final int treeHeight;
private int k;
private XMSSNode root;
private List authenticationPath;
private Map> retain;
private Stack stack;
private List treeHashInstances;
private Map keep;
private int index;
protected BDS(XMSS xmss) {
super();
if (xmss == null) {
throw new NullPointerException("xmss == null");
}
this.xmss = xmss;
wotsPlus = xmss.getWOTSPlus();
treeHeight = xmss.getParams().getHeight();
k = xmss.getParams().getK();
if (k > treeHeight || k < 2 || ((treeHeight - k) % 2) != 0) {
throw new IllegalArgumentException("illegal value for BDS parameter k");
}
authenticationPath = new ArrayList();
retain = new TreeMap>();
stack = new Stack();
initializeTreeHashInstances();
keep = new TreeMap();
index = 0;
}
private void initializeTreeHashInstances() {
treeHashInstances = new ArrayList();
for (int height = 0; height < (treeHeight - k); height++) {
treeHashInstances.add(new TreeHash(height));
}
}
protected XMSSNode initialize(OTSHashAddress otsHashAddress) {
if (otsHashAddress == null) {
throw new NullPointerException("otsHashAddress == null");
}
/* prepare addresses */
LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.build();
HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.build();
/* iterate indexes */
for (int indexLeaf = 0; indexLeaf < (1 << treeHeight); indexLeaf++) {
/* generate leaf */
otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withOTSAddress(indexLeaf).withChainAddress(otsHashAddress.getChainAddress())
.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
.build();
/*
* import WOTSPlusSecretKey as its needed to calculate the public
* key on the fly
*/
wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
lTreeAddress = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())
.withTreeAddress(lTreeAddress.getTreeAddress()).withLTreeAddress(indexLeaf)
.withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex())
.withKeyAndMask(lTreeAddress.getKeyAndMask()).build();
XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeIndex(indexLeaf)
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight()) {
/* add to authenticationPath if leafIndex == 1 */
int indexOnHeight = ((int) Math.floor(indexLeaf / (1 << node.getHeight())));
if (indexOnHeight == 1) {
authenticationPath.add(node.clone());
}
/* store next right authentication node */
if (indexOnHeight == 3 && node.getHeight() < (treeHeight - k)) {
treeHashInstances.get(node.getHeight()).setNode(node.clone());
}
if (indexOnHeight >= 3 && (indexOnHeight & 1) == 1 && node.getHeight() >= (treeHeight - k)
&& node.getHeight() <= (treeHeight - 2)) {
if (retain.get(node.getHeight()) == null) {
LinkedList queue = new LinkedList();
queue.add(node.clone());
retain.put(node.getHeight(), queue);
} else {
retain.get(node.getHeight()).add(node.clone());
}
}
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight())
.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
node = xmss.randomizeHash(stack.pop(), node, hashTreeAddress);
node = new XMSSNode(node.getHeight() + 1, node.getValue());
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight() + 1)
.withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())
.build();
}
/* push to stack */
stack.push(node);
}
root = stack.pop();
return root.clone();
}
protected void nextAuthenticationPath(OTSHashAddress otsHashAddress) {
if (otsHashAddress == null) {
throw new NullPointerException("otsHashAddress == null");
}
if (index > ((1 << treeHeight) - 2)) {
throw new IllegalStateException("index out of bounds");
}
/* prepare addresses */
LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.build();
HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.build();
/* determine tau */
int tau = XMSSUtil.calculateTau(index, treeHeight);
/* parent of leaf on height tau+1 is a left node */
if (((index >> (tau + 1)) & 1) == 0 && (tau < (treeHeight - 1))) {
keep.put(tau, authenticationPath.get(tau).clone());
}
/* leaf is a left node */
if (tau == 0) {
otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withOTSAddress(index).withChainAddress(otsHashAddress.getChainAddress())
.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
.build();
/*
* import WOTSPlusSecretKey as its needed to calculate the public
* key on the fly
*/
wotsPlus.importKeys(xmss.getWOTSPlusSecretKey(otsHashAddress), xmss.getPublicSeed());
WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress);
lTreeAddress = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())
.withTreeAddress(lTreeAddress.getTreeAddress()).withLTreeAddress(index)
.withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex())
.withKeyAndMask(lTreeAddress.getKeyAndMask()).build();
XMSSNode node = xmss.lTree(wotsPlusPublicKey, lTreeAddress);
authenticationPath.set(0, node);
} else {
/* add new left node on height tau to authentication path */
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeHeight(tau - 1)
.withTreeIndex(index >> tau).withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
XMSSNode node = xmss.randomizeHash(authenticationPath.get(tau - 1), keep.get(tau - 1), hashTreeAddress);
node = new XMSSNode(node.getHeight() + 1, node.getValue());
authenticationPath.set(tau, node);
keep.remove(tau - 1);
/* add new right nodes to authentication path */
for (int height = 0; height < tau; height++) {
if (height < (treeHeight - k)) {
authenticationPath.set(height, treeHashInstances.get(height).tailNode.clone());
} else {
authenticationPath.set(height, retain.get(height).removeFirst());
}
}
/* reinitialize treehash instances */
int minHeight = Math.min(tau, treeHeight - k);
for (int height = 0; height < minHeight; height++) {
int startIndex = index + 1 + (3 * (1 << height));
if (startIndex < (1 << treeHeight)) {
treeHashInstances.get(height).initialize(startIndex);
}
}
}
/* update treehash instances */
for (int i = 0; i < (treeHeight - k) >> 1; i++) {
TreeHash treeHash = getTreeHashInstanceForUpdate();
if (treeHash != null) {
treeHash.update(otsHashAddress);
}
}
index++;
}
private TreeHash getTreeHashInstanceForUpdate() {
TreeHash ret = null;
for (TreeHash treeHash : treeHashInstances) {
if (treeHash.isFinished() || !treeHash.isInitialized()) {
continue;
}
if (ret == null) {
ret = treeHash;
continue;
}
if (treeHash.getHeight() < ret.getHeight()) {
ret = treeHash;
continue;
}
if (treeHash.getHeight() == ret.getHeight()) {
if (treeHash.getIndexLeaf() < ret.getIndexLeaf()) {
ret = treeHash;
}
}
}
return ret;
}
protected void validate() {
if (treeHeight != xmss.getParams().getHeight()) {
throw new IllegalStateException("wrong height");
}
if (authenticationPath == null) {
throw new IllegalStateException("authenticationPath == null");
}
if (retain == null) {
throw new IllegalStateException("retain == null");
}
if (stack == null) {
throw new IllegalStateException("stack == null");
}
if (treeHashInstances == null) {
throw new IllegalStateException("treeHashInstances == null");
}
if (keep == null) {
throw new IllegalStateException("keep == null");
}
if (!XMSSUtil.isIndexValid(treeHeight, index)) {
throw new IllegalStateException("index in BDS state out of bounds");
}
}
protected int getTreeHeight() {
return treeHeight;
}
protected XMSSNode getRoot() {
return root.clone();
}
protected List getAuthenticationPath() {
List authenticationPath = new ArrayList();
for (XMSSNode node : this.authenticationPath) {
authenticationPath.add(node.clone());
}
return authenticationPath;
}
protected void setXMSS(XMSS xmss) {
this.xmss = xmss;
this.wotsPlus = xmss.getWOTSPlus();
}
protected int getIndex() {
return index;
}
}