All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.sqlserver.jdbc.SQLServerCertificateUtils Maven / Gradle / Ivy

There is a newer version: 12.8.1.jre11
Show newest version
/*
 * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
 * available under the terms of the MIT License. See the LICENSE file in the project root for more information.
 */

package com.microsoft.sqlserver.jdbc;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
import java.math.BigInteger;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;

import org.bouncycastle.openssl.PEMDecryptorProvider;
import org.bouncycastle.openssl.PEMEncryptedKeyPair;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder;


final class SQLServerCertificateUtils {

    private static final Logger logger = Logger.getLogger("com.microsoft.sqlserver.jdbc.SQLServerCertificateUtils");
    private static final String logContext = Thread.currentThread().getStackTrace()[1].getClassName() + ": ";

    static KeyManager[] getKeyManagerFromFile(String certPath, String keyPath,
            String keyPassword) throws IOException, GeneralSecurityException, SQLServerException {
        if (keyPath != null && keyPath.length() > 0) {
            return readPKCS8Certificate(certPath, keyPath, keyPassword);
        } else {
            return readPKCS12Certificate(certPath, keyPassword);
        }
    }

    /**
     * Parse name in RFC 2253 format Returns the common name if successful, null if failed to find the common name. The
     * parser tuned to be safe than sorry so if it sees something it can't parse correctly it returns null
     * 
     * @param distinguishedName
     *        server name to parse
     * @return subject name
     */
    static String parseCommonName(String distinguishedName) {
        int index;
        // canonical name converts entire name to lowercase
        index = distinguishedName.indexOf("cn=");
        if (index == -1) {
            return null;
        }
        distinguishedName = distinguishedName.substring(index + 3);
        // Parse until a comma or end is reached
        // Note the parser will handle gracefully (essentially will return empty string) , inside the quotes (e.g
        // cn="Foo, bar") however
        // RFC 952 says that the hostName cant have commas however the parser should not (and will not) crash if it
        // sees a , within quotes.
        for (index = 0; index < distinguishedName.length(); index++) {
            if (distinguishedName.charAt(index) == ',') {
                break;
            }
        }
        String commonName = distinguishedName.substring(0, index);
        // strip any quotes
        if (commonName.length() > 1 && ('\"' == commonName.charAt(0))) {
            if ('\"' == commonName.charAt(commonName.length() - 1))
                commonName = commonName.substring(1, commonName.length() - 1);
            else {
                // Be safe the name is not ended in " return null so the common Name wont match
                commonName = null;
            }
        }
        return commonName;
    }

    /**
     * Validate server name in certificate matches hostname
     * 
     * @param nameInCert
     *        server name in certificate
     * @param hostName
     *        hostname
     * @return if the server name is valid and matches hostname
     */
    static boolean validateServerName(String nameInCert, String hostName) {
        // Failed to get the common name from DN or empty CN
        if (null == nameInCert) {
            if (logger.isLoggable(Level.FINER)) {
                logger.finer(logContext + " Failed to parse the name from the certificate or name is empty.");
            }
            return false;
        }
        // We do not allow wildcards in IDNs (xn--).
        if (!nameInCert.startsWith("xn--") && nameInCert.contains("*")) {
            int hostIndex = 0, certIndex = 0, match = 0, startIndex = -1, periodCount = 0;
            while (hostIndex < hostName.length()) {
                if ('.' == hostName.charAt(hostIndex)) {
                    periodCount++;
                }
                if (certIndex < nameInCert.length() && hostName.charAt(hostIndex) == nameInCert.charAt(certIndex)) {
                    hostIndex++;
                    certIndex++;
                } else if (certIndex < nameInCert.length() && '*' == nameInCert.charAt(certIndex)) {
                    startIndex = certIndex;
                    match = hostIndex;
                    certIndex++;
                } else if (startIndex != -1 && 0 == periodCount) {
                    certIndex = startIndex + 1;
                    match++;
                    hostIndex = match;
                } else {
                    logFailMessage(nameInCert, hostName);
                    return false;
                }
            }
            if (nameInCert.length() == certIndex && periodCount > 1) {
                logSuccessMessage(nameInCert, hostName);
                return true;
            } else {
                logFailMessage(nameInCert, hostName);
                return false;
            }
        }
        // Verify that the name in certificate matches exactly with the host name
        if (!nameInCert.equals(hostName)) {
            logFailMessage(nameInCert, hostName);
            return false;
        }
        logSuccessMessage(nameInCert, hostName);
        return true;
    }

    /**
     * Validate server name in certificate
     * 
     * @param cert
     *        X509 certificate
     * @param hostName
     *        hostname
     * @throws CertificateException
     */
    static void validateServerNameInCertificate(X509Certificate cert, String hostName) throws CertificateException {
        String nameInCertDN = cert.getSubjectX500Principal().getName("canonical");

        if (logger.isLoggable(Level.FINER)) {
            logger.finer(logContext + " Validating the server name:" + hostName);
            logger.finer(logContext + " The DN name in certificate:" + nameInCertDN);
        }

        boolean isServerNameValidated;
        String dnsNameInSANCert = "";

        // the name in cert is in RFC2253 format parse it to get the actual subject name
        String subjectCN = parseCommonName(nameInCertDN);

        isServerNameValidated = validateServerName(subjectCN, hostName);

        if (!isServerNameValidated) {
            Collection> sanCollection = cert.getSubjectAlternativeNames();

            if (sanCollection != null) {
                // find a subjectAlternateName entry corresponding to DNS Name
                for (List sanEntry : sanCollection) {

                    if (sanEntry != null && sanEntry.size() >= 2) {
                        Object key = sanEntry.get(0);
                        Object value = sanEntry.get(1);

                        if (logger.isLoggable(Level.FINER)) {
                            logger.finer(logContext + "Key: " + key + "; KeyClass:"
                                    + (key != null ? key.getClass() : null) + ";value: " + value + "; valueClass:"
                                    + (value != null ? value.getClass() : null));
                        }

                        // From
                        // Documentation(http://download.oracle.com/javase/6/docs/api/java/security/cert/X509Certificate.html):
                        // "Note that the Collection returned may contain
                        // more than one name of the same type."
                        // So, more than one entry of dnsNameType can be present.
                        // Java docs guarantee that the first entry in the list will be an integer.
                        // 2 is the sequence no of a dnsName
                        if ((key != null) && (key instanceof Integer) && ((Integer) key == 2)) {
                            // As per RFC2459, the DNSName will be in the
                            // "preferred name syntax" as specified by RFC
                            // 1034 and the name can be in upper or lower case.
                            // And no significance is attached to case.
                            // Java docs guarantee that the second entry in the list
                            // will be a string for dnsName
                            if (value != null && value instanceof String) {
                                dnsNameInSANCert = (String) value;

                                // Use English locale to avoid Turkish i issues.
                                // Note that, this conversion was not necessary for
                                // cert.getSubjectX500Principal().getName("canonical");
                                // as the above API already does this by default as per documentation.
                                dnsNameInSANCert = dnsNameInSANCert.toLowerCase(Locale.ENGLISH);

                                isServerNameValidated = validateServerName(dnsNameInSANCert, hostName);
                                if (isServerNameValidated) {
                                    if (logger.isLoggable(Level.FINER)) {
                                        logger.finer(
                                                logContext + " found a valid name in certificate: " + dnsNameInSANCert);
                                    }
                                    break;
                                }
                            }

                            if (logger.isLoggable(Level.FINER)) {
                                logger.finer(logContext
                                        + " the following name in certificate does not match the serverName: " + value);
                            }
                        }

                    } else {
                        if (logger.isLoggable(Level.FINER)) {
                            logger.finer(logContext + " found an invalid san entry: " + sanEntry);
                        }
                    }
                }
            }
        }

        if (!isServerNameValidated) {
            MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_certNameFailed"));
            Object[] msgArgs = {hostName, dnsNameInSANCert};
            throw new CertificateException(form.format(msgArgs));
        }
    }

    /**
     * Validate certificate provided in path against server X509 certificate
     * 
     * @param cert
     *        X509 certificate
     * @param certFile
     *        path to certificate file to validate
     * @throws CertificateException
     */
    static void validateServerCerticate(X509Certificate cert, String certFile) throws CertificateException {
        try (InputStream is = fileToStream(certFile)) {
            if (!CertificateFactory.getInstance("X509").generateCertificate(is).getPublicKey()
                    .equals(cert.getPublicKey())) {
                MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_serverCertError"));
                Object[] msgArgs = {certFile};
                throw new CertificateException(form.format(msgArgs));
            }
        } catch (Exception e) {
            MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_serverCertError"));
            Object[] msgArgs = {certFile, e.getMessage()};
            throw new CertificateException(form.format(msgArgs));
        }
    }

    private static void logFailMessage(String nameInCert, String hostName) {
        if (logger.isLoggable(Level.FINER)) {
            logger.finer(logContext + " The name in certificate " + nameInCert + " does not match with the server name "
                    + hostName + ".");
        }
    }

    private static void logSuccessMessage(String nameInCert, String hostName) {
        if (logger.isLoggable(Level.FINER)) {
            logger.finer(logContext + " The name in certificate:" + nameInCert + " validated against server name "
                    + hostName + ".");
        }
    }

    // PKCS#12 format
    private static final String PKCS12_ALG = "PKCS12";
    private static final String SUN_X_509 = "SunX509";
    // PKCS#8 format
    private static final String PEM_PRIVATE_START = "-----BEGIN PRIVATE KEY-----";
    private static final String PEM_PRIVATE_END = "-----END PRIVATE KEY-----";
    private static final String JAVA_KEY_STORE = "JKS";
    private static final String CLIENT_CERT = "client-cert";
    private static final String CLIENT_KEY = "client-key";
    // PKCS#1 format
    private static final String PEM_RSA_PRIVATE_START = "-----BEGIN RSA PRIVATE KEY-----";
    // PVK format
    private static final long PVK_MAGIC = 0xB0B5F11EL;
    private static final byte[] RSA2_MAGIC = {82, 83, 65, 50};
    private static final String RC4_ALG = "RC4";
    private static final String RSA_ALG = "RSA";

    private static KeyManager[] readPKCS12Certificate(String certPath,
            String keyPassword) throws NoSuchAlgorithmException, CertificateException, FileNotFoundException, IOException, UnrecoverableKeyException, KeyStoreException, SQLServerException {
        KeyStore keystore = KeyStore.getInstance(PKCS12_ALG);
        try (FileInputStream certStream = new FileInputStream(certPath)) {
            keystore.load(certStream, keyPassword.toCharArray());
        } catch (FileNotFoundException e) {
            throw new SQLServerException(SQLServerException.getErrString("R_clientCertError"), null, 0, null);
        }
        KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(SUN_X_509);
        keyManagerFactory.init(keystore, keyPassword.toCharArray());
        return keyManagerFactory.getKeyManagers();
    }

    private static KeyManager[] readPKCS8Certificate(String certPath, String keyPath,
            String keyPassword) throws IOException, GeneralSecurityException, SQLServerException {
        Certificate clientCertificate = loadCertificate(certPath);
        ((X509Certificate) clientCertificate).checkValidity();
        PrivateKey privateKey = loadPrivateKey(keyPath, keyPassword);

        KeyStore keyStore = KeyStore.getInstance(JAVA_KEY_STORE);
        keyStore.load(null, null);
        keyStore.setCertificateEntry(CLIENT_CERT, clientCertificate);
        keyStore.setKeyEntry(CLIENT_KEY, privateKey, keyPassword.toCharArray(), new Certificate[] {clientCertificate});

        KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
        kmf.init(keyStore, keyPassword.toCharArray());
        return kmf.getKeyManagers();
    }

    private static PrivateKey loadPrivateKeyFromPKCS8(
            String key) throws IOException, NoSuchAlgorithmException, InvalidKeySpecException {
        StringBuilder sb = new StringBuilder(key);
        deleteFirst(sb, PEM_PRIVATE_START);
        deleteFirst(sb, PEM_PRIVATE_END);
        byte[] formattedKey = Base64.getDecoder().decode(sb.toString().replaceAll("\\s", ""));

        KeyFactory factory = KeyFactory.getInstance(RSA_ALG);
        return factory.generatePrivate(new PKCS8EncodedKeySpec(formattedKey));
    }

    private static void deleteFirst(StringBuilder sb, String str) {
        int i = sb.indexOf(str);
        if (i != -1) {
            sb.delete(i, i + str.length());
        }
    }

    private static PrivateKey loadPrivateKeyFromPKCS1(String key,
            String keyPass) throws IOException, NoSuchAlgorithmException, InvalidKeySpecException {
        SQLServerBouncyCastleLoader.loadBouncyCastle();
        try (PEMParser pemParser = new PEMParser(new StringReader(key))) {
            Object object = pemParser.readObject();
            JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider("BC");
            KeyPair kp;
            if (object instanceof PEMEncryptedKeyPair && keyPass != null) {
                PEMDecryptorProvider decProv = new JcePEMDecryptorProviderBuilder().build(keyPass.toCharArray());
                kp = converter.getKeyPair(((PEMEncryptedKeyPair) object).decryptKeyPair(decProv));
            } else {
                kp = converter.getKeyPair((PEMKeyPair) object);
            }
            return kp.getPrivate();
        }
    }

    private static PrivateKey loadPrivateKeyFromPVK(String keyPath,
            String keyPass) throws IOException, GeneralSecurityException, SQLServerException {
        File f = new File(keyPath);
        ByteBuffer buffer = ByteBuffer.allocate((int) f.length());

        try (FileInputStream in = new FileInputStream(f); FileChannel channel = in.getChannel()) {
            ((Buffer) buffer.order(ByteOrder.LITTLE_ENDIAN)).rewind();

            long magic = buffer.getInt() & 0xFFFFFFFFL;
            if (PVK_MAGIC != magic) {
                SQLServerException.makeFromDriverError(null, magic, SQLServerResource.getResource("R_pvkHeaderError"),
                        "", false);
            }

            ((Buffer) buffer).position(((Buffer) buffer).position() + 8); // skip reserved and keytype
            boolean encrypted = buffer.getInt() != 0;
            int saltLength = buffer.getInt();
            int keyLength = buffer.getInt();
            byte[] salt = new byte[saltLength];
            buffer.get(salt);

            ((Buffer) buffer).position(((Buffer) buffer).position() + 8); // skip btype(1b), version(1b), reserved(2b),
                                                                          // and keyalg(4b)

            byte[] key = new byte[keyLength - 8];
            buffer.get(key);

            if (encrypted) {
                MessageDigest digest = MessageDigest.getInstance("SHA1");
                digest.update(salt);
                if (null != keyPass) {
                    digest.update(keyPass.getBytes());
                }
                byte[] hash = digest.digest();
                key = getSecretKeyFromHash(key, hash);
            }

            ByteBuffer buff = ByteBuffer.wrap(key).order(ByteOrder.LITTLE_ENDIAN);
            ((Buffer) buff).position(RSA2_MAGIC.length); // skip the header

            int byteLength = buff.getInt() / 8;
            BigInteger publicExponent = BigInteger.valueOf(buff.getInt());
            BigInteger modulus = getBigInteger(buff, byteLength);
            BigInteger prime1 = getBigInteger(buff, byteLength / 2);
            BigInteger prime2 = getBigInteger(buff, byteLength / 2);
            BigInteger primeExponent1 = getBigInteger(buff, byteLength / 2);
            BigInteger primeExponent2 = getBigInteger(buff, byteLength / 2);
            BigInteger crtCoefficient = getBigInteger(buff, byteLength / 2);
            BigInteger privateExponent = getBigInteger(buff, byteLength);

            RSAPrivateCrtKeySpec spec = new RSAPrivateCrtKeySpec(modulus, publicExponent, privateExponent, prime1,
                    prime2, primeExponent1, primeExponent2, crtCoefficient);
            KeyFactory factory = KeyFactory.getInstance(RSA_ALG);
            return factory.generatePrivate(spec);
        }
    }

    private static Certificate loadCertificate(
            String certificatePem) throws IOException, GeneralSecurityException, SQLServerException {
        CertificateFactory certificateFactory = CertificateFactory.getInstance("X509");
        try (InputStream certStream = fileToStream(certificatePem)) {
            return certificateFactory.generateCertificate(certStream);
        }
    }

    private static PrivateKey loadPrivateKey(String privateKeyPemPath,
            String privateKeyPassword) throws GeneralSecurityException, IOException, SQLServerException {
        String privateKeyPem = getStringFromFile(privateKeyPemPath);

        if (privateKeyPem.contains(PEM_PRIVATE_START)) { // PKCS#8 format
            return loadPrivateKeyFromPKCS8(privateKeyPem);
        } else if (privateKeyPem.contains(PEM_RSA_PRIVATE_START)) { // PKCS#1 format
            return loadPrivateKeyFromPKCS1(privateKeyPem, privateKeyPassword);
        } else {
            return loadPrivateKeyFromPVK(privateKeyPemPath, privateKeyPassword);
        }
    }

    private static boolean startsWithMagic(byte[] b) {
        for (int i = 0; i < RSA2_MAGIC.length; i++) {
            if (b[i] != RSA2_MAGIC[i])
                return false;
        }
        return true;
    }

    private static byte[] getSecretKeyFromHash(byte[] originalKey,
            byte[] keyHash) throws GeneralSecurityException, SQLServerException {
        SecretKey key = new SecretKeySpec(keyHash, 0, 16, RC4_ALG);
        byte[] decrypted = decryptSecretKey(key, originalKey);
        if (startsWithMagic(decrypted)) {
            return decrypted;
        }

        // Couldn't find magic due to padding, trim the key
        Arrays.fill(keyHash, 5, keyHash.length, (byte) 0);
        key = new SecretKeySpec(keyHash, 0, 16, RC4_ALG);
        decrypted = decryptSecretKey(key, originalKey);
        if (startsWithMagic(decrypted)) {
            return decrypted;
        }

        SQLServerException.makeFromDriverError(null, originalKey, SQLServerResource.getResource("R_pvkParseError"), "",
                false);
        return null;
    }

    private static byte[] decryptSecretKey(SecretKey key, byte[] encoded) throws GeneralSecurityException {
        Cipher cipher = Cipher.getInstance(key.getAlgorithm());
        cipher.init(Cipher.DECRYPT_MODE, key);
        return cipher.doFinal(encoded);
    }

    private static BigInteger getBigInteger(ByteBuffer buffer, int length) {
        // Add an extra bit for signum
        byte[] array = new byte[length + 1];
        // Write in reverse because our buffer was set to Little Endian
        for (int i = 0; i < length; i++) {
            array[array.length - 1 - i] = buffer.get();
        }
        return new BigInteger(array);
    }

    private static InputStream fileToStream(String fname) throws IOException, SQLServerException {
        try (FileInputStream fis = new FileInputStream(fname); DataInputStream dis = new DataInputStream(fis)) {
            byte[] bytes = new byte[dis.available()];
            dis.readFully(bytes);
            return new ByteArrayInputStream(bytes);
        } catch (FileNotFoundException e) {
            throw new SQLServerException(SQLServerException.getErrString("R_clientCertError"), null, 0, null);
        }
    }

    private static String getStringFromFile(String filePath) throws IOException {
        return new String(Files.readAllBytes(Paths.get(filePath)));
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy