All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.crypto.agreement.SM2KeyExchange Maven / Gradle / Ivy

Go to download

The Bouncy Castle Crypto package is a Java implementation of cryptographic algorithms. This jar contains JCE provider and lightweight API for the Bouncy Castle Cryptography APIs for Java 1.8 and later with debug enabled.

The newest version!
package org.bouncycastle.crypto.agreement;

import java.math.BigInteger;

import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoServicesRegistrar;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ParametersWithID;
import org.bouncycastle.crypto.params.SM2KeyExchangePrivateParameters;
import org.bouncycastle.crypto.params.SM2KeyExchangePublicParameters;
import org.bouncycastle.math.ec.ECAlgorithms;
import org.bouncycastle.math.ec.ECFieldElement;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Memoable;
import org.bouncycastle.util.Pack;

/**
 * SM2 Key Exchange protocol - based on https://tools.ietf.org/html/draft-shen-sm2-ecdsa-02
 */
public class SM2KeyExchange
{
    private final Digest digest;

    private byte[] userID;
    private ECPrivateKeyParameters staticKey;
    private ECPoint staticPubPoint;
    private ECPoint ephemeralPubPoint;
    private ECDomainParameters ecParams;
    private int w;
    private ECPrivateKeyParameters ephemeralKey;
    private boolean initiator;

    public SM2KeyExchange()
    {
        this(new SM3Digest());
    }

    public SM2KeyExchange(Digest digest)
    {
        this.digest = digest;
    }

    public void init(
        CipherParameters privParam)
    {
        SM2KeyExchangePrivateParameters baseParam;

        if (privParam instanceof ParametersWithID)
        {
            baseParam = (SM2KeyExchangePrivateParameters)((ParametersWithID)privParam).getParameters();
            userID = ((ParametersWithID)privParam).getID();
        }
        else
        {
            baseParam = (SM2KeyExchangePrivateParameters)privParam;
            userID = new byte[0];
        }

        initiator = baseParam.isInitiator();
        staticKey = baseParam.getStaticPrivateKey();
        ephemeralKey = baseParam.getEphemeralPrivateKey();
        ecParams = staticKey.getParameters();
        staticPubPoint = baseParam.getStaticPublicPoint();
        ephemeralPubPoint = baseParam.getEphemeralPublicPoint();

        w = ecParams.getCurve().getFieldSize() / 2 - 1;

        CryptoServicesRegistrar.checkConstraints(Utils.getDefaultProperties("SM2KE", this.staticKey));
    }

    public byte[] calculateKey(int kLen, CipherParameters pubParam)
    {
        SM2KeyExchangePublicParameters otherPub;
        byte[] otherUserID;

        if (pubParam instanceof ParametersWithID)
        {
            otherPub = (SM2KeyExchangePublicParameters)((ParametersWithID)pubParam).getParameters();
            otherUserID = ((ParametersWithID)pubParam).getID();
        }
        else
        {
            otherPub = (SM2KeyExchangePublicParameters)pubParam;
            otherUserID = new byte[0];
        }

        byte[] za = getZ(digest, userID, staticPubPoint);
        byte[] zb = getZ(digest, otherUserID, otherPub.getStaticPublicKey().getQ());

        ECPoint U = calculateU(otherPub);

        byte[] rv;
        if (initiator)
        {
            rv = kdf(U, za, zb, kLen);
        }
        else
        {
            rv = kdf(U, zb, za, kLen);
        }

        return rv;
    }

    public byte[][] calculateKeyWithConfirmation(int kLen, byte[] confirmationTag, CipherParameters pubParam)
    {
        SM2KeyExchangePublicParameters otherPub;
        byte[] otherUserID;

        if (pubParam instanceof ParametersWithID)
        {
            otherPub = (SM2KeyExchangePublicParameters)((ParametersWithID)pubParam).getParameters();
            otherUserID = ((ParametersWithID)pubParam).getID();
        }
        else
        {
            otherPub = (SM2KeyExchangePublicParameters)pubParam;
            otherUserID = new byte[0];
        }

        if (initiator && confirmationTag == null)
        {
            throw new IllegalArgumentException("if initiating, confirmationTag must be set");
        }
        
        byte[] za = getZ(digest, userID, staticPubPoint);
        byte[] zb = getZ(digest, otherUserID, otherPub.getStaticPublicKey().getQ());

        ECPoint U = calculateU(otherPub);

        byte[] rv;
        if (initiator)
        {
            rv = kdf(U, za, zb, kLen);

            byte[] inner = calculateInnerHash(digest, U, za, zb, ephemeralPubPoint, otherPub.getEphemeralPublicKey().getQ());

            byte[] s1 = S1(digest, U, inner);

            if (!Arrays.constantTimeAreEqual(s1, confirmationTag))
            {
                throw new IllegalStateException("confirmation tag mismatch");
            }
            
            return new byte[][] { rv, S2(digest, U, inner)};
        }
        else
        {
            rv = kdf(U, zb, za, kLen);

            byte[] inner = calculateInnerHash(digest, U, zb, za, otherPub.getEphemeralPublicKey().getQ(), ephemeralPubPoint);

            return new byte[][] { rv, S1(digest, U, inner), S2(digest, U, inner) };
        }
    }

    private ECPoint calculateU(SM2KeyExchangePublicParameters otherPub)
    {
        ECDomainParameters params = staticKey.getParameters();

        ECPoint p1 = ECAlgorithms.cleanPoint(params.getCurve(), otherPub.getStaticPublicKey().getQ());
        ECPoint p2 = ECAlgorithms.cleanPoint(params.getCurve(), otherPub.getEphemeralPublicKey().getQ());

        BigInteger x1 = reduce(ephemeralPubPoint.getAffineXCoord().toBigInteger());
        BigInteger x2 = reduce(p2.getAffineXCoord().toBigInteger());
        BigInteger tA = staticKey.getD().add(x1.multiply(ephemeralKey.getD()));
        BigInteger k1 = ecParams.getH().multiply(tA).mod(ecParams.getN());
        BigInteger k2 = k1.multiply(x2).mod(ecParams.getN());

        return ECAlgorithms.sumOfTwoMultiplies(p1, k1, p2, k2).normalize();
    }

    private byte[] kdf(ECPoint u, byte[] za, byte[] zb, int klen)
    {
         int digestSize = digest.getDigestSize();
         byte[] buf = new byte[Math.max(4, digestSize)];
         byte[] rv = new byte[(klen + 7) / 8];
         int off = 0;

         Memoable memo = null;
         Memoable copy = null;

         if (digest instanceof Memoable)
         {
             addFieldElement(digest, u.getAffineXCoord());
             addFieldElement(digest, u.getAffineYCoord());
             digest.update(za, 0, za.length);
             digest.update(zb, 0, zb.length);
             memo = (Memoable)digest;
             copy = memo.copy();
         }

         int ct = 0;

         while (off < rv.length)
         {
             if (memo != null)
             {
                 memo.reset(copy);
             }
             else
             {
                 addFieldElement(digest, u.getAffineXCoord());
                 addFieldElement(digest, u.getAffineYCoord());
                 digest.update(za, 0, za.length);
                 digest.update(zb, 0, zb.length);
             }

             Pack.intToBigEndian(++ct, buf, 0);
             digest.update(buf, 0, 4);
             digest.doFinal(buf, 0);

             int copyLen = Math.min(digestSize, rv.length - off);
             System.arraycopy(buf, 0, rv, off, copyLen);
             off += copyLen;
         }

         return rv;
    }

    //x1~=2^w+(x1 AND (2^w-1))
    private BigInteger reduce(BigInteger x)
    {
        return x.and(BigInteger.valueOf(1).shiftLeft(w).subtract(BigInteger.valueOf(1))).setBit(w);
    }

    private byte[] S1(Digest digest, ECPoint u, byte[] inner)
    {
        digest.update((byte)0x02);
        addFieldElement(digest, u.getAffineYCoord());
        digest.update(inner, 0, inner.length);

        return digestDoFinal();
    }

    private byte[] calculateInnerHash(Digest digest, ECPoint u, byte[] za, byte[] zb, ECPoint p1, ECPoint p2)
    {
        addFieldElement(digest, u.getAffineXCoord());
        digest.update(za, 0, za.length);
        digest.update(zb, 0, zb.length);
        addFieldElement(digest, p1.getAffineXCoord());
        addFieldElement(digest, p1.getAffineYCoord());
        addFieldElement(digest, p2.getAffineXCoord());
        addFieldElement(digest, p2.getAffineYCoord());

        return digestDoFinal();
    }

    private byte[] S2(Digest digest, ECPoint u, byte[] inner)
    {
        digest.update((byte)0x03);
        addFieldElement(digest, u.getAffineYCoord());
        digest.update(inner, 0, inner.length);

        return digestDoFinal();
    }

    private byte[] getZ(Digest digest, byte[] userID, ECPoint pubPoint)
    {
        addUserID(digest, userID);

        addFieldElement(digest, ecParams.getCurve().getA());
        addFieldElement(digest, ecParams.getCurve().getB());
        addFieldElement(digest, ecParams.getG().getAffineXCoord());
        addFieldElement(digest, ecParams.getG().getAffineYCoord());
        addFieldElement(digest, pubPoint.getAffineXCoord());
        addFieldElement(digest, pubPoint.getAffineYCoord());

        return digestDoFinal();
    }

    private void addUserID(Digest digest, byte[] userID)
    {
        int len = userID.length * 8;

        digest.update((byte)(len >>> 8));
        digest.update((byte)len);
        digest.update(userID, 0, userID.length);
    }

    private void addFieldElement(Digest digest, ECFieldElement v)
    {
        byte[] p = v.getEncoded();
        digest.update(p, 0, p.length);
    }

    private byte[] digestDoFinal()
    {
        byte[] result = new byte[digest.getDigestSize()];
        digest.doFinal(result, 0);
        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy