All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.undertow.protocols.ssl.ALPNHackSSLEngine Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote Jakarta Enterprise Beans and Jakarta Messaging, including all dependencies. It is intended for use by those not using maven, maven users should just import the Jakarta Enterprise Beans and Jakarta Messaging BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

There is a newer version: 35.0.0.Beta1
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.protocols.ssl;

import io.undertow.UndertowLogger;

import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;

/**
 * SSLEngine wrapper that provides some super hacky ALPN support on JDK8.
 *
 * Even though this is a nasty hack that relies on JDK internals it is still preferable to modifying the boot class path.
 *
 * It is expected to work with all JDK8 versions, however this cannot be guaranteed if the SSL internals are changed
 * in an incompatible way.
 *
 * This class will go away once JDK8 is no longer in use.
 *
 * @author Stuart Douglas
 */
public class ALPNHackSSLEngine extends SSLEngine {

    public static final boolean ENABLED;


    private static final Field HANDSHAKER;
    private static final Field HANDSHAKER_PROTOCOL_VERSION;
    private static final Field HANDSHAKE_HASH;
    private static final Field HANDSHAKE_HASH_VERSION;
    private static final Method HANDSHAKE_HASH_UPDATE;
    private static final Method HANDSHAKE_HASH_PROTOCOL_DETERMINED;
    private static final Field HANDSHAKE_HASH_DATA;
    private static final Field HANDSHAKE_HASH_FIN_MD;

    private static final Class SSL_ENGINE_IMPL_CLASS;

    static {

        boolean enabled = true;
        Field handshaker;
        Field handshakeHash;
        Field handshakeHashVersion;
        Field handshakeHashData;
        Field handshakeHashFinMd;
        Field protocolVersion;
        Method handshakeHashUpdate;
        Method handshakeHashProtocolDetermined;
        Class sslEngineImpleClass;
        try {
            Class protocolVersionClass = Class.forName("sun.security.ssl.ProtocolVersion", true, ClassLoader.getSystemClassLoader());
            sslEngineImpleClass = Class.forName("sun.security.ssl.SSLEngineImpl", true, ClassLoader.getSystemClassLoader());
            handshaker = sslEngineImpleClass.getDeclaredField("handshaker");
            handshaker.setAccessible(true);
            handshakeHash = handshaker.getType().getDeclaredField("handshakeHash");
            handshakeHash.setAccessible(true);
            protocolVersion = handshaker.getType().getDeclaredField("protocolVersion");
            protocolVersion.setAccessible(true);
            handshakeHashVersion = handshakeHash.getType().getDeclaredField("version");
            handshakeHashVersion.setAccessible(true);
            handshakeHashUpdate = handshakeHash.getType().getDeclaredMethod("update", byte[].class, int.class, int.class);
            handshakeHashUpdate.setAccessible(true);
            handshakeHashProtocolDetermined = handshakeHash.getType().getDeclaredMethod("protocolDetermined", protocolVersionClass);
            handshakeHashProtocolDetermined.setAccessible(true);
            handshakeHashData = handshakeHash.getType().getDeclaredField("data");
            handshakeHashData.setAccessible(true);
            handshakeHashFinMd = handshakeHash.getType().getDeclaredField("finMD");
            handshakeHashFinMd.setAccessible(true);

        } catch (Exception e) {
            UndertowLogger.ROOT_LOGGER.debug("JDK8 ALPN Hack failed ", e);
            enabled = false;
            handshaker = null;
            handshakeHash = null;
            handshakeHashVersion = null;
            handshakeHashUpdate = null;
            handshakeHashProtocolDetermined = null;
            handshakeHashData = null;
            handshakeHashFinMd = null;
            protocolVersion = null;
            sslEngineImpleClass = null;
        }
        ENABLED = enabled && !Boolean.getBoolean("io.undertow.disable-jdk8-alpn");
        HANDSHAKER = handshaker;
        HANDSHAKE_HASH = handshakeHash;
        HANDSHAKE_HASH_PROTOCOL_DETERMINED = handshakeHashProtocolDetermined;
        HANDSHAKE_HASH_VERSION = handshakeHashVersion;
        HANDSHAKE_HASH_UPDATE = handshakeHashUpdate;
        HANDSHAKE_HASH_DATA = handshakeHashData;
        HANDSHAKE_HASH_FIN_MD = handshakeHashFinMd;
        HANDSHAKER_PROTOCOL_VERSION = protocolVersion;
        SSL_ENGINE_IMPL_CLASS = sslEngineImpleClass;
    }

    private final SSLEngine delegate;

    //ALPN Hack specific variables
    private boolean unwrapHelloSeen = false;
    private boolean ourHelloSent = false;
    private ALPNHackServerByteArrayOutputStream alpnHackServerByteArrayOutputStream;
    private ALPNHackClientByteArrayOutputStream ALPNHackClientByteArrayOutputStream;
    private List applicationProtocols;
    private String selectedApplicationProtocol;
    private ByteBuffer bufferedWrapData;

    public ALPNHackSSLEngine(SSLEngine delegate) {
        this.delegate = delegate;
    }

    public static boolean isEnabled(SSLEngine engine) {
        if(!ENABLED) {
            return false;
        }
        return SSL_ENGINE_IMPL_CLASS.isAssignableFrom(engine.getClass());
    }

    @Override
    public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) throws SSLException {
        if(bufferedWrapData != null) {
            int prod = bufferedWrapData.remaining();
            byteBuffer.put(bufferedWrapData);
            bufferedWrapData = null;
            return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, prod);
        }
        int pos = byteBuffer.position();
        int limit = byteBuffer.limit();
        SSLEngineResult res =  delegate.wrap(byteBuffers, i, i1, byteBuffer);
        if(!ourHelloSent && res.bytesProduced() > 0) {
            if(delegate.getUseClientMode() && applicationProtocols != null && !applicationProtocols.isEmpty()) {
                ourHelloSent = true;
                ALPNHackClientByteArrayOutputStream = replaceClientByteOutput(delegate);
                ByteBuffer newBuf = byteBuffer.duplicate();
                newBuf.flip();
                byte[] data = new byte[newBuf.remaining()];
                newBuf.get(data);
                byte[] newData = ALPNHackClientHelloExplorer.rewriteClientHello(data, applicationProtocols);
                if(newData != null) {
                    byte[] clientHelloMesage = new byte[newData.length - 5];
                    System.arraycopy(newData, 5, clientHelloMesage, 0 , clientHelloMesage.length);
                    ALPNHackClientByteArrayOutputStream.setSentClientHello(clientHelloMesage);
                    byteBuffer.clear();
                    byteBuffer.put(newData);
                }
            } else if (!getUseClientMode()) {
                if(selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream != null) {
                    byte[] newServerHello = alpnHackServerByteArrayOutputStream.getServerHello(); //this is the new server hello, it will be part of the first TLS plaintext record
                    if (newServerHello != null) {
                        byteBuffer.flip();
                        List records = ALPNHackServerHelloExplorer.extractRecords(byteBuffer);
                        ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newServerHello, records);
                        byteBuffer.position(pos); //erase the data
                        byteBuffer.limit(limit);
                        if (newData.remaining() > byteBuffer.remaining()) {
                            int old = newData.limit();
                            newData.limit(newData.position() + byteBuffer.remaining());
                            res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining());
                            byteBuffer.put(newData);
                            newData.limit(old);
                            bufferedWrapData = newData;
                        } else {
                            res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining());
                            byteBuffer.put(newData);
                        }
                    }
                }
            }
        }
        if(res.bytesProduced() > 0) {
            ourHelloSent = true;
        }
        return res;
    }

    @Override
    public SSLEngineResult unwrap(ByteBuffer dataToUnwrap, ByteBuffer[] byteBuffers, int i, int i1) throws SSLException {
        if(!unwrapHelloSeen) {
            if(!delegate.getUseClientMode() && applicationProtocols != null) {
                try {
                    List result = ALPNHackClientHelloExplorer.exploreClientHello(dataToUnwrap.duplicate());
                    if(result != null) {
                        for(String protocol : applicationProtocols) {
                            if(result.contains(protocol)) {
                                selectedApplicationProtocol = protocol;
                                break;
                            }
                        }
                    }
                    unwrapHelloSeen = true;
                } catch (BufferUnderflowException e) {
                    return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0);
                }
            } else if(delegate.getUseClientMode() && ALPNHackClientByteArrayOutputStream != null) {
                if(!dataToUnwrap.hasRemaining()) {
                    return delegate.unwrap(dataToUnwrap, byteBuffers, i, i1);
                }
                try {
                    ByteBuffer dup = dataToUnwrap.duplicate();
                    int type = dup.get();
                    int major = dup.get();
                    int minor = dup.get();
                    if(type == 22 && major == 3 && minor == 3) {
                        //we only care about TLS 1.2
                        //split up the records, there may be multiple when doing a fast session resume
                        List records = ALPNHackServerHelloExplorer.extractRecords(dataToUnwrap.duplicate());

                        ByteBuffer firstRecord = records.get(0); //this will be the handshake record

                        final AtomicReference alpnResult = new AtomicReference<>();
                        ByteBuffer dupFirst = firstRecord.duplicate();
                        dupFirst.position(firstRecord.position() + 5);
                        ByteBuffer firstLessFraming = dupFirst.duplicate();

                        byte[] result = ALPNHackServerHelloExplorer.removeAlpnExtensionsFromServerHello(dupFirst, alpnResult);
                        firstLessFraming.limit(dupFirst.position());
                        unwrapHelloSeen = true;
                        if (result != null) {
                            selectedApplicationProtocol = alpnResult.get();
                            int newFirstRecordLength = result.length + dupFirst.remaining();
                            byte[] newFirstRecord = new byte[newFirstRecordLength];
                            System.arraycopy(result, 0, newFirstRecord, 0, result.length);
                            dupFirst.get(newFirstRecord, result.length, dupFirst.remaining());
                            dataToUnwrap.position(dataToUnwrap.limit());

                            byte[] originalFirstRecord = new byte[firstLessFraming.remaining()];
                            firstLessFraming.get(originalFirstRecord);

                            ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newFirstRecord, records);
                            dataToUnwrap.clear();
                            dataToUnwrap.put(newData);
                            dataToUnwrap.flip();
                            ALPNHackClientByteArrayOutputStream.setReceivedServerHello(originalFirstRecord);
                        }
                    }
                } catch (BufferUnderflowException e) {
                    return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0);
                }
            }
        }
        SSLEngineResult res = delegate.unwrap(dataToUnwrap, byteBuffers, i, i1);
        if(!delegate.getUseClientMode() && selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream == null) {
            alpnHackServerByteArrayOutputStream = replaceServerByteOutput(delegate, selectedApplicationProtocol);
        }
        return res;
    }

    @Override
    public Runnable getDelegatedTask() {
        return delegate.getDelegatedTask();
    }

    @Override
    public void closeInbound() throws SSLException {
        delegate.closeInbound();
    }

    @Override
    public boolean isInboundDone() {
        return delegate.isInboundDone();
    }

    @Override
    public void closeOutbound() {
        delegate.closeOutbound();
    }

    @Override
    public boolean isOutboundDone() {
        return delegate.isOutboundDone();
    }

    @Override
    public String[] getSupportedCipherSuites() {
        return delegate.getSupportedCipherSuites();
    }

    @Override
    public String[] getEnabledCipherSuites() {
        return delegate.getEnabledCipherSuites();
    }

    @Override
    public void setEnabledCipherSuites(String[] strings) {
        delegate.setEnabledCipherSuites(strings);
    }

    @Override
    public String[] getSupportedProtocols() {
        return delegate.getSupportedProtocols();
    }

    @Override
    public String[] getEnabledProtocols() {
        return delegate.getEnabledProtocols();
    }

    @Override
    public void setEnabledProtocols(String[] strings) {
        delegate.setEnabledProtocols(strings);
    }

    @Override
    public SSLSession getSession() {
        return delegate.getSession();
    }

    @Override
    public void beginHandshake() throws SSLException {
        delegate.beginHandshake();
    }

    @Override
    public SSLEngineResult.HandshakeStatus getHandshakeStatus() {
        return delegate.getHandshakeStatus();
    }

    @Override
    public void setUseClientMode(boolean b) {
        delegate.setUseClientMode(b);
    }

    @Override
    public boolean getUseClientMode() {
        return delegate.getUseClientMode();
    }

    @Override
    public void setNeedClientAuth(boolean b) {
        delegate.setNeedClientAuth(b);
    }

    @Override
    public boolean getNeedClientAuth() {
        return delegate.getNeedClientAuth();
    }

    @Override
    public void setWantClientAuth(boolean b) {
        delegate.setWantClientAuth(b);
    }

    @Override
    public boolean getWantClientAuth() {
        return delegate.getWantClientAuth();
    }

    @Override
    public void setEnableSessionCreation(boolean b) {
        delegate.setEnableSessionCreation(b);
    }

    @Override
    public boolean getEnableSessionCreation() {
        return delegate.getEnableSessionCreation();
    }

    /**
     * JDK8 ALPN hack support method.
     *
     * These methods will be removed once JDK8 ALPN support is no longer required
     * @param applicationProtocols
     */
    public void setApplicationProtocols(List applicationProtocols) {
        this.applicationProtocols = applicationProtocols;
    }

    /**
     * JDK8 ALPN hack support method.
     *
     * These methods will be removed once JDK8 ALPN support is no longer required
     */
    public List getApplicationProtocols() {
        return applicationProtocols;
    }

    /**
     * JDK8 ALPN hack support method.
     *
     * These methods will be removed once JDK8 ALPN support is no longer required
     */
    public String getSelectedApplicationProtocol() {
        return selectedApplicationProtocol;
    }


    static ALPNHackServerByteArrayOutputStream replaceServerByteOutput(SSLEngine sslEngine, String selectedAlpnProtocol) {
        try {
            Object handshaker = HANDSHAKER.get(sslEngine);
            Object hash = HANDSHAKE_HASH.get(handshaker);
            ByteArrayOutputStream existing = (ByteArrayOutputStream) HANDSHAKE_HASH_DATA.get(hash);

            ALPNHackServerByteArrayOutputStream out = new ALPNHackServerByteArrayOutputStream(sslEngine, existing.toByteArray(), selectedAlpnProtocol);
            HANDSHAKE_HASH_DATA.set(hash, out);
            return out;
        } catch (Exception e) {
            UndertowLogger.ROOT_LOGGER.debug("Failed to replace hash output stream ", e);
            return null;
        }
    }

    static ALPNHackClientByteArrayOutputStream replaceClientByteOutput(SSLEngine sslEngine) {
        try {
            Object handshaker = HANDSHAKER.get(sslEngine);
            Object hash = HANDSHAKE_HASH.get(handshaker);

            ALPNHackClientByteArrayOutputStream out = new ALPNHackClientByteArrayOutputStream(sslEngine);
            HANDSHAKE_HASH_DATA.set(hash, out);
            return out;
        } catch (Exception e) {
            UndertowLogger.ROOT_LOGGER.debug("Failed to replace hash output stream ", e);
            return null;
        }
    }
    static void regenerateHashes(SSLEngine sslEngineToHack, ByteArrayOutputStream data, byte[]... hashBytes) {
        //hack up the SSL engine internal state
        try {
            Object handshaker = HANDSHAKER.get(sslEngineToHack);
            Object hash = HANDSHAKE_HASH.get(handshaker);
            data.reset();
            Object protocolVersion = HANDSHAKER_PROTOCOL_VERSION.get(handshaker);
            HANDSHAKE_HASH_VERSION.set(hash, -1);
            HANDSHAKE_HASH_PROTOCOL_DETERMINED.invoke(hash, protocolVersion);
            MessageDigest digest = (MessageDigest) HANDSHAKE_HASH_FIN_MD.get(hash);
            digest.reset();
            for (byte[] b : hashBytes) {
                HANDSHAKE_HASH_UPDATE.invoke(hash, b, 0, b.length);
            }
        } catch (Exception e) {
            e.printStackTrace(); //TODO: remove
            throw new RuntimeException(e);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy