com.microsoft.sqlserver.jdbc.KeyStoreProviderCommon Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mssql-jdbc Show documentation
Show all versions of mssql-jdbc Show documentation
Microsoft JDBC Driver for SQL Server.
/*
* Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
* available under the terms of the MIT License. See the LICENSE file in the project root for more information.
*/
package com.microsoft.sqlserver.jdbc;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.X509Certificate;
import java.text.MessageFormat;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
/**
*
* This class holds information about the certificate
*
*/
class CertificateDetails {
X509Certificate certificate;
Key privateKey;
CertificateDetails(X509Certificate certificate, Key privateKey) {
this.certificate = certificate;
this.privateKey = privateKey;
}
}
class KeyStoreProviderCommon {
static final String rsaEncryptionAlgorithmWithOAEP = "RSA_OAEP";
static byte[] version = new byte[] {0x01};
static void validateEncryptionAlgorithm(String encryptionAlgorithm, boolean isEncrypt) throws SQLServerException {
String errString = isEncrypt ? "R_NullKeyEncryptionAlgorithm" : "R_NullKeyEncryptionAlgorithmInternal";
if (null == encryptionAlgorithm) {
throw new SQLServerException(null, SQLServerException.getErrString(errString), null, 0, false);
}
errString = isEncrypt ? "R_InvalidKeyEncryptionAlgorithm" : "R_InvalidKeyEncryptionAlgorithmInternal";
if (!rsaEncryptionAlgorithmWithOAEP.equalsIgnoreCase(encryptionAlgorithm.trim())) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString(errString));
Object[] msgArgs = {encryptionAlgorithm, rsaEncryptionAlgorithmWithOAEP};
throw new SQLServerException(form.format(msgArgs), null);
}
}
static void validateNonEmptyMasterKeyPath(String masterKeyPath) throws SQLServerException {
if (null == masterKeyPath || masterKeyPath.trim().length() == 0) {
throw new SQLServerException(null, SQLServerException.getErrString("R_InvalidMasterKeyDetails"), null, 0,
false);
}
}
static byte[] decryptColumnEncryptionKey(String masterKeyPath, String encryptionAlgorithm,
byte[] encryptedColumnEncryptionKey, CertificateDetails certificateDetails) throws SQLServerException {
if (null == encryptedColumnEncryptionKey) {
throw new SQLServerException(null, SQLServerException.getErrString("R_NullEncryptedColumnEncryptionKey"),
null, 0, false);
} else if (0 == encryptedColumnEncryptionKey.length) {
throw new SQLServerException(null, SQLServerException.getErrString("R_EmptyEncryptedColumnEncryptionKey"),
null, 0, false);
}
validateEncryptionAlgorithm(encryptionAlgorithm, false);
int currentIndex = version.length;
int keyPathLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
// We just read 2 bytes
currentIndex += 2;
// Get ciphertext length
int cipherTextLength = convertTwoBytesToShort(encryptedColumnEncryptionKey, currentIndex);
currentIndex += 2;
currentIndex += keyPathLength;
int signatureLength = encryptedColumnEncryptionKey.length - currentIndex - cipherTextLength;
// Get ciphertext
byte[] cipherText = new byte[cipherTextLength];
System.arraycopy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);
currentIndex += cipherTextLength;
byte[] signature = new byte[signatureLength];
System.arraycopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength);
byte[] hash = new byte[encryptedColumnEncryptionKey.length - signature.length];
System.arraycopy(encryptedColumnEncryptionKey, 0, hash, 0,
encryptedColumnEncryptionKey.length - signature.length);
if (!verifyRSASignature(hash, signature, certificateDetails.certificate, masterKeyPath)) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidCertificateSignature"));
Object[] msgArgs = {masterKeyPath};
throw new SQLServerException(form.format(msgArgs), null);
}
byte[] plainCEK = decryptRSAOAEP(cipherText, certificateDetails);
return plainCEK;
}
private static byte[] decryptRSAOAEP(byte[] cipherText,
CertificateDetails certificateDetails) throws SQLServerException {
byte[] plainCEK = null;
try {
Cipher rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-1AndMGF1Padding");
rsa.init(Cipher.DECRYPT_MODE, certificateDetails.privateKey);
rsa.update(cipherText);
plainCEK = rsa.doFinal();
} catch (InvalidKeyException | NoSuchAlgorithmException | NoSuchPaddingException | IllegalBlockSizeException
| BadPaddingException e) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_CEKDecryptionFailed"));
Object[] msgArgs = {e.getMessage()};
throw new SQLServerException(form.format(msgArgs), e);
}
return plainCEK;
}
static boolean verifyRSASignature(byte[] hash, byte[] signature, X509Certificate certificate,
String masterKeyPath) throws SQLServerException {
Signature signVerify;
boolean verificationSuccess = false;
try {
signVerify = Signature.getInstance("SHA256withRSA");
signVerify.initVerify(certificate.getPublicKey());
signVerify.update(hash);
verificationSuccess = signVerify.verify(signature);
} catch (InvalidKeyException | NoSuchAlgorithmException | SignatureException e) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidCertificateSignature"));
Object[] msgArgs = {masterKeyPath};
throw new SQLServerException(form.format(msgArgs), e);
}
return verificationSuccess;
}
private static short convertTwoBytesToShort(byte[] input, int index) throws SQLServerException {
short shortVal;
if (index + 1 >= input.length) {
throw new SQLServerException(null, SQLServerException.getErrString("R_ByteToShortConversion"), null, 0,
false);
}
ByteBuffer byteBuffer = ByteBuffer.allocate(2);
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
byteBuffer.put(input[index]);
byteBuffer.put(input[index + 1]);
shortVal = byteBuffer.getShort(0);
return shortVal;
}
}