All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.activemq.transport.nio.NIOSSLTransport Maven / Gradle / Ivy

There is a newer version: 6.1.2
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.activemq.transport.nio;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;

import org.apache.activemq.command.ConnectionInfo;
import org.apache.activemq.openwire.OpenWireFormat;
import org.apache.activemq.thread.TaskRunnerFactory;
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.apache.activemq.wireformat.WireFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NIOSSLTransport extends NIOTransport {

    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);

    protected boolean needClientAuth;
    protected boolean wantClientAuth;
    protected String[] enabledCipherSuites;
    protected String[] enabledProtocols;

    protected SSLContext sslContext;
    protected SSLEngine sslEngine;
    protected SSLSession sslSession;

    protected volatile boolean handshakeInProgress = false;
    protected SSLEngineResult.Status status = null;
    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
    protected TaskRunnerFactory taskRunnerFactory;

    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
        super(wireFormat, socketFactory, remoteLocation, localLocation);
    }

    public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
            ByteBuffer inputBuffer) throws IOException {
        super(wireFormat, socket, initBuffer);
        this.sslEngine = engine;
        if (engine != null) {
            this.sslSession = engine.getSession();
        }
        this.inputBuffer = inputBuffer;
    }

    public void setSslContext(SSLContext sslContext) {
        this.sslContext = sslContext;
    }

    volatile boolean hasSslEngine = false;

    @Override
    protected void initializeStreams() throws IOException {
        if (sslEngine != null) {
            hasSslEngine = true;
        }
        NIOOutputStream outputStream = null;
        try {
            channel = socket.getChannel();
            channel.configureBlocking(false);

            if (sslContext == null) {
                sslContext = SSLContext.getDefault();
            }

            String remoteHost = null;
            int remotePort = -1;

            try {
                URI remoteAddress = new URI(this.getRemoteAddress());
                remoteHost = remoteAddress.getHost();
                remotePort = remoteAddress.getPort();
            } catch (Exception e) {
            }

            // initialize engine, the initial sslSession we get will need to be
            // updated once the ssl handshake process is completed.
            if (!hasSslEngine) {
                if (remoteHost != null && remotePort != -1) {
                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
                } else {
                    sslEngine = sslContext.createSSLEngine();
                }

                sslEngine.setUseClientMode(false);
                if (enabledCipherSuites != null) {
                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
                }

                if (enabledProtocols != null) {
                    sslEngine.setEnabledProtocols(enabledProtocols);
                }

                if (wantClientAuth) {
                    sslEngine.setWantClientAuth(wantClientAuth);
                }

                if (needClientAuth) {
                    sslEngine.setNeedClientAuth(needClientAuth);
                }

                sslSession = sslEngine.getSession();

                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
                inputBuffer.clear();
            }

            outputStream = new NIOOutputStream(channel);
            outputStream.setEngine(sslEngine);
            this.dataOut = new DataOutputStream(outputStream);
            this.buffOut = outputStream;

            //If the sslEngine was not passed in, then handshake
            if (!hasSslEngine) {
                sslEngine.beginHandshake();
            }
            handshakeStatus = sslEngine.getHandshakeStatus();
            if (!hasSslEngine) {
                doHandshake();
            }

           // if (hasSslEngine) {
            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
                @Override
                public void onSelect(SelectorSelection selection) {
                    try {
                        initialized.await();
                    } catch (InterruptedException error) {
                        onException(IOExceptionSupport.create(error));
                    }
                    serviceRead();
                }

                @Override
                public void onError(SelectorSelection selection, Throwable error) {
                    if (error instanceof IOException) {
                        onException((IOException) error);
                    } else {
                        onException(IOExceptionSupport.create(error));
                    }
                }
            });
            doInit();

        } catch (Exception e) {
            try {
                if(outputStream != null) {
                    outputStream.close();
                }
                super.closeStreams();
            } catch (Exception ex) {}
            throw new IOException(e);
        }
    }

    final protected CountDownLatch initialized = new CountDownLatch(1);

    protected void doInit() throws Exception {
        taskRunnerFactory.execute(new Runnable() {

            @Override
            public void run() {
                //Need to start in new thread to let startup finish first
                //We can trigger a read because we know the channel is ready since the SSL handshake
                //already happened
                serviceRead();
                initialized.countDown();
            }
        });
    }

    //Only used for the auto transport to abort the openwire init method early if already initialized
    boolean openWireInititialized = false;

    protected void doOpenWireInit() throws Exception {
        //Do this later to let wire format negotiation happen
        if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
            initBuffer.buffer.flip();
            if (initBuffer.buffer.hasRemaining()) {
                nextFrameSize = -1;
                receiveCounter += initBuffer.readSize;
                processCommand(initBuffer.buffer);
                processCommand(initBuffer.buffer);
                initBuffer.buffer.clear();
                openWireInititialized = true;
            }
        }
    }

    protected void finishHandshake() throws Exception {
        if (handshakeInProgress) {
            handshakeInProgress = false;
            nextFrameSize = -1;

            // Once handshake completes we need to ask for the now real sslSession
            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
            // cipher suite.
            sslSession = sslEngine.getSession();

            // listen for events telling us when the socket is readable.
            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
                @Override
                public void onSelect(SelectorSelection selection) {
                    serviceRead();
                }

                @Override
                public void onError(SelectorSelection selection, Throwable error) {
                    if (error instanceof IOException) {
                        onException((IOException) error);
                    } else {
                        onException(IOExceptionSupport.create(error));
                    }
                }
            });
        }
    }

    @Override
    public void serviceRead() {
        try {
            if (handshakeInProgress) {
                doHandshake();
            }

            doOpenWireInit();

            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
            plain.position(plain.limit());

            while (true) {
                if (!plain.hasRemaining()) {

                    int readCount = secureRead(plain);

                    if (readCount == 0) {
                        break;
                    }

                    // channel is closed, cleanup
                    if (readCount == -1) {
                        onException(new EOFException());
                        selection.close();
                        break;
                    }

                    receiveCounter += readCount;
                }

                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
                    processCommand(plain);
                }
            }
        } catch (IOException e) {
            onException(e);
        } catch (Throwable e) {
            onException(IOExceptionSupport.create(e));
        }
    }

    protected void processCommand(ByteBuffer plain) throws Exception {

        // Are we waiting for the next Command or are we building on the current one
        if (nextFrameSize == -1) {

            // We can get small packets that don't give us enough for the frame size
            // so allocate enough for the initial size value and
            if (plain.remaining() < Integer.SIZE) {
                if (currentBuffer == null) {
                    currentBuffer = ByteBuffer.allocate(4);
                }

                // Go until we fill the integer sized current buffer.
                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
                    currentBuffer.put(plain.get());
                }

                // Didn't we get enough yet to figure out next frame size.
                if (currentBuffer.hasRemaining()) {
                    return;
                } else {
                    currentBuffer.flip();
                    nextFrameSize = currentBuffer.getInt();
                }

            } else {

                // Either we are completing a previous read of the next frame size or its
                // fully contained in plain already.
                if (currentBuffer != null) {

                    // Finish the frame size integer read and get from the current buffer.
                    while (currentBuffer.hasRemaining()) {
                        currentBuffer.put(plain.get());
                    }

                    currentBuffer.flip();
                    nextFrameSize = currentBuffer.getInt();

                } else {
                    nextFrameSize = plain.getInt();
                }
            }

            if (wireFormat instanceof OpenWireFormat) {
                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
                if (nextFrameSize > maxFrameSize) {
                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
                }
            }

            // now we got the data, lets reallocate and store the size for the marshaler.
            // if there's more data in plain, then the next call will start processing it.
            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
            currentBuffer.putInt(nextFrameSize);

        } else {
            // If its all in one read then we can just take it all, otherwise take only
            // the current frame size and the next iteration starts a new command.
            if (currentBuffer != null) {
                if (currentBuffer.remaining() >= plain.remaining()) {
                    currentBuffer.put(plain);
                } else {
                    byte[] fill = new byte[currentBuffer.remaining()];
                    plain.get(fill);
                    currentBuffer.put(fill);
                }

                // Either we have enough data for a new command or we have to wait for some more.
                if (currentBuffer.hasRemaining()) {
                    return;
                } else {
                    currentBuffer.flip();
                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
                    doConsume(command);
                    nextFrameSize = -1;
                    currentBuffer = null;
               }
            }
        }
    }

    protected int secureRead(ByteBuffer plain) throws Exception {

        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
            int bytesRead = channel.read(inputBuffer);

            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
                return 0;
            }

            if (bytesRead == -1) {
                sslEngine.closeInbound();
                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
                    return -1;
                }
            }
        }

        plain.clear();

        inputBuffer.flip();
        SSLEngineResult res;
        do {
            res = sslEngine.unwrap(inputBuffer, plain);
        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
                && res.bytesProduced() == 0);

        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
            finishHandshake();
        }

        status = res.getStatus();
        handshakeStatus = res.getHandshakeStatus();

        // TODO deal with BUFFER_OVERFLOW

        if (status == SSLEngineResult.Status.CLOSED) {
            sslEngine.closeInbound();
            return -1;
        }

        inputBuffer.compact();
        plain.flip();

        return plain.remaining();
    }

    protected void doHandshake() throws Exception {
        handshakeInProgress = true;
        Selector selector = null;
        SelectionKey key = null;
        boolean readable = true;
        try {
            while (true) {
                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
                switch (handshakeStatus) {
                    case NEED_UNWRAP:
                        if (readable) {
                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
                        }
                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
                            long now = System.currentTimeMillis();
                            if (selector == null) {
                                selector = Selector.open();
                                key = channel.register(selector, SelectionKey.OP_READ);
                            } else {
                                key.interestOps(SelectionKey.OP_READ);
                            }
                            int keyCount = selector.select(this.getSoTimeout());
                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
                                throw new SocketTimeoutException("Timeout during handshake");
                            }
                            readable = key.isReadable();
                        }
                        break;
                    case NEED_TASK:
                        Runnable task;
                        while ((task = sslEngine.getDelegatedTask()) != null) {
                            task.run();
                        }
                        break;
                    case NEED_WRAP:
                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
                        break;
                    case FINISHED:
                    case NOT_HANDSHAKING:
                        finishHandshake();
                        return;
                }
            }
        } finally {
            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
        }
    }

    @Override
    protected void doStart() throws Exception {
        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
        // no need to init as we can delay that until demand (eg in doHandshake)
        super.doStart();
    }

    @Override
    protected void doStop(ServiceStopper stopper) throws Exception {
        initialized.countDown();

        if (taskRunnerFactory != null) {
            taskRunnerFactory.shutdownNow();
            taskRunnerFactory = null;
        }
        if (channel != null) {
            channel.close();
            channel = null;
        }
        super.doStop(stopper);
    }

    /**
     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
     *
     * @param command
     *            The Command coming in.
     */
    @Override
    public void doConsume(Object command) {
        if (command instanceof ConnectionInfo) {
            ConnectionInfo connectionInfo = (ConnectionInfo) command;
            connectionInfo.setTransportContext(getPeerCertificates());
        }
        super.doConsume(command);
    }

    /**
     * @return peer certificate chain associated with the ssl socket
     */
    @Override
    public X509Certificate[] getPeerCertificates() {

        X509Certificate[] clientCertChain = null;
        try {
            if (sslEngine.getSession() != null) {
                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
            }
        } catch (SSLPeerUnverifiedException e) {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Failed to get peer certificates.", e);
            }
        }

        return clientCertChain;
    }

    public boolean isNeedClientAuth() {
        return needClientAuth;
    }

    public void setNeedClientAuth(boolean needClientAuth) {
        this.needClientAuth = needClientAuth;
    }

    public boolean isWantClientAuth() {
        return wantClientAuth;
    }

    public void setWantClientAuth(boolean wantClientAuth) {
        this.wantClientAuth = wantClientAuth;
    }

    public String[] getEnabledCipherSuites() {
        return enabledCipherSuites;
    }

    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
        this.enabledCipherSuites = enabledCipherSuites;
    }

    public String[] getEnabledProtocols() {
        return enabledProtocols;
    }

    public void setEnabledProtocols(String[] enabledProtocols) {
        this.enabledProtocols = enabledProtocols;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy