All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.predic8.membrane.core.transport.ssl.StaticSSLContext Maven / Gradle / Ivy

There is a newer version: 5.6.0
Show newest version
/* Copyright 2016 predic8 GmbH, www.predic8.com

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

package com.predic8.membrane.core.transport.ssl;

import com.google.common.base.Objects;
import com.predic8.membrane.core.config.security.SSLParser;
import com.predic8.membrane.core.config.security.Store;
import com.predic8.membrane.core.resolver.ResolverMap;
import com.predic8.membrane.core.transport.TrustManagerWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import javax.crypto.Cipher;
import javax.net.ssl.*;
import javax.validation.constraints.NotNull;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.*;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateParsingException;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPublicKey;
import java.util.*;

public class StaticSSLContext extends SSLContext {

    private static final String DEFAULT_CERTIFICATE_SHA256 = "c7:e3:fd:97:2f:d3:b9:4f:38:87:9c:45:32:70:b3:d8:c1:9f:d1:64:39:fc:48:5f:f4:a1:6a:95:b5:ca:08:f7";
    private static boolean default_certificate_warned = false;
    private static boolean limitedStrength;

    private static final Logger log = LoggerFactory.getLogger(StaticSSLContext.class.getName());

    static {
        String dhKeySize = System.getProperty("jdk.tls.ephemeralDHKeySize");
        if (dhKeySize == null || "legacy".equals(dhKeySize))
            System.setProperty("jdk.tls.ephemeralDHKeySize", "matched");

        try {
            limitedStrength = Cipher.getMaxAllowedKeyLength("AES") <= 128;
            if (limitedStrength)
                log.warn("Your Java Virtual Machine does not have unlimited strength cryptography. If it is legal in your country, we strongly advise installing the Java Cryptography Extension (JCE) Unlimited Strength Jurisdiction Policy Files.");
        } catch (NoSuchAlgorithmException ignored) {
        }
    }


    private final SSLParser sslParser;
    private List dnsNames;

    private javax.net.ssl.SSLContext sslc;


    public StaticSSLContext(SSLParser sslParser, ResolverMap resourceResolver, String baseLocation) {
        this.sslParser = sslParser;

        try {
            String algorihm = KeyManagerFactory.getDefaultAlgorithm();
            if (sslParser.getAlgorithm() != null)
                algorihm = sslParser.getAlgorithm();

            KeyManagerFactory kmf = null;
            String keyStoreType = "JKS";
            if (sslParser.getKeyStore() != null) {
                if (sslParser.getKeyStore().getKeyAlias() != null)
                    throw new InvalidParameterException("keyAlias is not yet supported.");
                char[] keyPass = "changeit".toCharArray();
                if (sslParser.getKeyStore().getKeyPassword() != null)
                    keyPass = sslParser.getKeyStore().getKeyPassword().toCharArray();

                if (sslParser.getKeyStore().getType() != null)
                    keyStoreType = sslParser.getKeyStore().getType();
                KeyStore ks = openKeyStore(sslParser.getKeyStore(), "JKS", keyPass, resourceResolver, baseLocation);
                kmf = KeyManagerFactory.getInstance(algorihm);
                kmf.init(ks, keyPass);

                Enumeration aliases = ks.aliases();
                while (aliases.hasMoreElements()) {
                    String alias = aliases.nextElement();
                    if (ks.isKeyEntry(alias)) {
                        // first key is used by the KeyManagerFactory
                        dnsNames = getDNSNames(ks.getCertificate(alias));
                        break;
                    }
                }
            }
            if (sslParser.getKey() != null) {
                if (kmf != null)
                    throw new InvalidParameterException(" may not be used together with .");

                KeyStore ks = KeyStore.getInstance(keyStoreType);
                ks.load(null, "".toCharArray());

                List certs = new ArrayList();

                for (com.predic8.membrane.core.config.security.Certificate cert : sslParser.getKey().getCertificates())
                    certs.add(PEMSupport.getInstance().parseCertificate(cert.get(resourceResolver, baseLocation)));
                if (certs.size() == 0)
                    throw new RuntimeException("At least one //ssl/key/certificate is required.");
                dnsNames = getDNSNames(certs.get(0));

                checkChainValidity(certs);
                Object key = PEMSupport.getInstance().parseKey(sslParser.getKey().getPrivate().get(resourceResolver, baseLocation));
                Key k = key instanceof Key ? (Key) key : ((KeyPair)key).getPrivate();
                if (k instanceof RSAPrivateCrtKey && certs.get(0).getPublicKey() instanceof RSAPublicKey) {
                    RSAPrivateCrtKey privkey = (RSAPrivateCrtKey)k;
                    RSAPublicKey pubkey = (RSAPublicKey) certs.get(0).getPublicKey();
                    if (!(privkey.getModulus().equals(pubkey.getModulus()) && privkey.getPublicExponent().equals(pubkey.getPublicExponent())))
                        log.warn("Certificate does not fit to key.");
                }

                ks.setKeyEntry("inlinePemKeyAndCertificate", k, "".toCharArray(),  certs.toArray(new Certificate[certs.size()]));

                kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
                String keyPassword = "";
                if (sslParser.getKey().getPassword() != null)
                    keyPassword = sslParser.getKey().getPassword();
                kmf.init(ks, keyPassword.toCharArray());
            }

            TrustManagerFactory tmf = null;
            if (sslParser.getTrustStore() != null) {
                String trustAlgorithm = TrustManagerFactory.getDefaultAlgorithm();
                if (sslParser.getTrustStore().getAlgorithm() != null)
                    trustAlgorithm = sslParser.getTrustStore().getAlgorithm();
                KeyStore ks = openKeyStore(sslParser.getTrustStore(), keyStoreType, null, resourceResolver, baseLocation);
                tmf = TrustManagerFactory.getInstance(trustAlgorithm);
                tmf.init(ks);
            }
            if (sslParser.getTrust() != null) {
                if (tmf != null)
                    throw new InvalidParameterException(" may not be used together with .");

                KeyStore ks = KeyStore.getInstance(keyStoreType);
                ks.load(null, "".toCharArray());

                for (int j = 0; j < sslParser.getTrust().getCertificateList().size(); j++)
                    ks.setCertificateEntry("inlinePemCertificate" + j, PEMSupport.getInstance().parseCertificate(sslParser.getTrust().getCertificateList().get(j).get(resourceResolver, baseLocation)));

                tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
                tmf.init(ks);
            }

            TrustManager[] tms = tmf != null ? tmf.getTrustManagers() : null /* trust anyone: new TrustManager[] { new NullTrustManager() } */;
            if (sslParser.isIgnoreTimestampCheckFailure())
                tms = new TrustManager[] { new TrustManagerWrapper(tms, true) };

            if (sslParser.getProtocol() != null)
                sslc = javax.net.ssl.SSLContext.getInstance(sslParser.getProtocol());
            else
                sslc = javax.net.ssl.SSLContext.getInstance("TLS");

            sslc.init(
                    kmf != null ? kmf.getKeyManagers() : null,
                    tms,
                    null);

        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        init(sslParser, sslc);
    }

    public StaticSSLContext(SSLParser sslParser, javax.net.ssl.SSLContext sslc) {
        this.sslParser = sslParser;
        this.sslc = sslc;
        init(sslParser, sslc);
    }

    private List getDNSNames(Certificate certificate) throws CertificateParsingException {
        ArrayList dnsNames = new ArrayList();
        if (certificate instanceof X509Certificate) {
            X509Certificate x = (X509Certificate) certificate;

            Collection> subjectAlternativeNames = x.getSubjectAlternativeNames();
            if (subjectAlternativeNames != null)
                for (List l : subjectAlternativeNames) {
                    if (l.get(0) instanceof Integer && ((Integer)l.get(0) == 2))
                        dnsNames.add(l.get(1).toString());
                }
        }
        return dnsNames;
    }

    @Override
    public boolean equals(Object obj) {
        if (!(obj instanceof SSLContext))
            return false;
        StaticSSLContext other = (StaticSSLContext)obj;
        return Objects.equal(sslParser, other.sslParser);
    }

    private KeyStore openKeyStore(Store store, String defaultType, char[] keyPass, ResolverMap resourceResolver, String baseLocation) throws NoSuchAlgorithmException, CertificateException, FileNotFoundException, IOException, KeyStoreException, NoSuchProviderException {
        String type = store.getType();
        if (type == null)
            type = defaultType;
        char[] password = keyPass;
        if (store.getPassword() != null)
            password = store.getPassword().toCharArray();
        if (password == null)
            throw new InvalidParameterException("Password for key store is not set.");
        KeyStore ks;
        if (store.getProvider() != null)
            ks = KeyStore.getInstance(type, store.getProvider());
        else
            ks = KeyStore.getInstance(type);
        ks.load(resourceResolver.resolve(ResolverMap.combine(baseLocation, store.getLocation())), password);
        if (!default_certificate_warned && ks.getCertificate("membrane") != null) {
            byte[] pkeEnc = ks.getCertificate("membrane").getEncoded();
            MessageDigest md = MessageDigest.getInstance("SHA-256");
            md.update(pkeEnc);
            byte[] mdbytes = md.digest();
            StringBuffer sb = new StringBuffer();
            for (int i = 0; i < mdbytes.length; i++) {
                if (i > 0)
                    sb.append(':');
                sb.append(Integer.toString((mdbytes[i] & 0xff) + 0x100, 16).substring(1));
            }
            if (sb.toString().equals(DEFAULT_CERTIFICATE_SHA256)) {
                log.warn("Using Membrane with the default certificate. This is highly discouraged! "
                        + "Please run the generate-ssl-keys script in the conf directory.");
                default_certificate_warned = true;
            }
        }
        return ks;
    }

    public void applyCiphers(SSLServerSocket sslServerSocket) {
        if (ciphers != null) {
            SSLParameters sslParameters = sslServerSocket.getSSLParameters();
            applyCipherOrdering(sslParameters);
            sslParameters.setCipherSuites(ciphers);
            sslServerSocket.setSSLParameters(sslParameters);
        }
    }

    public ServerSocket createServerSocket(int port, int backlog, InetAddress bindAddress) throws IOException {
        SSLServerSocketFactory sslssf = sslc.getServerSocketFactory();
        SSLServerSocket sslss = (SSLServerSocket) sslssf.createServerSocket(port, backlog, bindAddress);
        applyCiphers(sslss);
        if (protocols != null) {
            sslss.setEnabledProtocols(protocols);
        } else {
            String[] protocols = sslss.getEnabledProtocols();
            Set set = new HashSet();
            for (String protocol : protocols) {
                if (protocol.equals("SSLv3") || protocol.equals("SSLv2Hello")) {
                    continue;
                }
                set.add(protocol);
            }
            sslss.setEnabledProtocols(set.toArray(new String[0]));
        }
        sslss.setWantClientAuth(wantClientAuth);
        sslss.setNeedClientAuth(needClientAuth);
        return sslss;
    }

    public Socket wrapAcceptedSocket(Socket socket) throws IOException {
        return socket;
    }

    public Socket createSocket(Socket socket, String host, int port, int connectTimeout, @Nullable String sniServerName) throws IOException {
        SSLSocketFactory sslsf = sslc.getSocketFactory();
        SSLSocket ssls = (SSLSocket) sslsf.createSocket(socket, host, port, true);
        applySNI(ssls, sniServerName,host);
        if (protocols != null) {
            ssls.setEnabledProtocols(protocols);
        } else {
            String[] protocols = ssls.getEnabledProtocols();
            Set set = new HashSet();
            for (String protocol : protocols) {
                if (protocol.equals("SSLv3") || protocol.equals("SSLv2Hello")) {
                    continue;
                }
                set.add(protocol);
            }
            ssls.setEnabledProtocols(set.toArray(new String[0]));
        }
        applyCiphers(ssls);
        return ssls;
    }

    public Socket createSocket(String host, int port, int connectTimeout, @Nullable String sniServerName) throws IOException {
        Socket s = new Socket();
        s.connect(new InetSocketAddress(host, port), connectTimeout);
        return createSocket(s, host, port, connectTimeout, sniServerName);
    }

    public Socket createSocket(String host, int port, InetAddress addr, int localPort, int connectTimeout, @Nullable String sniServerName) throws IOException {
        Socket s = new Socket();
        s.bind(new InetSocketAddress(addr, localPort));
        s.connect(new InetSocketAddress(host, port), connectTimeout);
        return createSocket(s, host, port, connectTimeout, sniServerName);
    }

    private void applySNI(@NotNull SSLSocket ssls, @Nullable String sniServerName, @NotNull String defaultHost) {
        if(sniServerName != null && sniServerName.isEmpty())
            return;
        if(sniServerName == null)
            sniServerName = defaultHost;

        SNIHostName name = new SNIHostName(sniServerName.getBytes()); // mvn complains here when not putting in "bytes" even though there is a constructor for "string"
        List serverNames = new ArrayList<>(1);
        serverNames.add(name);

        SSLParameters params = ssls.getSSLParameters();
        params.setServerNames(serverNames);
        ssls.setSSLParameters(params);
    }

    SSLSocketFactory getSocketFactory() {
        return sslc.getSocketFactory();
    }

    List getDnsNames() {
        return dnsNames;
    }

    /**
     *
     * @return Human-readable description of there the keystore lives.
     */
    String getLocation() {
        return sslParser.getKeyStore() != null ? sslParser.getKeyStore().getLocation() : "null";
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy