All Downloads are FREE. Search and download functionalities are using the official Maven repository.

java-ec.com.unbound.common.crypto.ec.Curve Maven / Gradle / Ivy

Go to download

This is a collection of JAVA libraries that implement Unbound cryptographic classes for JAVA provider, PKCS11 wrapper, cryptoki, and advapi

There is a newer version: 42761
Show newest version
package com.unbound.common.crypto.ec;

import com.unbound.common.crypto.ec.math.UInt;

import java.math.BigInteger;
import java.security.spec.ECPoint;

abstract class Curve
{
  static final int MAX_LEN = 17;
  protected final int bits;
  protected final int length;
  protected final int invp;
  protected final int a;
  protected final int[] b;
  protected final int[] p;
  protected final int[] q;

  protected final int[] ONE;
  protected final int[] MONT_ONE;
  protected final int[] MONT_RR;

  protected final BigInteger aa;
  protected final BigInteger bb;
  protected final BigInteger pp;
  protected final BigInteger qq;
  protected final Point G;

  public static Curve getSecP256R1()
  {
    return SecP256R1.getInstance();
  }
  public static Curve getSecP384R1()
  {
    return SecP384R1.getInstance();
  }
  public static Curve getSecP521R1()
  {
    return SecP521R1.getInstance();
  }
  public static Curve getSecP256K1() { return SecP256K1.getInstance(); }

  protected Curve(
    int bits,
    int[] p,
    int invp,
    int[] MONT_ONE,
    int[] MONT_RR,
    int a,
    int[] b,
    int[] Gx,
    int[] Gy,
    int[] q)
  {
    this.bits = bits;
    this.length = (bits + 31) / 32;
    this.p = p;
    this.invp = invp;
    this.MONT_ONE = MONT_ONE;
    this.MONT_RR = MONT_RR;
    this.a = a;
    this.b = b;
    this.q = q;

    ONE = alloc(); ONE[0] = 1;

    pp = UInt.toBigInteger(length, p);
    aa = a<0 ? pp.add(BigInteger.valueOf(a)) : BigInteger.valueOf(a);
    bb = UInt.toBigInteger(length, b);
    qq = UInt.toBigInteger(length, q);
    G = new Point(this, Gx, Gy);
  }

  static final class IntPool
  {
    static final int BASE = 20;
    static final int MAX = 25;
    private final int[][] table = new int[MAX][];
    final long[] mp = new long[MAX];

    IntPool()
    {
      for (int i=0; i>> 32;
            A >>>= 32;

      for (int index=1; index>> 32;
        x    = (x & 0xffffffffL) + (rr[index] & 0xffffffffL);
        xHi += x >>> 32;
        x    = A + (x & 0xffffffffL);
        A    = xHi + (x >>> 32);

        x2   = mp[index];
        xHi  = x2 >>> 32;
        x    = (x & 0xffffffffL) + (x2 & 0xffffffffL);
        xHi += x >>> 32;
        x    = B + (x & 0xffffffffL);
        B    = xHi + (x >>> 32);

        rr[index-1] = (int)x;
      }

      B           += ((A + rs) & 0xffffffffL);
      rs           = (A >>> 32) + (B >>> 32);
      rr[length-1] = (int)B;
    }

    int cond = (int)rs;
    int[] r_minus_p = IntPool.alloc(pool, IntPool.BASE+3);
    if (r_minus_p==null) r_minus_p = alloc();
    int borrow = UInt.sub(length, r_minus_p, rr, p);

    cond |= 1-borrow;
    cond |= UInt.isZero(length, r_minus_p);

    UInt.cmov(length, r, cond, rr, r_minus_p);
  }

  protected void prepareMontMul(long m, long[] mp)
  {
    for (int i=0; i0)
    {
      mul(ZZZZ, ZZZZ, p.x, null);
      for (int i=0; i oct.length) return null;
      x = UInt.fromBytes(length, oct, offset+1, bytes);
      y = UInt.fromBytes(length, oct, offset+1+bytes, bytes);
    }
    else if (h==2 || h==3)
    {
      if (offset + 1 + bytes*2 > oct.length) return null;
      x = UInt.fromBytes(length, oct, offset+1, bytes);
      y = recoverY(x, h==3);
    }
    else return null;

    return new Point(this, x, y);
}

  int[] recoverY(int[] x, boolean bit0)
  {
    BigInteger xx = UInt.toBigInteger(length, x);
    BigInteger a = getA();
    BigInteger b = getB();
    BigInteger p = getP();
    BigInteger y2 = xx.multiply(xx).add(a).multiply(xx).add(b).mod(p);
    BigInteger y = Arithmetic.sqrtP(y2, p);

    if (bit0 != y.testBit(0))
    {
      y = p.subtract(y).mod(p);
    }
    return UInt.fromBigInteger(length, y);
  }

  void add(Point r, Point p1, Point p2, IntPool pool)
  {
    if (p2.isInfinity()) { Point.copy(r, p1); return; }
    if (p1.isInfinity()) { Point.copy(r, p2); return; }

    int[] X1 = p1.x;
    int[] Y1 = p1.y;
    int[] Z1 = p1.z;

    int[] X2 = p2.x;
    int[] Y2 = p2.y;
    int[] Z2 = p2.z;

    int[] X3 = r.x;
    int[] Y3 = r.y;
    int[] Z3 = r.z;

    int[] Z1Z1 = IntPool.alloc(pool, 9);  sqr( Z1Z1, Z1         , pool); // Z1Z1 = Z1^2
    int[] Z2Z2 = IntPool.alloc(pool, 10); sqr( Z2Z2, Z2         , pool); // Z2Z2 = Z2^2
    int[] U1   = IntPool.alloc(pool, 11); mul( U1,   X1,   Z2Z2 , pool); // U1 = X1*Z2Z2
    int[] U2   = IntPool.alloc(pool, 12); mul( U2,   X2,   Z1Z1 , pool); // U2 = X2*Z1Z1
    int[] S1   = IntPool.alloc(pool, 13); mul( S1,   Y1,   Z2   , pool); // S1 = Y1*Z2*Z2Z2
                                                mul( S1,   Z2Z2       , pool);
    int[] S2   = IntPool.alloc(pool, 14); mul( S2,   Y2,   Z1   , pool);
                                                mul( S2,   Z1Z1       , pool); // S2 = Y2*Z1*Z1Z1

    if (UInt.equ(length, U1, U2))
    {
      if (!UInt.equ(length, S1, S2)) Point.copy(r, infinity());
      else dbl(r, p1, pool);
      return;
    }

    int[]  H   = IntPool.alloc(pool, 15);   sub( H,    U2,   U1 , pool);  // H = U2-U1
    int[]  VTI = IntPool.alloc(pool, 16);   add( VTI,  H,    H  , pool);
                                                  sqr( VTI            , pool);  // I = (2*H)^2
    int[]  J   = IntPool.alloc(pool, 17);   mul( J,    H,    VTI, pool);  // J = H*I
    int[]  R   = IntPool.alloc(pool, 18);   sub( R,    S2,   S1 , pool);
                                                  add( R,    R        , pool);  // r = 2*(S2-S1)
                                                  mul( VTI,  U1       , pool);  // V = U1*I
                                                  sqr( X3,   R        , pool);
                                                  sub( X3,   J        , pool);
                                                  sub( X3,   VTI      , pool);  // X3 = R^2 - J - 2*V
                                                  sub( X3,   VTI      , pool);
                                                  sub( Y3,   VTI,  X3 , pool);
                                                  mul( Y3,   R        , pool);
                                                  mul( VTI,  S1,   J  , pool);
                                                  sub( Y3,   VTI      , pool);
                                                  sub( Y3,   VTI      , pool);  // Y3 = R*(V-X3) - 2*S1*J
                                                  add( Z3,   Z1,   Z2 , pool);
                                                  sqr( Z3             , pool);
                                                  sub( Z3,   Z1Z1     , pool);
                                                  sub( Z3,   Z2Z2     , pool);
                                                  mul( Z3,   H        , pool);  // Z3 = ((Z1+Z2)^2 - Z1Z1 - Z2Z2) * H
  }

  void dbl(Point r, Point p, IntPool pool)
  {
    int[] X = p.x;
    int[] Y = p.y;
    int[] Z = p.z;

    int[] RX = r.x;
    int[] RY = r.y;
    int[] RZ = r.z;

    if (a==0)
    {
      int[]  AE  = IntPool.alloc(pool, 0);      sqr     ( AE,  X       , pool);  // A = X^2
      int[]  BC8 = IntPool.alloc(pool, 1);      sqr     ( BC8, Y       , pool);  // B = Y^2
      int[]  CF  = IntPool.alloc(pool, 2);      sqr     ( CF,  BC8     , pool);  // C = B^2
      int[]  D   = IntPool.alloc(pool, 3);      add     ( D,   X,   BC8, pool);
                                                      mul8    ( BC8, CF      , pool);  // C8 = 8*C
                                                      sqr     ( D            , pool);
                                                      sub     ( D,   AE      , pool);
                                                      sub     ( D,   CF      , pool);
                                                      add     ( D,   D       , pool);  // D = 2 * ((X+B)^2 - A - C)
                                                      mul3    ( AE           , pool);  // E = 3 * A
                                                      sqr     ( CF,  AE      , pool);  // F = E^2
                                                      sub     ( RX,  CF,  D  , pool);
                                                      sub     ( RX,  D       , pool);  // RX = F - 2 * D
                                                      mul     ( RZ,  Y,   Z  , pool);
                                                      add     ( RZ,  RZ      , pool);  // RZ = 2 * Y * Z
                                                      sub     ( RY,  D,   RX , pool);
                                                      mul     ( RY,  AE      , pool);
                                                      sub     ( RY,  BC8     , pool);  // RY = E * (D - RX) - 8 * C
    }
    else if (a==-3)
    {

      int[]  DELTA = IntPool.alloc(pool,4);     sqr     (DELTA,      Z           , pool);  // delta==Z1^2
      int[]  GAMMA = IntPool.alloc(pool,5);     sqr     (GAMMA,      Y           , pool);  // gamma==Z2^2
      int[]  BETA  = IntPool.alloc(pool,6);     mul     (BETA,       X,    GAMMA , pool);  // beta = X1*gamma
      int[]  TEMP  = IntPool.alloc(pool,7);     sub     (TEMP,       X,    DELTA , pool);
      int[]  ALPHA = IntPool.alloc(pool,8);     add     (ALPHA,      X,    DELTA , pool);
                                                      mul     (ALPHA,      TEMP        , pool);
                                                      mul3    (ALPHA                   , pool);  // alpha = 3*(X1-delta)*(X1+delta)
                                                      mul8    (TEMP,       BETA        , pool);
                                                      sqr     (RX,         ALPHA       , pool);
                                                      sub     (RX,         TEMP        , pool);  // X3 = alpha^2-8*beta
                                                      add     (RZ,         Y,    Z     , pool);
                                                      sqr     (RZ                      , pool);
                                                      sub     (RZ,         GAMMA       , pool);
                                                      sub     (RZ,         DELTA       , pool);  //Z3 = (Y1+Z1)^2-gamma-delta
                                                      sqr     (TEMP,       GAMMA       , pool);
                                                      mul8    (TEMP                    , pool);
                                                      mul4    (RY,         BETA        , pool);
                                                      sub     (RY,         RX          , pool);
                                                      mul     (RY,         ALPHA       , pool);
                                                      sub     (RY,         TEMP        , pool);  // Y3 = alpha*(4*beta-X3)-8*gamma^2
    }
    else throw new UnsupportedOperationException("Unsupported curve");
  }

  void dbl(Point p, IntPool pool)
  {
    dbl(p, p, pool);
  }

  void add(Point r, Point p, IntPool pool)
  {
    add(r, r, p, pool);
  }

  void mul(Point r, Point p, int[] x)
  {
    IntPool pool = new IntPool();
    Point[] row = new Point[16];
    row[ 0] = infinity();
    row[ 1] = p;

    dbl(row[ 2] = new Point(this), row[ 1]   , pool);
    add(row[ 3] = new Point(this), row[ 2], p, pool);
    dbl(row[ 4] = new Point(this), row[ 2]   , pool);
    add(row[ 5] = new Point(this), row[ 4], p, pool);
    dbl(row[ 6] = new Point(this), row[ 3]   , pool);
    add(row[ 7] = new Point(this), row[ 6], p, pool);
    dbl(row[ 8] = new Point(this), row[ 4]   , pool);
    add(row[ 9] = new Point(this), row[ 8], p, pool);
    dbl(row[10] = new Point(this), row[ 5]   , pool);
    add(row[11] = new Point(this), row[10], p, pool);
    dbl(row[12] = new Point(this), row[ 6]   , pool);
    add(row[13] = new Point(this), row[12], p, pool);
    dbl(row[14] = new Point(this), row[ 7]   , pool);
    add(row[15] = new Point(this), row[14], p, pool);

    Point q = new Point(this);

    for (int i=length-1; i>=0; i--)
    {
      int v = x[i];
      int v0 =  v      & 0x0f;
      int v1 = (v>>4)  & 0x0f;
      int v2 = (v>>8)  & 0x0f;
      int v3 = (v>>12) & 0x0f;
      int v4 = (v>>16) & 0x0f;
      int v5 = (v>>20) & 0x0f;
      int v6 = (v>>24) & 0x0f;
      int v7 = (v>>28) & 0x0f;

      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v7], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v6], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v5], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v4], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v3], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v2], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v1], pool);
      dbl(q, pool); dbl(q, pool); dbl(q, pool); dbl(q, pool); add(q, row[v0], pool);
    }

    Point.copy(r, q);
  }


  private static class Arithmetic
  {
    public static final BigInteger ZERO = BigInteger.valueOf(0);
    public static final BigInteger ONE = BigInteger.valueOf(1);
    public static final BigInteger TWO = BigInteger.valueOf(2);

    private static final int[] jacobiTable = {0, 1, 0, -1, 0, -1, 0, 1};

    public static int jacobi(BigInteger A, BigInteger B)
    {
      BigInteger a, b, v;
      long k = 1;

      // test trivial cases
      if (B.equals(ZERO)) return A.abs().equals(ONE) ? 1 : 0;
      if (!A.testBit(0) && !B.testBit(0)) return 0;

      a = A; b = B;

      if (b.signum() == -1)
      { // b < 0
        b = b.negate(); // b = -b
        if (a.signum() == -1) k = -1;
      }

      v = ZERO;
      while (!b.testBit(0))
      {
        v = v.add(ONE); // v = v + 1
        b = b.divide(TWO); // b = b/2
      }

      if (v.testBit(0)) k = k * jacobiTable[a.intValue() & 7];

      if (a.signum() < 0)
      { // a < 0
        if (b.testBit(1)) k = -k; // k = -k
        a = a.negate(); // a = -a
      }

      // main loop
      while (a.signum() != 0)
      {
        v = ZERO;
        while (!a.testBit(0))
        { // a is even
          v = v.add(ONE);
          a = a.divide(TWO);
        }
        if (v.testBit(0)) k = k * jacobiTable[b.intValue() & 7];

        if (a.compareTo(b) < 0)
        { // a < b
          // swap and correct intermediate result
          BigInteger x = a; a = b; b = x;
          if (a.testBit(1) && b.testBit(1)) k = -k;
        }
        a = a.subtract(b);
      }

      return b.equals(ONE) ? (int) k : 0;
    }

    public static BigInteger sqrtP(BigInteger a, BigInteger p) throws IllegalArgumentException
    {
      BigInteger v = null;

      if (a.compareTo(ZERO) < 0) a = a.add(p);
      if (a.equals(ZERO)) return ZERO;
      if (p.equals(TWO)) return a;

      // p = 3 mod 4
      if (p.testBit(0) && p.testBit(1))
      {
        if (jacobi(a, p) == 1)
        { // a quadr. residue mod p
          v = p.add(ONE); // v = p+1
          v = v.shiftRight(2); // v = v/4
          return a.modPow(v, p); // return a^v mod p
          // return --> a^((p+1)/4) mod p
        }
        throw new IllegalArgumentException("No quadratic residue: " + a + ", " + p);
      }

      long t = 0;

      // initialization
      // compute k and s, where p = 2^s (2k+1) +1

      BigInteger k = p.subtract(ONE); // k = p-1
      long s = 0;
      while (!k.testBit(0))
      { // while k is even
        s++; // s = s+1
        k = k.shiftRight(1); // k = k/2
      }

      k = k.subtract(ONE); // k = k - 1
      k = k.shiftRight(1); // k = k/2

      // initial values
      BigInteger r = a.modPow(k, p); // r = a^k mod p

      BigInteger n = r.multiply(r).remainder(p); // n = r^2 % p
      n = n.multiply(a).remainder(p); // n = n * a % p
      r = r.multiply(a).remainder(p); // r = r * a %p

      if (n.equals(ONE)) return r;

      // non-quadratic residue
      BigInteger z = TWO; // z = 2
      while (jacobi(z, p) == 1)
      {
        // while z quadratic residue
        z = z.add(ONE); // z = z + 1
      }

      v = k;
      v = v.multiply(TWO); // v = 2k
      v = v.add(ONE); // v = 2k + 1
      BigInteger c = z.modPow(v, p); // c = z^v mod p

      // iteration
      while (n.compareTo(ONE) > 0)
      { // n > 1
        k = n; // k = n
        t = s; // t = s
        s = 0;

        while (!k.equals(ONE))
        { // k != 1
          k = k.multiply(k).mod(p); // k = k^2 % p
          s++; // s = s + 1
        }

        t -= s; // t = t - s
        if (t == 0) throw new IllegalArgumentException("No quadratic residue: " + a + ", " + p);

        v = ONE;
        for (long i = 0; i < t - 1; i++)
        {
          v = v.shiftLeft(1); // v = 1 * 2^(t - 1)
        }
        c = c.modPow(v, p); // c = c^v mod p
        r = r.multiply(c).remainder(p); // r = r * c % p
        c = c.multiply(c).remainder(p); // c = c^2 % p
        n = n.multiply(c).mod(p); // n = n * c % p
      }
      return r;
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy