All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dyadicsec.provider.SecretKeyCipher Maven / Gradle / Ivy

package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;

import static com.dyadicsec.cryptoki.CK.*;

import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import java.security.InvalidKeyException;
import java.security.ProviderException;
import java.security.NoSuchAlgorithmException;
import java.security.InvalidAlgorithmParameterException;
import java.security.AlgorithmParameters;
import java.security.SecureRandom;
import java.security.GeneralSecurityException;
import java.security.KeyStoreException;
import java.security.Key;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.util.Arrays;
import java.util.LinkedList;


/**
 * Created by saar.peer on 29-Jun-16.
 */
public class SecretKeyCipher extends CipherSpi
{

  private SecretKey secretKey = null;

  private final int keyType;
  private boolean wrap = false;
  private boolean encrypt = true;
  private boolean singleOp = false;
  private AlgorithmParameterSpec paramSpec = null;
  private KeyParameters unwrapKeyParams = null;
  private boolean aad = false;
  private boolean padding = false;
  private boolean initialized = false;
  private byte[] buffer = null;
  private byte[] auth = null;
  private LinkedList siv_headers = null;
  private Session session = null;
  private CK_MECHANISM mechanism = null;

  private int mode = 0;
  private static final int ECB = 1;
  private static final int CBC = 2;
  private static final int CCM = 3;
  private static final int GCM = 4;
  private static final int OFB = 5;
  private static final int CFB = 6;
  private static final int CTR = 7;
  private static final int XTS = 8;
  private static final int SIV = 9;
  private static final int NIST = 10;

  private final static byte[] B0 = new byte[0];

  SecretKeyCipher(int keyType)
  {
    this.keyType = keyType;
  }

  private void checkValidAlg(int keyType) throws NoSuchAlgorithmException
  {
    if (this.keyType != keyType) throw new NoSuchAlgorithmException("Mode not supported");
  }


  @Override
  protected void engineSetMode(String mode) throws NoSuchAlgorithmException
  {
    mode = mode.toUpperCase();

    if (mode.equalsIgnoreCase("CCM"))
    {
      this.mode = CCM;
      checkValidAlg(CKK_AES);
      aad = true;
    }
    else if (mode.equalsIgnoreCase("GCM"))
    {
      this.mode = GCM;
      checkValidAlg(CKK_AES);
      aad = true;
    }
    else if (mode.equalsIgnoreCase("ECB")) this.mode = ECB;
    else if (mode.equalsIgnoreCase("CBC")) this.mode = CBC;
    else if (mode.equalsIgnoreCase("CTR"))
    {
      this.mode = CTR;
      checkValidAlg(CKK_AES);
    }
    else if (mode.equalsIgnoreCase("OFB64"))
    {
      this.mode = OFB;
      checkValidAlg(CKK_DES3);
    }
    else if (mode.equalsIgnoreCase("OFB128"))
    {
      this.mode = OFB;
      checkValidAlg(CKK_AES);
    }
    else if (mode.equalsIgnoreCase("CFB64"))
    {
      this.mode = CFB;
      checkValidAlg(CKK_DES3);
    }
    else if (mode.equalsIgnoreCase("CFB128"))
    {
      this.mode = CFB;
      checkValidAlg(CKK_AES);
    }
    else if (mode.equalsIgnoreCase("XTS"))
    {
      this.mode = XTS;
      checkValidAlg(DYCKK_AES_XTS);
      singleOp = true;
    }
    else if (mode.equalsIgnoreCase("SIV"))
    {
      this.mode = SIV;
      checkValidAlg(DYCKK_AES_SIV);
      singleOp = true;
      aad = true;
    }
    else if (mode.equalsIgnoreCase("WRAP"))
    {
      this.mode = NIST;
      checkValidAlg(CKK_AES);
      singleOp = true;
    }
    else throw new NoSuchAlgorithmException("Mode not supported: " + mode);
  }

  @Override
  protected void engineSetPadding(String padding) throws NoSuchPaddingException
  {
    if (padding.equalsIgnoreCase("NOPADDING")) this.padding = false;
    else if (padding.equalsIgnoreCase("PKCS5PADDING"))
    {
      if (mode != CBC && mode != NIST) throw new NoSuchPaddingException("padding not supported");
      this.padding = true;
    }
    else throw new NoSuchPaddingException("padding not supported");
  }

  @Override
  protected int engineGetBlockSize()
  {
    return keyType == CKK_DES3 ? 8 : 16;
  }

  @Override
  protected int engineGetOutputSize(int inputLen)
  {
    byte[] temp = new byte[(singleOp ? buffer.length : 0) + inputLen];
    return encdecLen(temp, 0, temp.length);
  }

  @Override
  protected byte[] engineGetIV()
  {
    if (paramSpec == null) return null;
    if (paramSpec instanceof IvParameterSpec) return ((IvParameterSpec) paramSpec).getIV();
    return null;
  }

  @Override
  protected int engineGetKeySize(Key key) throws InvalidKeyException
  {
    if (key instanceof SecretKey)
    {
      SecretKey secretKey = (SecretKey) key;
      try
      {
        return secretKey.getBitSize();
      }
      catch (KeyStoreException e)
      {
        throw new InvalidKeyException(e);
      }
    }

    byte[] encoded = key.getEncoded();
    if (encoded == null) throw new InvalidKeyException("Invalid key value");
    return encoded.length * 8;
  }

  @Override
  protected AlgorithmParameters engineGetParameters()
  {
    if (paramSpec == null) return null;

    try
    {
      AlgorithmParameters params = AlgorithmParameters.getInstance(keyType == CKK_DES3 ? "DESede" : "AES", "SunJCE");
      params.init(paramSpec);
      return params;
    }
    catch (GeneralSecurityException e)
    {
      throw new ProviderException("Could not encode parameters", e);
    }
  }

  private int getMechanismType()
  {
    switch (mode)
    {
      case CTR:
        return CKM_AES_CTR;
      case GCM:
        return CKM_AES_GCM;
      case CCM:
        return CKM_AES_CCM;
      case XTS:
        return DYCKM_AES_XTS;
      case SIV:
        return DYCKM_AES_SIV;
      case NIST:
        return padding ? CKM_AES_KEY_WRAP_PAD : CKM_AES_KEY_WRAP;
      case OFB:
        return keyType == CKK_AES ? CKM_AES_OFB : CKM_DES_OFB64;
      case CFB:
        return keyType == CKK_AES ? CKM_AES_CFB128 : CKM_DES_CFB64;
      case ECB:
        return keyType == CKK_AES ? CKM_AES_ECB : CKM_DES3_ECB;
      case CBC:
        return padding ?
                (keyType == CKK_AES ? CKM_AES_CBC_PAD : CKM_DES3_CBC_PAD) :
                (keyType == CKK_AES ? CKM_AES_CBC : CKM_DES3_CBC);
    }
    return -1;
  }

  private CK_MECHANISM getMechIV(IvParameterSpec spec)
  {
    int mechanismType = getMechanismType();
    if (mechanism != null && (mechanism.getType() == mechanismType))
    {
      mechanism.setBuffer(spec.getIV());
      return mechanism;
    }
    return new CK_MECHANISM(mechanismType, spec.getIV());
  }

  private CK_MECHANISM getMechSIV()
  {
    return new DYCK_AES_SIV_PARAMS(siv_headers == null ? null : (byte[][]) siv_headers.toArray());
  }


  private CK_MECHANISM getMechGCM(GCMParameterSpec spec)
  {
    if (mechanism != null && (mechanism.getType() == CKM_AES_GCM))
    {
      ((CK_GCM_PARAMS) mechanism).init(spec.getIV(), auth, spec.getTLen());
      return mechanism;
    }

    return new CK_GCM_PARAMS(spec.getIV(), auth, spec.getTLen());
  }

  private CK_MECHANISM getMechCCM(CCMParameterSpec spec)
  {
    return new CK_CCM_PARAMS(spec.getDataSize(), spec.getIV(), auth, spec.getTagSize());
  }

  private void ensureInitOperation()
  {
    try
    {
      ensureInit();
    }
    catch (InvalidKeyException e)
    {
      throw new ProviderException(e);
    }
    catch (InvalidAlgorithmParameterException e)
    {
      throw new ProviderException(e);
    }
  }

  private void initSession() throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    try
    {
      session = encrypt ? secretKey.pkcs11Key.encryptInit(mechanism) : secretKey.pkcs11Key.decryptInit(mechanism);
    }
    catch (CKException e)
    {
      if (e.getRV() == CKR_ARGUMENTS_BAD ||
              e.getRV() == CKR_MECHANISM_INVALID ||
              e.getRV() == CKR_MECHANISM_PARAM_INVALID) throw new InvalidAlgorithmParameterException(e);

      throw new InvalidKeyException(e);
    }
  }

  private void releaseSession()
  {
    if (session == null) return;
    secretKey.pkcs11Key.getSlot().releaseSession(session);
    session = null;
  }


  private void prepareMechanism() throws InvalidAlgorithmParameterException
  {
    mechanism = null;

    switch (mode)
    {
      case SIV:
        mechanism = getMechSIV();
        break;

      case GCM:
        mechanism = getMechGCM((GCMParameterSpec) paramSpec);
        break;

      case CCM:
        mechanism = getMechCCM((CCMParameterSpec) paramSpec);
        break;

      case CTR:
      case CBC:
      case OFB:
      case CFB:
      case XTS:
      case NIST:
        mechanism = getMechIV((IvParameterSpec) paramSpec);
        break;

      case ECB:
        mechanism = new CK_MECHANISM(getMechanismType());
        break;

      default:
        throw new InvalidAlgorithmParameterException("Invalid PKCS#11 mechanism");
    }
  }

  private void ensureInit() throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    if (initialized) return;
    buffer = null;
    if (session != null) session.close();
    session = null;

    prepareMechanism();

    if (!wrap)
    {
      initSession();
    }

    if (singleOp) buffer = B0;
    if (aad) auth = B0;

    initialized = true;
  }

  @Override
  protected void engineInit(int opmode, Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom)
          throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    initialized = false;

    if (!(key instanceof SecretKey)) throw new InvalidKeyException("Invalid key type");
    secretKey = (SecretKey) key;

    try
    {
      secretKey.save();
    }
    catch (KeyStoreException e)
    {
      throw new InvalidKeyException(e);
    }

    try
    {
      if (secretKey.getKeyType() != keyType) throw new InvalidKeyException("Invalid key type");
    }
    catch (KeyStoreException e)
    {
      throw new InvalidKeyException(e);
    }

    wrap = (opmode == Cipher.WRAP_MODE) || (opmode == Cipher.UNWRAP_MODE);
    encrypt = (opmode == Cipher.ENCRYPT_MODE) || (opmode == Cipher.WRAP_MODE);
    paramSpec = algorithmParameterSpec;

    if (opmode == Cipher.UNWRAP_MODE)
    {
      if ((paramSpec instanceof KeyGenSpec))
      {
        unwrapKeyParams = ((KeyGenSpec) paramSpec).params;
        paramSpec = ((KeyGenSpec) paramSpec).original;
      }
      else if (System.getProperty("ukc.provider.fastUnwrap").equals("1"))
      {
        unwrapKeyParams = new KeyParameters();
        unwrapKeyParams.setToken(false);
      }
    }

    switch (mode)
    {
      case SIV:
        if (!wrap) throw new InvalidAlgorithmParameterException("SIV doesn't support encrypt/decrypt");
        break;

      case GCM:
        if (!(paramSpec instanceof GCMParameterSpec))
          throw new InvalidAlgorithmParameterException("GCMParameterSpec required");
        break;

      case CCM:
        if (!(paramSpec instanceof CCMParameterSpec))
          throw new InvalidAlgorithmParameterException("CCMParameterSpec required");
        break;

      case NIST:
        if (paramSpec == null)
        {
          final byte[] iv = new byte[0];
          paramSpec = new IvParameterSpec(iv);
        }
        else
        {
          if (!(paramSpec instanceof IvParameterSpec))
            throw new InvalidAlgorithmParameterException("IvParameterSpec required");
          int ivLen = ((IvParameterSpec) paramSpec).getIV().length;
          if (ivLen != 8) throw new InvalidAlgorithmParameterException("Invalid IV length");
        }
        break;

      case CTR:
      case CBC:
      case OFB:
      case CFB:
      case XTS:
        if (wrap && mode == XTS) throw new InvalidAlgorithmParameterException("XTS doesn't support wrap/unwrap");
        if (paramSpec == null && encrypt)
        {
          if (secureRandom == null) throw new InvalidAlgorithmParameterException("Can't generate IV");
          int size = engineGetBlockSize();
          if (mode == NIST) size = 0;
          byte[] iv = new byte[size];
          secureRandom.nextBytes(iv);
          paramSpec = new IvParameterSpec(iv);
        }
        if (paramSpec == null || !(paramSpec instanceof IvParameterSpec))
          throw new InvalidAlgorithmParameterException("IvParameterSpec required");
        if (((IvParameterSpec) paramSpec).getIV().length != engineGetBlockSize())
          throw new InvalidAlgorithmParameterException("Invalid IV length");
        break;
    }

    if (aad)
    {
      auth = B0;
      siv_headers = null;
    }
    else ensureInit();
  }

  @Override
  protected void engineInit(int opmode, Key key, AlgorithmParameters algorithmParameters, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    AlgorithmParameterSpec spec = null;

    Class clazz = IvParameterSpec.class;
    switch (mode)
    {
      case CCM:
        clazz = CCMParameterSpec.class;
        break;
      case GCM:
        clazz = GCMParameterSpec.class;
    }

    if (algorithmParameters != null)
    {
      try
      {
        spec = algorithmParameters.getParameterSpec(clazz);
      }
      catch (InvalidParameterSpecException ipse)
      {
        throw new InvalidAlgorithmParameterException("Wrong parameter");
      }
    }

    engineInit(opmode, key, spec, secureRandom);
  }

  @Override
  protected void engineInit(int opmode, Key key, SecureRandom secureRandom) throws InvalidKeyException
  {
    try
    {
      engineInit(opmode, key, (AlgorithmParameterSpec) null, secureRandom);
    }
    catch (InvalidAlgorithmParameterException e)
    {
      throw new InvalidKeyException(e);
    }
  }

  private void updateSingleOp(byte[] in, int inOffset, int inLen)
  {
    int oldSize = buffer.length;
    byte[] newBuffer = new byte[oldSize + inLen];
    if (oldSize > 0) System.arraycopy(buffer, 0, newBuffer, 0, oldSize);
    System.arraycopy(in, inOffset, newBuffer, oldSize, inLen);
    buffer = newBuffer;
  }

  @Override
  protected void engineUpdateAAD(byte[] src,
                                 int offset,
                                 int len) throws IllegalStateException, UnsupportedOperationException
  {
    if (!aad) throw new IllegalStateException("Cipher does not accept AAD");

    if (mode == SIV)
    {
      if (siv_headers == null) siv_headers = new LinkedList();
      siv_headers.add(Arrays.copyOfRange(src, offset, len));
      return;
    }

    int oldSize = auth.length;
    byte[] newBuffer = new byte[oldSize + len];
    if (oldSize > 0) System.arraycopy(auth, 0, newBuffer, 0, oldSize);
    System.arraycopy(src, offset, newBuffer, oldSize, len);
    auth = newBuffer;
  }

  @Override
  protected byte[] engineUpdate(byte[] in, int inOffset, int inLen)
  {
    if ((inLen == 0) || (in == null)) return B0;

    ensureInitOperation();

    if (singleOp)
    {
      updateSingleOp(in, inOffset, inLen);
      return B0;
    }

    int outLen = encdecUpdateLen(in, inOffset, inLen);
    byte[] out = new byte[outLen];
    int realOutLen = encdecUpdate(in, inOffset, inLen, out, 0);
    if (realOutLen == outLen) return out;
    return Arrays.copyOf(out, realOutLen);
  }

  @Override
  protected int engineUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException
  {
    if ((inLen == 0) || (in == null)) return 0;
    ensureInitOperation();

    if (singleOp)
    {
      updateSingleOp(in, inOffset, inLen);
      return 0;
    }

    int outLen = encdecUpdateLen(in, inOffset, inLen);
    if (outLen > out.length - outOffset) throw new ShortBufferException();
    return encdecUpdate(in, inOffset, inLen, out, outOffset);
  }

  private int getOutBufLen(int inLen)
  {
    int blockLen = (keyType == CKK_DES3) ? 8 : 16;
    return inLen + blockLen * (encrypt ? 3 : 1); // block + data + block + tag
  }


  @Override
  protected byte[] engineDoFinal(byte[] in, int inOffset, int inLen)
          throws IllegalBlockSizeException, BadPaddingException, AEADBadTagException
  {
    ensureInitOperation();

    if (singleOp)
    {
      engineUpdate(in, inOffset, inLen);
      in = buffer;
      inOffset = 0;
      inLen = buffer.length;
    }

    int outLen = getOutBufLen(inLen); //encdecLen(in, inOffset, inLen);
    byte[] out = new byte[outLen == 0 ? 1 : outLen];

    int realOutLen = encdec(in, inOffset, inLen, out, 0);
    if (realOutLen == outLen) return out;
    return Arrays.copyOf(out, realOutLen);
  }

  @Override
  protected int engineDoFinal(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
          throws ShortBufferException, IllegalBlockSizeException, BadPaddingException, AEADBadTagException
  {
    ensureInitOperation();

    if (singleOp)
    {
      if (in != null && inLen > 0) engineUpdate(in, inOffset, inLen);
      in = buffer;
      inOffset = 0;
      inLen = buffer.length;
    }

    int outLen = encdecLen(in, inOffset, inLen);
    if (outLen > out.length - outOffset) throw new ShortBufferException();
    return encdec(in, inOffset, inLen, out, outOffset);
  }

  private int encdecLen(byte[] in, int inOffset, int inLen)
  {
    try
    {
      return encrypt ? session.encrypt(in, inOffset, inLen, null, 0) : session.decrypt(in, inOffset, inLen, null, 0);
    }
    catch (CKException e)
    {
      throw new ProviderException(e);
    }
  }

  private int encdec(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
          throws IllegalBlockSizeException, BadPaddingException, AEADBadTagException
  {
    try
    {
      return encrypt ? session.encrypt(in, inOffset, inLen, out, outOffset) : session.decrypt(in, inOffset, inLen, out, outOffset);
    }
    catch (CKException e)
    {
      int rv = e.getRV();
      if (rv == CKR_DATA_LEN_RANGE || rv == CKR_ENCRYPTED_DATA_LEN_RANGE) throw new IllegalBlockSizeException();
      if (rv == CKR_DATA_INVALID || rv == CKR_ENCRYPTED_DATA_INVALID)
      {
        if (aad) throw new AEADBadTagException();
        throw new BadPaddingException();
      }
      throw new ProviderException(e);
    }
    finally
    {
      releaseSession();
    }
  }

  private int encdecUpdateLen(byte[] in, int inOffset, int inLen)
  {
    try
    {
      return encrypt ? session.encryptUpdate(in, inOffset, inLen, null, 0) : session.decryptUpdate(in, inOffset, inLen, null, 0);
    }
    catch (CKException e)
    {
      throw new ProviderException(e);
    }
  }

  private int encdecUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
  {
    try
    {
      return encrypt ? session.encryptUpdate(in, inOffset, inLen, out, outOffset) : session.decryptUpdate(in, inOffset, inLen, out, outOffset);
    }
    catch (CKException e)
    {
      throw new ProviderException(e);
    }
  }

  @Override
  protected byte[] engineWrap(Key key) throws IllegalBlockSizeException, InvalidKeyException
  {
    CKKey wrappedPkcs11Key = null;
    try
    {
      if (key instanceof SecretKey)
      {
        ((SecretKey) key).save();
        wrappedPkcs11Key = ((SecretKey) key).pkcs11Key;
      }
      else if (key instanceof RSAPrivateKey)
      {
        ((RSAPrivateKey) key).save();
        wrappedPkcs11Key = ((RSAPrivateKey) key).pkcs11Key;
      }
      else if (key instanceof ECPrivateKey)
      {
        ((ECPrivateKey) key).save();
        wrappedPkcs11Key = ((ECPrivateKey) key).pkcs11Key;
      }
    }
    catch (KeyStoreException e)
    {
      throw new InvalidKeyException(e);
    }

    try
    {
      prepareMechanism();
    }
    catch (InvalidAlgorithmParameterException e)
    {
      throw new InvalidKeyException(e);
    }

    if (wrappedPkcs11Key != null)
    {
      try
      {
        return secretKey.pkcs11Key.wrap(mechanism, wrappedPkcs11Key, 0);
      }
      catch (CKException e)
      {
        throw new InvalidKeyException(e);
      }
    }

    byte[] encodedKey = key.getEncoded();
    if ((encodedKey == null) || (encodedKey.length == 0))
      throw new InvalidKeyException("Cannot get an encoding of the key to be wrapped");

    try
    {
      initSession();
    }
    catch (InvalidAlgorithmParameterException e)
    {
      throw new ProviderException(e);
    }
    try
    {
      return engineDoFinal(encodedKey, 0, encodedKey.length);
    }
    catch (BadPaddingException e)
    {
      throw new ProviderException(e);
    }
    finally
    {
      releaseSession();
    }
  }

  @Override
  protected Key engineUnwrap(byte[] wrappedKey, String wrappedKeyAlgorithm, int wrappedKeyType) throws InvalidKeyException, NoSuchAlgorithmException
  {
    try
    {
      prepareMechanism();
    }
    catch (InvalidAlgorithmParameterException e)
    {
      throw new InvalidKeyException(e);
    }

    UnwrapInfo unwrapInfo = new UnwrapInfo(mechanism, secretKey.pkcs11Key, wrappedKey);

    switch (wrappedKeyType)
    {
      case Cipher.PRIVATE_KEY:
        if (wrappedKeyAlgorithm.equalsIgnoreCase("RSA"))
          return new RSAPrivateKey().initForUnwrap(unwrapInfo, unwrapKeyParams);
        else if (wrappedKeyAlgorithm.equalsIgnoreCase("EC"))
          return new ECPrivateKey().initForUnwrap(unwrapInfo, unwrapKeyParams);
        throw new InvalidKeyException("Unsupported wrappedKeyAlgorithm " + wrappedKeyAlgorithm);

      case Cipher.SECRET_KEY:
        return new SecretKey().initForUnwrap(unwrapInfo, SecretKey.algToKeyType(wrappedKeyAlgorithm), unwrapKeyParams);
    }
    throw new InvalidKeyException("Unsupported wrappedKeyType");
  }

  public static final class AES extends SecretKeyCipher
  {
    public AES()
    {
      super(CKK_AES);
    }
  }

  public static final class AESXTS extends SecretKeyCipher
  {
    public AESXTS()
    {
      super(DYCKK_AES_XTS);
    }
  }

  public static final class AESSIV extends SecretKeyCipher
  {
    public AESSIV()
    {
      super(DYCKK_AES_SIV);
    }
  }

  public static final class DES3 extends SecretKeyCipher
  {
    public DES3()
    {
      super(CKK_DES3);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy