All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wildfly.security.dynamic.ssl.DynamicSSLSocketFactory Maven / Gradle / Ivy

The newest version!
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2024 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.wildfly.security.dynamic.ssl;

import org.wildfly.common.Assert;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * SSLSocketFactory that is being used by DynamicSSLContext.
 *
 * @author Diana Krepinska
 */
final class DynamicSSLSocketFactory extends SSLSocketFactory {

    private DynamicSSLContextSPI dynamicSSLContextImpl;
    private volatile String[] intersectionCipherSuite;
    private SSLSocketFactory configuredDefaultSslSocketFactory;

    DynamicSSLSocketFactory(SSLSocketFactory configuredDefaultSslSocketFactory, DynamicSSLContextSPI dynamicSSLContextImpl) {
        super();
        Assert.assertNotNull(configuredDefaultSslSocketFactory);
        Assert.assertNotNull(dynamicSSLContextImpl);
        this.configuredDefaultSslSocketFactory = configuredDefaultSslSocketFactory;
        this.dynamicSSLContextImpl = dynamicSSLContextImpl;
    }

    @Override
    public Socket createSocket() throws IOException {
        return configuredDefaultSslSocketFactory.createSocket();
    }

    @Override
    public Socket createSocket(InetAddress address, int port) throws IOException {
        return createSocketBasedOnPeerInfo(null, port, address, null, null, null, null);
    }

    @Override
    public Socket createSocket(String host, int port) throws IOException {
        return createSocketBasedOnPeerInfo(host, port, null, null, null, null, null);
    }

    @Override
    public Socket createSocket(String host, int port, InetAddress localAddress, int localPort) throws IOException {
        return createSocketBasedOnPeerInfo(host, port, null, localAddress, localPort, null, null);
    }

    @Override
    public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException {
        return createSocketBasedOnPeerInfo(null, port, address, localAddress, localPort, null, null);
    }

    @Override
    public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException {
        return createSocketBasedOnPeerInfo(host, port, null, null, null, socket, autoClose);
    }

    @Override
    public String[] getDefaultCipherSuites() {
        return configuredDefaultSslSocketFactory.getDefaultCipherSuites();
    }

    @Override
    public String[] getSupportedCipherSuites() {
        String[] val = intersectionCipherSuite;
        if (val == null) {
            synchronized (this) {
                val = intersectionCipherSuite;
                if (intersectionCipherSuite == null) {
                    val = intersectionCipherSuite = getIntersection();
                }
            }
        }
        return val;
    }

    private Socket createSocketBasedOnPeerInfo(String hostname, Integer port, InetAddress address, InetAddress localAddress, Integer localPort, Socket socket, Boolean autoClose) throws IOException {
        try {
            SSLContext sslContext = this.dynamicSSLContextImpl.getSSLContext(new URI(null, null, hostname == null ? address.getHostName() : hostname, port, null, null, null));
            if (sslContext == null) {
                throw ElytronMessages.log.configuredSSLContextIsNull();
            }
            SSLSocketFactory socketFactory = sslContext.getSocketFactory();
            if (socketFactory instanceof DynamicSSLSocketFactory && socketFactory.equals(this)) {
                throw ElytronMessages.log.dynamicSSLContextCreatesLoop();
            }
            // resolve socket
            if (socket != null && autoClose != null) {
                return socketFactory.createSocket(socket, hostname, port, autoClose);
            }

            // resolves InetAddresses callbacks
            if (address != null) {
                return localAddress == null ?
                        socketFactory.createSocket(address, port) : socketFactory.createSocket(address, port, localAddress, localPort);
            }
            if (localAddress != null && localPort != null) {
                return socketFactory.createSocket(hostname, port, localAddress, localPort);
            }
            return socketFactory.createSocket(hostname, port);
        } catch (URISyntaxException e) {
            throw new UnknownHostException(e.getMessage());
        } catch (DynamicSSLContextException e) {
            throw new IOException(e);
        }
    }

    private String[] getIntersection() {
        List sslContexts;
        try {
            sslContexts = dynamicSSLContextImpl.getConfiguredSSLContexts();
        } catch (DynamicSSLContextException e) {
            throw ElytronMessages.log.unableToGetConfiguredSSLContexts();
        }
        if (sslContexts == null) {
            throw ElytronMessages.log.configuredSSLContextsAreNull();
        }
        Map counts = new HashMap<>();
        List intersection = new ArrayList<>();
        sslContexts.forEach(c -> {
            String[] cipherSuites = c.getSocketFactory().getSupportedCipherSuites();
            for (String cipherSuite : cipherSuites) {
                counts.merge(cipherSuite, 1, (a, b) -> a + b);
            }
        });
        List finalSslContexts = sslContexts;
        counts.forEach((c, v) -> {
            if (finalSslContexts.size() == v) {
                intersection.add(c);
            }
        });
        return intersection.toArray(new String[0]);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy