All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.undertow.protocols.ssl.ALPNHackSSLEngine Maven / Gradle / Ivy

There is a newer version: 62
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.protocols.ssl;

import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;

import java.io.ByteArrayOutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;

/**
 * SSLEngine wrapper that provides some super hacky ALPN support on JDK8.
 *
 * Even though this is a nasty hack that relies on JDK internals it is still preferable to modifying the boot class path.
 *
 * It is expected to work with all JDK8 versions, however this cannot be guaranteed if the SSL internals are changed
 * in an incompatible way.
 *
 * This class will go away once JDK8 is no longer in use.
 *
 * @author Stuart Douglas
 */
public class ALPNHackSSLEngine extends SSLEngine {

    public static final boolean ENABLED;


    private static final Field HANDSHAKER;
    private static final Field HANDSHAKER_PROTOCOL_VERSION;
    private static final Field HANDSHAKE_HASH;
    private static final Field HANDSHAKE_HASH_VERSION;
    private static final Method HANDSHAKE_HASH_UPDATE;
    private static final Method HANDSHAKE_HASH_PROTOCOL_DETERMINED;
    private static final Field HANDSHAKE_HASH_DATA;
    private static final Field HANDSHAKE_HASH_FIN_MD;

    private static final Class SSL_ENGINE_IMPL_CLASS;

    static {

        boolean enabled = true;
        Field handshaker;
        Field handshakeHash;
        Field handshakeHashVersion;
        Field handshakeHashData;
        Field handshakeHashFinMd;
        Field protocolVersion;
        Method handshakeHashUpdate;
        Method handshakeHashProtocolDetermined;
        Class sslEngineImpleClass;
        try {
            Class protocolVersionClass = Class.forName("sun.security.ssl.ProtocolVersion", true, ClassLoader.getSystemClassLoader());
            sslEngineImpleClass = Class.forName("sun.security.ssl.SSLEngineImpl", true, ClassLoader.getSystemClassLoader());
            handshaker = sslEngineImpleClass.getDeclaredField("handshaker");
            handshaker.setAccessible(true);
            handshakeHash = handshaker.getType().getDeclaredField("handshakeHash");
            handshakeHash.setAccessible(true);
            protocolVersion = handshaker.getType().getDeclaredField("protocolVersion");
            protocolVersion.setAccessible(true);
            handshakeHashVersion = handshakeHash.getType().getDeclaredField("version");
            handshakeHashVersion.setAccessible(true);
            handshakeHashUpdate = handshakeHash.getType().getDeclaredMethod("update", byte[].class, int.class, int.class);
            handshakeHashUpdate.setAccessible(true);
            handshakeHashProtocolDetermined = handshakeHash.getType().getDeclaredMethod("protocolDetermined", protocolVersionClass);
            handshakeHashProtocolDetermined.setAccessible(true);
            handshakeHashData = handshakeHash.getType().getDeclaredField("data");
            handshakeHashData.setAccessible(true);
            handshakeHashFinMd = handshakeHash.getType().getDeclaredField("finMD");
            handshakeHashFinMd.setAccessible(true);

        } catch (Exception e) {
            UndertowLogger.ROOT_LOGGER.debug("JDK8 ALPN Hack failed ", e);
            enabled = false;
            handshaker = null;
            handshakeHash = null;
            handshakeHashVersion = null;
            handshakeHashUpdate = null;
            handshakeHashProtocolDetermined = null;
            handshakeHashData = null;
            handshakeHashFinMd = null;
            protocolVersion = null;
            sslEngineImpleClass = null;
        }
        ENABLED = enabled && !Boolean.getBoolean("io.undertow.disable-jdk8-alpn");
        HANDSHAKER = handshaker;
        HANDSHAKE_HASH = handshakeHash;
        HANDSHAKE_HASH_PROTOCOL_DETERMINED = handshakeHashProtocolDetermined;
        HANDSHAKE_HASH_VERSION = handshakeHashVersion;
        HANDSHAKE_HASH_UPDATE = handshakeHashUpdate;
        HANDSHAKE_HASH_DATA = handshakeHashData;
        HANDSHAKE_HASH_FIN_MD = handshakeHashFinMd;
        HANDSHAKER_PROTOCOL_VERSION = protocolVersion;
        SSL_ENGINE_IMPL_CLASS = sslEngineImpleClass;
    }

    private final SSLEngine delegate;

    //ALPN Hack specific variables
    private boolean unwrapHelloSeen = false;
    private boolean ourHelloSent = false;
    private ALPNHackServerByteArrayOutputStream alpnHackServerByteArrayOutputStream;
    private ALPNHackClientByteArrayOutputStream ALPNHackClientByteArrayOutputStream;
    private List applicationProtocols;
    private String selectedApplicationProtocol;
    private ByteBuffer bufferedWrapData;

    public ALPNHackSSLEngine(SSLEngine delegate) {
        this.delegate = delegate;
    }

    public static boolean isEnabled(SSLEngine engine) {
        if(!ENABLED) {
            return false;
        }
        return SSL_ENGINE_IMPL_CLASS.isAssignableFrom(engine.getClass());
    }

    @Override
    public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) throws SSLException {
        if(bufferedWrapData != null) {
            int prod = bufferedWrapData.remaining();
            byteBuffer.put(bufferedWrapData);
            bufferedWrapData = null;
            return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, prod);
        }
        int pos = byteBuffer.position();
        int limit = byteBuffer.limit();
        SSLEngineResult res =  delegate.wrap(byteBuffers, i, i1, byteBuffer);
        if(!ourHelloSent && res.bytesProduced() > 0) {
            if(delegate.getUseClientMode() && applicationProtocols != null && !applicationProtocols.isEmpty()) {
                ourHelloSent = true;
                ALPNHackClientByteArrayOutputStream = replaceClientByteOutput(delegate);
                ByteBuffer newBuf = byteBuffer.duplicate();
                newBuf.flip();
                byte[] data = new byte[newBuf.remaining()];
                newBuf.get(data);
                byte[] newData = ALPNHackClientHelloExplorer.rewriteClientHello(data, applicationProtocols);
                if(newData != null) {
                    byte[] clientHelloMesage = new byte[newData.length - 5];
                    System.arraycopy(newData, 5, clientHelloMesage, 0 , clientHelloMesage.length);
                    ALPNHackClientByteArrayOutputStream.setSentClientHello(clientHelloMesage);
                    byteBuffer.clear();
                    byteBuffer.put(newData);
                }
            } else if (!getUseClientMode()) {
                if(selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream != null) {
                    byte[] newServerHello = alpnHackServerByteArrayOutputStream.getServerHello(); //this is the new server hello, it will be part of the first TLS plaintext record
                    if (newServerHello != null) {
                        byteBuffer.flip();
                        List records = ALPNHackServerHelloExplorer.extractRecords(byteBuffer);
                        ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newServerHello, records);
                        byteBuffer.position(pos); //erase the data
                        byteBuffer.limit(limit);
                        if (newData.remaining() > byteBuffer.remaining()) {
                            int old = newData.limit();
                            newData.limit(newData.position() + byteBuffer.remaining());
                            res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining());
                            byteBuffer.put(newData);
                            newData.limit(old);
                            bufferedWrapData = newData;
                        } else {
                            res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining());
                            byteBuffer.put(newData);
                        }
                    }
                }
            }
        }
        if(res.bytesProduced() > 0) {
            ourHelloSent = true;
        }
        return res;
    }

    @Override
    public SSLEngineResult unwrap(ByteBuffer dataToUnwrap, ByteBuffer[] byteBuffers, int i, int i1) throws SSLException {
        if(!unwrapHelloSeen) {
            if(!delegate.getUseClientMode() && applicationProtocols != null) {
                try {
                    List result = ALPNHackClientHelloExplorer.exploreClientHello(dataToUnwrap.duplicate());
                    if(result != null) {
                        for(String protocol : applicationProtocols) {
                            if(result.contains(protocol)) {
                                selectedApplicationProtocol = protocol;
                                break;
                            }
                        }
                    }
                    unwrapHelloSeen = true;
                } catch (BufferUnderflowException e) {
                    return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0);
                }
            } else if(delegate.getUseClientMode() && ALPNHackClientByteArrayOutputStream != null) {
                if(!dataToUnwrap.hasRemaining()) {
                    return delegate.unwrap(dataToUnwrap, byteBuffers, i, i1);
                }
                try {
                    ByteBuffer dup = dataToUnwrap.duplicate();
                    int type = dup.get();
                    int major = dup.get();
                    int minor = dup.get();
                    if(type == 22 && major == 3 && minor == 3) {
                        //we only care about TLS 1.2
                        //split up the records, there may be multiple when doing a fast session resume
                        List records = ALPNHackServerHelloExplorer.extractRecords(dataToUnwrap.duplicate());

                        ByteBuffer firstRecord = records.get(0); //this will be the handshake record

                        final AtomicReference alpnResult = new AtomicReference<>();
                        ByteBuffer dupFirst = firstRecord.duplicate();
                        dupFirst.position(firstRecord.position() + 5);
                        ByteBuffer firstLessFraming = dupFirst.duplicate();

                        byte[] result = ALPNHackServerHelloExplorer.removeAlpnExtensionsFromServerHello(dupFirst, alpnResult);
                        firstLessFraming.limit(dupFirst.position());
                        unwrapHelloSeen = true;
                        if (result != null) {
                            selectedApplicationProtocol = alpnResult.get();
                            int newFirstRecordLength = result.length + dupFirst.remaining();
                            byte[] newFirstRecord = new byte[newFirstRecordLength];
                            System.arraycopy(result, 0, newFirstRecord, 0, result.length);
                            dupFirst.get(newFirstRecord, result.length, dupFirst.remaining());
                            dataToUnwrap.position(dataToUnwrap.limit());

                            byte[] originalFirstRecord = new byte[firstLessFraming.remaining()];
                            firstLessFraming.get(originalFirstRecord);

                            ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newFirstRecord, records);
                            dataToUnwrap.clear();
                            dataToUnwrap.put(newData);
                            dataToUnwrap.flip();
                            ALPNHackClientByteArrayOutputStream.setReceivedServerHello(originalFirstRecord);
                        }
                    }
                } catch (BufferUnderflowException e) {
                    return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0);
                }
            }
        }
        SSLEngineResult res = delegate.unwrap(dataToUnwrap, byteBuffers, i, i1);
        if(!delegate.getUseClientMode() && selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream == null) {
            alpnHackServerByteArrayOutputStream = replaceServerByteOutput(delegate, selectedApplicationProtocol);
        }
        return res;
    }

    @Override
    public Runnable getDelegatedTask() {
        return delegate.getDelegatedTask();
    }

    @Override
    public void closeInbound() throws SSLException {
        delegate.closeInbound();
    }

    @Override
    public boolean isInboundDone() {
        return delegate.isInboundDone();
    }

    @Override
    public void closeOutbound() {
        delegate.closeOutbound();
    }

    @Override
    public boolean isOutboundDone() {
        return delegate.isOutboundDone();
    }

    @Override
    public String[] getSupportedCipherSuites() {
        return delegate.getSupportedCipherSuites();
    }

    @Override
    public String[] getEnabledCipherSuites() {
        return delegate.getEnabledCipherSuites();
    }

    @Override
    public void setEnabledCipherSuites(String[] strings) {
        delegate.setEnabledCipherSuites(strings);
    }

    @Override
    public String[] getSupportedProtocols() {
        return delegate.getSupportedProtocols();
    }

    @Override
    public String[] getEnabledProtocols() {
        return delegate.getEnabledProtocols();
    }

    @Override
    public void setEnabledProtocols(String[] strings) {
        delegate.setEnabledProtocols(strings);
    }

    @Override
    public SSLSession getSession() {
        return delegate.getSession();
    }

    @Override
    public void beginHandshake() throws SSLException {
        delegate.beginHandshake();
    }

    @Override
    public SSLEngineResult.HandshakeStatus getHandshakeStatus() {
        return delegate.getHandshakeStatus();
    }

    @Override
    public void setUseClientMode(boolean b) {
        delegate.setUseClientMode(b);
    }

    @Override
    public boolean getUseClientMode() {
        return delegate.getUseClientMode();
    }

    @Override
    public void setNeedClientAuth(boolean b) {
        delegate.setNeedClientAuth(b);
    }

    @Override
    public boolean getNeedClientAuth() {
        return delegate.getNeedClientAuth();
    }

    @Override
    public void setWantClientAuth(boolean b) {
        delegate.setWantClientAuth(b);
    }

    @Override
    public boolean getWantClientAuth() {
        return delegate.getWantClientAuth();
    }

    @Override
    public void setEnableSessionCreation(boolean b) {
        delegate.setEnableSessionCreation(b);
    }

    @Override
    public boolean getEnableSessionCreation() {
        return delegate.getEnableSessionCreation();
    }

    /**
     * JDK8 ALPN hack support method.
     *
     * These methods will be removed once JDK8 ALPN support is no longer required
     * @param applicationProtocols
     */
    public void setApplicationProtocols(List applicationProtocols) {
        this.applicationProtocols = applicationProtocols;
    }

    /**
     * JDK8 ALPN hack support method.
     *
     * These methods will be removed once JDK8 ALPN support is no longer required
     */
    public List getApplicationProtocols() {
        return applicationProtocols;
    }

    /**
     * JDK8 ALPN hack support method.
     *
     * These methods will be removed once JDK8 ALPN support is no longer required
     */
    public String getSelectedApplicationProtocol() {
        return selectedApplicationProtocol;
    }


    static ALPNHackServerByteArrayOutputStream replaceServerByteOutput(SSLEngine sslEngine, String selectedAlpnProtocol) throws SSLException {
        try {
            Object handshaker = HANDSHAKER.get(sslEngine);
            Object hash = HANDSHAKE_HASH.get(handshaker);
            ByteArrayOutputStream existing = (ByteArrayOutputStream) HANDSHAKE_HASH_DATA.get(hash);

            ALPNHackServerByteArrayOutputStream out = new ALPNHackServerByteArrayOutputStream(sslEngine, existing.toByteArray(), selectedAlpnProtocol);
            HANDSHAKE_HASH_DATA.set(hash, out);
            return out;
        } catch (Exception e) {
            throw UndertowMessages.MESSAGES.failedToReplaceHashOutputStream(e);
        }
    }

    static ALPNHackClientByteArrayOutputStream replaceClientByteOutput(SSLEngine sslEngine) throws SSLException {
        try {
            Object handshaker = HANDSHAKER.get(sslEngine);
            Object hash = HANDSHAKE_HASH.get(handshaker);

            ALPNHackClientByteArrayOutputStream out = new ALPNHackClientByteArrayOutputStream(sslEngine);
            HANDSHAKE_HASH_DATA.set(hash, out);
            return out;
        } catch (Exception e) {
            throw UndertowMessages.MESSAGES.failedToReplaceHashOutputStream(e);
        }
    }
    static void regenerateHashes(SSLEngine sslEngineToHack, ByteArrayOutputStream data, byte[]... hashBytes) {
        //hack up the SSL engine internal state
        try {
            Object handshaker = HANDSHAKER.get(sslEngineToHack);
            Object hash = HANDSHAKE_HASH.get(handshaker);
            data.reset();
            Object protocolVersion = HANDSHAKER_PROTOCOL_VERSION.get(handshaker);
            HANDSHAKE_HASH_VERSION.set(hash, -1);
            HANDSHAKE_HASH_PROTOCOL_DETERMINED.invoke(hash, protocolVersion);
            MessageDigest digest = (MessageDigest) HANDSHAKE_HASH_FIN_MD.get(hash);
            digest.reset();
            for (byte[] b : hashBytes) {
                HANDSHAKE_HASH_UPDATE.invoke(hash, b, 0, b.length);
            }
        } catch (Exception e) {
            throw UndertowMessages.MESSAGES.failedToReplaceHashOutputStreamOnWrite(e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy