All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.ssl.SSLUtils Maven / Gradle / Ivy

There is a newer version: 2.5.2.Final
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.ssl;

import java.net.IDN;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.Provider.Service;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIMatcher;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLSession;
import javax.net.ssl.StandardConstants;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;

import org.wildfly.common.Assert;
import org.wildfly.security.OneTimeSecurityFactory;
import org.wildfly.security.SecurityFactory;
import org.wildfly.security.auth.server.SecurityIdentity;

import static org.wildfly.security.ssl.ElytronMessages.log;

/**
 * SSL factories and utilities.
 *
 * @author David M. Lloyd
 */
public final class SSLUtils {

    private static final String[] NO_STRINGS = new String[0];

    private SSLUtils() {}

    private static final String SERVICE_TYPE = SSLContext.class.getSimpleName();

    /**
     * The key used to store the authenticated {@link SecurityIdentity} onto the {@link SSLSession}.
     */
    public static final String SSL_SESSION_IDENTITY_KEY = "org.wildfly.security.ssl.identity";

    /**
     * Create an SSL context factory which locates the best context by searching the preferred providers in order using
     * the rules established in the given protocol selector.  If there are no matches, a factory is returned which
     *
     * @param protocolSelector the protocol selector
     * @param providerSupplier the provider supplier
     * @return the SSL context factory
     */
    public static SecurityFactory createSslContextFactory(ProtocolSelector protocolSelector, Supplier providerSupplier) {
        return createSslContextFactory(protocolSelector, providerSupplier, null);
    }

    /**
     * Create an SSL context factory which locates the best context by searching the preferred providers in order using
     * the rules established in the given protocol selector.  If there are no matches, a factory is returned which
     *
     * @param protocolSelector the protocol selector
     * @param providerSupplier the provider supplier
     * @param providerName the provider name to select, or {@code null} to allow any
     * @return the SSL context factory
     */
    public static SecurityFactory createSslContextFactory(ProtocolSelector protocolSelector, Supplier providerSupplier, String providerName) {
        final Map> preferredProviderByAlgorithm = new HashMap<>();

        // compile all the providers that support SSLContext.

        for (Provider provider : providerSupplier.get()) {
            // if a provider name was given, filter by it
            if (providerName != null && ! providerName.equals(provider.getName())) {
                continue;
            }
            Set services = provider.getServices();
            if (services != null) {
                for (Provider.Service service : services) {
                    if (SERVICE_TYPE.equals(service.getType())) {
                        String protocolName = service.getAlgorithm();
                        List providerList = preferredProviderByAlgorithm.computeIfAbsent(protocolName.toUpperCase(Locale.ENGLISH), s -> new ArrayList<>());
                        providerList.add(provider);

                        if (log.isTraceEnabled()) {
                            log.tracef("Provider %s was added for algorithm %s", provider, protocolName.toUpperCase(Locale.ENGLISH));
                        }
                    }
                }
            }
        }

        // now return a factory that will return the best match is can create.
        final String[] supportedProtocols = protocolSelector.evaluate(preferredProviderByAlgorithm.keySet().toArray(NO_STRINGS));
        if (log.isTraceEnabled()) {
            log.tracef("Supported protocols are: %s", Arrays.toString(supportedProtocols));
        }
        if (supportedProtocols.length > 0) {
            return () -> {
                for (String protocol : supportedProtocols) {
                    List providerList = preferredProviderByAlgorithm.getOrDefault(protocol.toUpperCase(Locale.ENGLISH), Collections.emptyList());
                    if (log.isTraceEnabled()) {
                        if (providerList.isEmpty()) {
                            log.tracef("No providers are available for protocol %s", protocol);
                        }
                    }
                    for (Provider provider : providerList) {
                        try {
                            if (log.isTraceEnabled()) {
                                log.tracef("Attempting to get an SSLContext instance using protocol %s and provider %s", protocol, provider);
                            }
                            return SSLContext.getInstance(protocol, provider);
                        } catch (NoSuchAlgorithmException ignored) {
                            if (log.isTraceEnabled()) {
                                log.tracef(ignored, "Provider %s has no such protocol %s", provider, protocol);
                            }
                        }
                    }
                }

                if (log.isTraceEnabled()) {
                    log.tracef("No %s provided by providers in %s: %s", SERVICE_TYPE, SSLUtils.class.getSimpleName(), Arrays.toString(providerSupplier.get()));
                }

                throw ElytronMessages.log.noAlgorithmForSslProtocol();
            };
        }

        if (log.isTraceEnabled()) {
            log.tracef("No %s provided by providers in %s: %s", SERVICE_TYPE, SSLUtils.class.getSimpleName(), Arrays.toString(providerSupplier.get()));
        }

        return SSLUtils::throwIt;
    }

    private static SSLContext throwIt() throws NoSuchAlgorithmException {
        throw ElytronMessages.log.noAlgorithmForSslProtocol();
    }

    /**
     * Create a simple security factory for SSL contexts.
     *
     * @param protocol the protocol name
     * @param provider the provider to use
     * @return the SSL context factory
     */
    public static SecurityFactory createSimpleSslContextFactory(String protocol, Provider provider) {
        return () -> SSLContext.getInstance(protocol, provider);
    }

    /**
     * Create a configured SSL context from an outside SSL context.
     *
     * @param original the original SSL context
     * @param sslConfigurator the SSL configurator
     * @return the configured SSL context
     */
    public static SSLContext createConfiguredSslContext(SSLContext original, final SSLConfigurator sslConfigurator) {
        return createConfiguredSslContext(original, sslConfigurator, true);
    }

    /**
     * Create a configured SSL context from an outside SSL context.
     *
     * @param original the original SSL context
     * @param sslConfigurator the SSL configurator
     * @param wrap should the resulting SSLEngine, SSLSocket, and SSLServerSocket instances be wrapped using the configurator.
     * @return the configured SSL context
     */
    public static SSLContext createConfiguredSslContext(SSLContext original, final SSLConfigurator sslConfigurator, final boolean wrap) {
        return new DelegatingSSLContext(new ConfiguredSSLContextSpi(original, sslConfigurator, wrap));
    }

    /**
     * Create a configured SSL context factory from an outside SSL context.  The returned factory will create new instances
     * for every call, so it might be necessary to wrap with a {@link OneTimeSecurityFactory} instance.
     *
     * @param originalFactory the original SSL context factory
     * @param sslConfigurator the SSL configurator
     * @return the configured SSL context
     */
    public static SecurityFactory createConfiguredSslContextFactory(SecurityFactory originalFactory, final SSLConfigurator sslConfigurator) {
        return () -> createConfiguredSslContext(originalFactory.create(), sslConfigurator);
    }

    private static final SecurityFactory DEFAULT_TRUST_MANAGER_SECURITY_FACTORY = new OneTimeSecurityFactory<>(() -> {
        final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
        trustManagerFactory.init((KeyStore) null);
        for (TrustManager trustManager : trustManagerFactory.getTrustManagers()) {
            if (trustManager instanceof X509TrustManager) {
                return (X509TrustManager) trustManager;
            }
        }
        throw ElytronMessages.log.noDefaultTrustManager();
    });

    /**
     * Get the platform's default X.509 trust manager security factory.  The factory caches the instance.
     *
     * @return the security factory for the default trust manager
     */
    public static SecurityFactory getDefaultX509TrustManagerSecurityFactory() {
        return DEFAULT_TRUST_MANAGER_SECURITY_FACTORY;
    }

    /**
     * Get a server SSL engine which dispatches to the appropriate SSL context based on the information in the
     * SSL greeting.
     *
     * @param selector the context selector to use (cannot be {@code null})
     * @return the SSL engine (not {@code null})
     */
    public static SSLEngine createSelectingSSLEngine(SSLContextSelector selector) {
        Assert.checkNotNullParam("selector", selector);
        return new SelectingServerSSLEngine(selector);
    }

    /**
     * Get a server SSL engine which dispatches to the appropriate SSL context based on the information in the
     * SSL greeting.
     *
     * @param selector the context selector to use (cannot be {@code null})
     * @param host the advisory host name
     * @param port the advisory port number
     * @return the SSL engine (not {@code null})
     */
    public static SSLEngine createSelectingSSLEngine(SSLContextSelector selector, String host, int port) {
        Assert.checkNotNullParam("selector", selector);
        return new SelectingServerSSLEngine(selector, host, port);
    }

    /**
     * Create an {@code SNIMatcher} which matches SNI host names that satisfy the given predicate.
     *
     * @param predicate the predicate (must not be {@code null})
     * @return the SNI matcher (not {@code null})
     */
    public static SNIMatcher createHostNamePredicateSNIMatcher(Predicate predicate) {
        Assert.checkNotNullParam("predicate", predicate);
        return new SNIMatcher(StandardConstants.SNI_HOST_NAME) {
            public boolean matches(final SNIServerName sniServerName) {
                return sniServerName instanceof SNIHostName && predicate.test((SNIHostName) sniServerName);
            }
        };
    }

    /**
     * Create an {@code SNIMatcher} which matches SNI host name strings that satisfy the given predicate.
     *
     * @param predicate the predicate (must not be {@code null})
     * @return the SNI matcher (not {@code null})
     * @see IDN
     */
    public static SNIMatcher createHostNameStringPredicateSNIMatcher(Predicate predicate) {
        Assert.checkNotNullParam("predicate", predicate);
        return new SNIMatcher(StandardConstants.SNI_HOST_NAME) {
            public boolean matches(final SNIServerName sniServerName) {
                return sniServerName instanceof SNIHostName && predicate.test(((SNIHostName) sniServerName).getAsciiName());
            }
        };
    }

    /**
     * Create an {@code SNIMatcher} which matches SNI host names that are equal to the given (ASCII) string.
     *
     * @param string the host name string (must not be {@code null})
     * @return the SNI matcher (not {@code null})
     * @see IDN
     */
    public static SNIMatcher createHostNameStringSNIMatcher(String string) {
        Assert.checkNotNullParam("string", string);
        return createHostNameStringPredicateSNIMatcher(string::equals);
    }

    /**
     * Create an {@code SNIMatcher} which matches SNI host name strings which end with the given suffix.
     *
     * @param suffix the suffix to match (must not be {@code null} or empty)
     * @return the SNI matcher (not {@code null})
     */
    public static SNIMatcher createHostNameSuffixSNIMatcher(String suffix) {
        Assert.checkNotNullParam("suffix", suffix);
        Assert.checkNotEmptyParam("suffix", suffix);
        final String finalSuffix = suffix.startsWith(".") ? suffix : "." + suffix;
        return createHostNameStringPredicateSNIMatcher(n -> n.endsWith(finalSuffix));
    }

    /**
     * Get a factory which produces SSL engines which dispatch to the appropriate SSL context based on the information
     * in the SSL greeting.
     *
     * @param selector the context selector to use (cannot be {@code null})
     * @return the SSL engine factory (not {@code null})
     */
    public static SecurityFactory createDispatchingSSLEngineFactory(SSLContextSelector selector) {
        Assert.checkNotNullParam("selector", selector);
        return () -> new SelectingServerSSLEngine(selector);
    }

    /**
     * Get the value of the given key from the SSL session, or a default value if the key is not set.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to retrieve (must not be {@code null})
     * @param defaultValue the value to return if the key is not present
     * @return the session value or the default value
     */
    public static Object getOrDefault(SSLSession sslSession, String key, Object defaultValue) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        final Object value = sslSession.getValue(key);
        return value != null ? value : defaultValue;
    }

    /**
     * Put a value on the session if the value is not yet set.  This method is atomic with respect to other methods
     * on this class.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to retrieve (must not be {@code null})
     * @param newValue the value to set (must not be {@code null})
     * @return the existing value, or {@code null} if the value was successfully set
     */
    public static Object putSessionValueIfAbsent(SSLSession sslSession, String key, Object newValue) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        Assert.checkNotNullParam("newValue", newValue);
        synchronized (sslSession) {
            final Object existing = sslSession.getValue(key);
            if (existing == null) {
                sslSession.putValue(key, newValue);
                return null;
            } else {
                return existing;
            }
        }
    }

    /**
     * Remove and return a value on the session.  This method is atomic with respect to other methods on this class.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to retrieve (must not be {@code null})
     * @return the existing value, or {@code null} if no such value was set
     */
    public static Object removeSessionValue(SSLSession sslSession, String key) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        synchronized (sslSession) {
            final Object existing = sslSession.getValue(key);
            sslSession.removeValue(key);
            return existing;
        }
    }

    /**
     * Remove the given key-value pair on the session.  This method is atomic with respect to other methods on this class.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to remove (must not be {@code null})
     * @param value the value to remove (must not be {@code null})
     * @return {@code true} if the key/value pair was removed, {@code false} if the key was not present or the value was not equal to the given value
     */
    public static boolean removeSessionValue(SSLSession sslSession, String key, Object value) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        Assert.checkNotNullParam("value", value);
        synchronized (sslSession) {
            final Object existing = sslSession.getValue(key);
            if (Objects.equals(existing, value)) {
                sslSession.removeValue(key);
                return true;
            } else {
                return false;
            }
        }
    }

    /**
     * Replace the given key's value with a new value.  If there is no value for the given key, no action is performed.
     * This method is atomic with respect to other methods on this class.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to retrieve (must not be {@code null})
     * @param newValue the value to set (must not be {@code null})
     * @return the existing value, or {@code null} if the value was not set
     */
    public static Object replaceSessionValue(SSLSession sslSession, String key, Object newValue) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        Assert.checkNotNullParam("newValue", newValue);
        synchronized (sslSession) {
            final Object existing = sslSession.getValue(key);
            if (existing != null) sslSession.putValue(key, newValue);
            return existing;
        }
    }

    /**
     * Replace the given key's value with a new value if (and only if) it is mapped to the given existing value.
     * This method is atomic with respect to other methods on this class.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to retrieve (must not be {@code null})
     * @param oldValue the value to match (must not be {@code null})
     * @param newValue the value to set (must not be {@code null})
     * @return {@code true} if the value was matched and replaced, or {@code false} if the value did not match and no action was taken
     */
    public static boolean replaceSessionValue(SSLSession sslSession, String key, Object oldValue, Object newValue) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        Assert.checkNotNullParam("oldValue", oldValue);
        Assert.checkNotNullParam("newValue", newValue);
        synchronized (sslSession) {
            final Object existing = sslSession.getValue(key);
            if (Objects.equals(existing, oldValue)) {
                sslSession.putValue(key, newValue);
                return true;
            } else {
                return false;
            }
        }
    }

    /**
     * Get or compute the value for the given key, storing the computed value (if one is generated).  The function
     * must not generate a {@code null} value or an unspecified exception will result.
     *
     * @param sslSession the SSL session (must not be {@code null})
     * @param key the key to retrieve (must not be {@code null})
     * @param mappingFunction the function to apply to acquire the value (must not be {@code null})
     * @return the stored or new value (not {@code null})
     */
    public static  R computeIfAbsent(SSLSession sslSession, String key, Function mappingFunction) {
        Assert.checkNotNullParam("sslSession", sslSession);
        Assert.checkNotNullParam("key", key);
        Assert.checkNotNullParam("mappingFunction", mappingFunction);
        synchronized (sslSession) {
            final R existing = (R) sslSession.getValue(key);
            if (existing == null) {
                R newValue = mappingFunction.apply(key);
                Assert.assertNotNull(newValue);
                sslSession.putValue(key, newValue);
                return newValue;
            } else {
                return existing;
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy