All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.pqc.crypto.xmss.WOTSPlus Maven / Gradle / Ivy

Go to download

The Bouncy Castle Crypto package is a Java implementation of cryptographic algorithms. This jar contains JCE provider and lightweight API for the Bouncy Castle Cryptography APIs for JDK 1.5 to JDK 1.8.

The newest version!
package org.bouncycastle.pqc.crypto.xmss;

import java.util.ArrayList;
import java.util.List;

/**
 * WOTS+.
 *
 */
public final class WOTSPlus {

	/**
	 * WOTS+ parameters.
	 */
	private final WOTSPlusParameters params;
	/**
	 * Randomization functions.
	 */
	private final KeyedHashFunctions khf;
	/**
	 * WOTS+ secret key seed.
	 */
	private byte[] secretKeySeed;
	/**
	 * WOTS+ public seed.
	 */
	private byte[] publicSeed;

	/**
	 * Constructs a new WOTS+ one-time signature system based on the given WOTS+
	 * parameters.
	 *
	 * @param params
	 *            Parameters for WOTSPlus object.
	 */
	protected WOTSPlus(WOTSPlusParameters params) {
		super();
		if (params == null) {
			throw new NullPointerException("params == null");
		}
		this.params = params;
		int n = params.getDigestSize();
		khf = new KeyedHashFunctions(params.getDigest(), n);
		secretKeySeed = new byte[n];
		publicSeed = new byte[n];
	}

	/**
	 * Import keys to WOTS+ instance.
	 *
	 * @param secretKeySeed
	 *            Secret key seed.
	 * @param publicSeed
	 *            Public seed.
	 */
	protected void importKeys(byte[] secretKeySeed, byte[] publicSeed) {
		if (secretKeySeed == null) {
			throw new NullPointerException("secretKeySeed == null");
		}
		if (secretKeySeed.length != params.getDigestSize()) {
			throw new IllegalArgumentException("size of secretKeySeed needs to be equal to size of digest");
		}
		if (publicSeed == null) {
			throw new NullPointerException("publicSeed == null");
		}
		if (publicSeed.length != params.getDigestSize()) {
			throw new IllegalArgumentException("size of publicSeed needs to be equal to size of digest");
		}
		this.secretKeySeed = secretKeySeed;
		this.publicSeed = publicSeed;
	}

	/**
	 * Creates a signature for the n-byte messageDigest.
	 *
	 * @param messageDigest
	 *            Digest to sign.
	 * @param otsHashAddress
	 *            OTS hash address for randomization.
	 * @return WOTS+ signature.
	 */
	protected WOTSPlusSignature sign(byte[] messageDigest, OTSHashAddress otsHashAddress) {
		if (messageDigest == null) {
			throw new NullPointerException("messageDigest == null");
		}
		if (messageDigest.length != params.getDigestSize()) {
			throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
		}
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		List baseWMessage = convertToBaseW(messageDigest, params.getWinternitzParameter(), params.getLen1());
		/* create checksum */
		int checksum = 0;
		for (int i = 0; i < params.getLen1(); i++) {
			checksum += params.getWinternitzParameter() - 1 - baseWMessage.get(i);
		}
		checksum <<= (8 - ((params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) % 8));
		int len2Bytes = (int) Math
				.ceil((double) (params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) / 8);
		List baseWChecksum = convertToBaseW(XMSSUtil.toBytesBigEndian(checksum, len2Bytes),
				params.getWinternitzParameter(), params.getLen2());

		/* msg || checksum */
		baseWMessage.addAll(baseWChecksum);

		/* create signature */
		byte[][] signature = new byte[params.getLen()][];
		for (int i = 0; i < params.getLen(); i++) {
			otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i)
					.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
					.build();
			signature[i] = chain(expandSecretKeySeed(i), 0, baseWMessage.get(i), otsHashAddress);
		}
		return new WOTSPlusSignature(params, signature);
	}

	/**
	 * Verifies signature on message.
	 *
	 * @param messageDigest
	 *            The digest that was signed.
	 * @param signature
	 *            Signature on digest.
	 * @param otsHashAddress
	 *            OTS hash address for randomization.
	 * @return true if signature was correct false else.
	 */
	protected boolean verifySignature(byte[] messageDigest, WOTSPlusSignature signature,
			OTSHashAddress otsHashAddress) {
		if (messageDigest == null) {
			throw new NullPointerException("messageDigest == null");
		}
		if (messageDigest.length != params.getDigestSize()) {
			throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
		}
		if (signature == null) {
			throw new NullPointerException("signature == null");
		}
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		byte[][] tmpPublicKey = getPublicKeyFromSignature(messageDigest, signature, otsHashAddress).toByteArray();
		/* compare values */
		return XMSSUtil.compareByteArray(tmpPublicKey, getPublicKey(otsHashAddress).toByteArray()) ? true : false;
	}

	/**
	 * Calculates a public key based on digest and signature.
	 *
	 * @param messageDigest
	 *            The digest that was signed.
	 * @param signature
	 *            Signarure on digest.
	 * @param otsHashAddress
	 *            OTS hash address for randomization.
	 * @return WOTS+ public key derived from digest and signature.
	 */
	protected WOTSPlusPublicKeyParameters getPublicKeyFromSignature(byte[] messageDigest, WOTSPlusSignature signature,
			OTSHashAddress otsHashAddress) {
		if (messageDigest == null) {
			throw new NullPointerException("messageDigest == null");
		}
		if (messageDigest.length != params.getDigestSize()) {
			throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
		}
		if (signature == null) {
			throw new NullPointerException("signature == null");
		}
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		List baseWMessage = convertToBaseW(messageDigest, params.getWinternitzParameter(), params.getLen1());
		/* create checksum */
		int checksum = 0;
		for (int i = 0; i < params.getLen1(); i++) {
			checksum += params.getWinternitzParameter() - 1 - baseWMessage.get(i);
		}
		checksum <<= (8 - ((params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) % 8));
		int len2Bytes = (int) Math
				.ceil((double) (params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) / 8);
		List baseWChecksum = convertToBaseW(XMSSUtil.toBytesBigEndian(checksum, len2Bytes),
				params.getWinternitzParameter(), params.getLen2());

		/* msg || checksum */
		baseWMessage.addAll(baseWChecksum);

		byte[][] publicKey = new byte[params.getLen()][];
		for (int i = 0; i < params.getLen(); i++) {
			otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i)
					.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
					.build();
			publicKey[i] = chain(signature.toByteArray()[i], baseWMessage.get(i),
					params.getWinternitzParameter() - 1 - baseWMessage.get(i), otsHashAddress);
		}
		return new WOTSPlusPublicKeyParameters(params, publicKey);
	}

	/**
	 * Computes an iteration of F on an n-byte input using outputs of PRF.
	 *
	 * @param startHash
	 *            Starting point.
	 * @param startIndex
	 *            Start index.
	 * @param steps
	 *            Steps to take.
	 * @param otsHashAddress
	 *            OTS hash address for randomization.
	 * @return Value obtained by iterating F for steps times on input startHash,
	 *         using the outputs of PRF.
	 */
	private byte[] chain(byte[] startHash, int startIndex, int steps, OTSHashAddress otsHashAddress) {
		int n = params.getDigestSize();
		if (startHash == null) {
			throw new NullPointerException("startHash == null");
		}
		if (startHash.length != n) {
			throw new IllegalArgumentException("startHash needs to be " + n + "bytes");
		}
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		if (otsHashAddress.toByteArray() == null) {
			throw new NullPointerException("otsHashAddress byte array == null");
		}
		if ((startIndex + steps) > params.getWinternitzParameter() - 1) {
			throw new IllegalArgumentException("max chain length must not be greater than w");
		}

		if (steps == 0) {
			return startHash;
		}

		byte[] tmp = chain(startHash, startIndex, steps - 1, otsHashAddress);
		otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
				.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
				.withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress())
				.withHashAddress(startIndex + steps - 1).withKeyAndMask(0).build();
		byte[] key = khf.PRF(publicSeed, otsHashAddress.toByteArray());
		otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
				.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
				.withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress())
				.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(1).build();
		byte[] bitmask = khf.PRF(publicSeed, otsHashAddress.toByteArray());
		byte[] tmpMasked = new byte[n];
		for (int i = 0; i < n; i++) {
			tmpMasked[i] = (byte) (tmp[i] ^ bitmask[i]);
		}
		tmp = khf.F(key, tmpMasked);
		return tmp;
	}

	/**
	 * Obtain base w values from Input.
	 *
	 * @param messageDigest
	 *            Input data.
	 * @param w
	 *            Base.
	 * @param outLength
	 *            Length of output.
	 * @return outLength-length list of base w integers.
	 */
	private List convertToBaseW(byte[] messageDigest, int w, int outLength) {
		if (messageDigest == null) {
			throw new NullPointerException("msg == null");
		}
		if (w != 4 && w != 16) {
			throw new IllegalArgumentException("w needs to be 4 or 16");
		}
		int logW = XMSSUtil.log2(w);
		if (outLength > ((8 * messageDigest.length) / logW)) {
			throw new IllegalArgumentException("outLength too big");
		}

		ArrayList res = new ArrayList();
		for (int i = 0; i < messageDigest.length; i++) {
			for (int j = 8 - logW; j >= 0; j -= logW) {
				res.add((messageDigest[i] >> j) & (w - 1));
				if (res.size() == outLength) {
					return res;
				}
			}
		}
		return res;
	}

	/**
	 * Derive private key at index from secret key seed.
	 *
	 * @param index
	 *            Index.
	 * @return Private key at index.
	 */
	private byte[] expandSecretKeySeed(int index) {
		if (index < 0 || index >= params.getLen()) {
			throw new IllegalArgumentException("index out of bounds");
		}
		return khf.PRF(secretKeySeed, XMSSUtil.toBytesBigEndian(index, 32));
	}

	/**
	 * Getter parameters.
	 *
	 * @return params.
	 */
	protected WOTSPlusParameters getParams() {
		return params;
	}

	/**
	 * Getter keyed hash functions.
	 *
	 * @return keyed hash functions.
	 */
	protected KeyedHashFunctions getKhf() {
		return khf;
	}

	/**
	 * Getter secret key seed.
	 *
	 * @return secret key seed.
	 */
	protected byte[] getSecretKeySeed() {
		return XMSSUtil.cloneArray(getSecretKeySeed());
	}

	/**
	 * Getter public seed.
	 *
	 * @return public seed.
	 */
	protected byte[] getPublicSeed() {
		return XMSSUtil.cloneArray(publicSeed);
	}

	/**
	 * Getter private key.
	 *
	 * @return WOTS+ private key.
	 */
	protected WOTSPlusPrivateKeyParameters getPrivateKey() {
		byte[][] privateKey = new byte[params.getLen()][];
		for (int i = 0; i < privateKey.length; i++) {
			privateKey[i] = expandSecretKeySeed(i);
		}
		return new WOTSPlusPrivateKeyParameters(params, privateKey);
	}

	/**
	 * Calculates a new public key based on the state of secretKeySeed,
	 * publicSeed and otsHashAddress.
	 *
	 * @param otsHashAddress
	 *            OTS hash address for randomization.
	 * @return WOTS+ public key.
	 */
	protected WOTSPlusPublicKeyParameters getPublicKey(OTSHashAddress otsHashAddress) {
		if (otsHashAddress == null) {
			throw new NullPointerException("otsHashAddress == null");
		}
		byte[][] publicKey = new byte[params.getLen()][];
		/* derive public key from secretKeySeed */
		for (int i = 0; i < params.getLen(); i++) {
			otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
					.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
					.withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i)
					.withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())
					.build();
			publicKey[i] = chain(expandSecretKeySeed(i), 0, params.getWinternitzParameter() - 1, otsHashAddress);
		}
		return new WOTSPlusPublicKeyParameters(params, publicKey);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy