All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.fusesource.hawtdispatch.transport.SslProtocolCodec Maven / Gradle / Ivy

/**
 * Copyright (C) 2012 FuseSource, Inc.
 * http://fusesource.com
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.fusesource.hawtdispatch.transport;

import org.fusesource.hawtdispatch.Task;

import javax.net.ssl.*;
import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.GatheringByteChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.ScatteringByteChannel;
import java.nio.channels.WritableByteChannel;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;

import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;

/**
 * Implements the SSL protocol as a WrappingProtocolCodec.  Useful for when
 * you want to switch to the SSL protocol on a regular TCP Transport.
 */
public class SslProtocolCodec implements WrappingProtocolCodec, SecuredSession {

    private ReadableByteChannel readChannel;
    private WritableByteChannel writeChannel;

    public enum ClientAuth {
        WANT, NEED, NONE
    };

    private SSLContext sslContext;
    private SSLEngine engine;

    private ByteBuffer readBuffer;
    private boolean readUnderflow;

    private ByteBuffer writeBuffer;
    private boolean writeFlushing;

    private ByteBuffer readOverflowBuffer;
    Transport transport;

    int lastReadSize;
    int lastWriteSize;
    long readCounter;
    long writeCounter;

    ProtocolCodec next;


    public SslProtocolCodec() {
    }

    public ProtocolCodec getNext() {
        return next;
    }
    public void setNext(ProtocolCodec next) {
        this.next = next;
        initNext();
    }

    private void initNext() {
        if( next!=null ) {
            this.next.setTransport(new TransportFilter(transport){
                public ReadableByteChannel getReadChannel() {
                    return sslReadChannel;
                }
                public WritableByteChannel getWriteChannel() {
                    return sslWriteChannel;
                }
            });
        }
    }

    public void setSSLContext(SSLContext ctx) {
        assert engine == null;
        this.sslContext = ctx;
    }

    public SslProtocolCodec client() throws Exception {
        initializeEngine();
        engine.setUseClientMode(true);
        engine.beginHandshake();
        return this;
    }

    public SslProtocolCodec server(ClientAuth clientAuth) throws Exception {
        initializeEngine();
        engine.setUseClientMode(false);
        switch (clientAuth) {
            case WANT: engine.setWantClientAuth(true); break;
            case NEED: engine.setNeedClientAuth(true); break;
            case NONE: engine.setWantClientAuth(false); break;
        }
        engine.beginHandshake();
        return this;
    }

    protected void initializeEngine() throws Exception {
        assert engine == null;
        if( sslContext == null ) {
            sslContext = SSLContext.getDefault();
        }
        engine = sslContext.createSSLEngine();
        SSLSession session = engine.getSession();
        readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
        readBuffer.flip();
        writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
    }


    public SSLSession getSSLSession() {
        return engine==null ? null : engine.getSession();
    }

    public X509Certificate[] getPeerX509Certificates() {
    	if( engine==null ) {
            return null;
        }
        try {
            ArrayList rc = new ArrayList();
            for( Certificate c:engine.getSession().getPeerCertificates() ) {
                if(c instanceof X509Certificate) {
                    rc.add((X509Certificate) c);
                }
            }
            return rc.toArray(new X509Certificate[rc.size()]);
        } catch (SSLPeerUnverifiedException e) {
            return null;
        }
    }

    SSLReadChannel sslReadChannel = new SSLReadChannel();
    SSLWriteChannel sslWriteChannel = new SSLWriteChannel();

    public void setTransport(Transport transport) {
        this.transport = transport;
        this.readChannel = transport.getReadChannel();
        this.writeChannel = transport.getWriteChannel();
        initNext();
    }

    public void handshake() throws IOException {
        if( !transportFlush() ) {
            return;
        }
        switch (engine.getHandshakeStatus()) {
            case NEED_TASK:
                final Runnable task = engine.getDelegatedTask();
                if( task!=null ) {
                    transport.getBlockingExecutor().execute(new Task() {
                        public void run() {
                            task.run();
                            transport.getDispatchQueue().execute(new Task() {
                                public void run() {
                                    if (readChannel.isOpen() && writeChannel.isOpen()) {
                                        try {
                                            handshake();
                                        } catch (IOException e) {
                                            transport.getTransportListener().onTransportFailure(e);
                                        }
                                    }
                                }
                            });
                        }
                    });
                }
                break;

            case NEED_WRAP:
                secure_write(ByteBuffer.allocate(0));
                break;

            case NEED_UNWRAP:
                if( secure_read(ByteBuffer.allocate(0)) == -1) {
                    throw new EOFException("Peer disconnected during ssl handshake");
                }
                break;

            case FINISHED:
            case NOT_HANDSHAKING:
                transport.drainInbound();
                transport.getTransportListener().onRefill();
                break;

            default:
                System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
                break;
        }
    }

    /**
     * @return true if fully flushed.
     * @throws IOException
     */
    protected boolean transportFlush() throws IOException {
        while (true) {
            if(writeFlushing) {
                lastWriteSize = writeChannel.write(writeBuffer);
                if( lastWriteSize > 0 ) {
                    writeCounter += lastWriteSize;
                }
                if( !writeBuffer.hasRemaining() ) {
                    writeBuffer.clear();
                    writeFlushing = false;
                    return true;
                } else {
                    return false;
                }
            } else {
                if( writeBuffer.position()!=0 ) {
                    writeBuffer.flip();
                    writeFlushing = true;
                } else {
                    return true;
                }
            }
        }
    }

    private int secure_read(ByteBuffer plain) throws IOException {
        int rc=0;
        while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
            if( readOverflowBuffer !=null ) {
                if(  plain.hasRemaining() ) {
                    // lets drain the overflow buffer before trying to suck down anymore
                    // network bytes.
                    int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
                    plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
                    readOverflowBuffer.position(readOverflowBuffer.position()+size);
                    if( !readOverflowBuffer.hasRemaining() ) {
                        readOverflowBuffer = null;
                    }
                    rc += size;
                } else {
                    return rc;
                }
            } else if( readUnderflow ) {
                lastReadSize = readChannel.read(readBuffer);
                if( lastReadSize == -1 ) {  // peer closed socket.
                    if (rc==0) {
                        return -1;
                    } else {
                        return rc;
                    }
                }
                if( lastReadSize==0 ) {  // no data available right now.
                    return rc;
                }
                readCounter += lastReadSize;
                // read in some more data, perhaps now we can unwrap.
                readUnderflow = false;
                readBuffer.flip();
            } else {
                SSLEngineResult result = engine.unwrap(readBuffer, plain);
                rc += result.bytesProduced();
                if( result.getStatus() == BUFFER_OVERFLOW ) {
                    readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
                    result = engine.unwrap(readBuffer, readOverflowBuffer);
                    if( readOverflowBuffer.position()==0 ) {
                        readOverflowBuffer = null;
                    } else {
                        readOverflowBuffer.flip();
                    }
                }
                switch( result.getStatus() ) {
                    case CLOSED:
                        if (rc==0) {
                            engine.closeInbound();
                            return -1;
                        } else {
                            return rc;
                        }
                    case OK:
                        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
                            handshake();
                        }
                        break;
                    case BUFFER_UNDERFLOW:
                        readBuffer.compact();
                        readUnderflow = true;
                        break;
                    case BUFFER_OVERFLOW:
                        throw new AssertionError("Unexpected case.");
                }
            }
        }
        return rc;
    }

    private int secure_write(ByteBuffer plain) throws IOException {
        if( !transportFlush() ) {
            // can't write anymore until the write_secured_buffer gets fully flushed out..
            return 0;
        }
        int rc = 0;
        while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
            SSLEngineResult result = engine.wrap(plain, writeBuffer);
            assert result.getStatus()!= BUFFER_OVERFLOW;
            rc += result.bytesConsumed();
            if( !transportFlush() ) {
                break;
            }
        }
        if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
            handshake();
        }
        return rc;
    }

    public class SSLReadChannel implements ScatteringByteChannel {

        public int read(ByteBuffer plain) throws IOException {
            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
                handshake();
            }
            return secure_read(plain);
        }

        public boolean isOpen() {
            return readChannel.isOpen();
        }

        public void close() throws IOException {
            readChannel.close();
        }

        public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
            if(offset+length > dsts.length || length<0 || offset<0) {
                throw new IndexOutOfBoundsException();
            }
            long rc=0;
            for (int i = 0; i < length; i++) {
                ByteBuffer dst = dsts[offset+i];
                if(dst.hasRemaining()) {
                    rc += read(dst);
                }
                if( dst.hasRemaining() ) {
                    return rc;
                }
            }
            return rc;
        }

        public long read(ByteBuffer[] dsts) throws IOException {
            return read(dsts, 0, dsts.length);
        }
    }

    public class SSLWriteChannel implements GatheringByteChannel {

        public int write(ByteBuffer plain) throws IOException {
            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
                handshake();
            }
            return secure_write(plain);
        }

        public boolean isOpen() {
            return writeChannel.isOpen();
        }

        public void close() throws IOException {
            writeChannel.close();
        }

        public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
            if(offset+length > srcs.length || length<0 || offset<0) {
                throw new IndexOutOfBoundsException();
            }
            long rc=0;
            for (int i = 0; i < length; i++) {
                ByteBuffer src = srcs[offset+i];
                if(src.hasRemaining()) {
                    rc += write(src);
                }
                if( src.hasRemaining() ) {
                    return rc;
                }
            }
            return rc;
        }

        public long write(ByteBuffer[] srcs) throws IOException {
            return write(srcs, 0, srcs.length);
        }
    }

    public void unread(byte[] buffer) {
        readBuffer.compact();
        if( readBuffer.remaining() < buffer.length) {
            throw new IllegalStateException("Cannot unread now");
        }
        readBuffer.put(buffer);
        readBuffer.flip();
    }

    public Object read() throws IOException {
        return next.read();
    }

    public ProtocolCodec.BufferState write(Object value) throws IOException {
        return next.write(value);
    }

    public ProtocolCodec.BufferState flush() throws IOException {
        return next.flush();
    }

    public boolean full() {
        return next.full();
    }

    public long getWriteCounter() {
        return writeCounter;
    }

    public long getLastWriteSize() {
        return lastWriteSize;
    }

    public long getReadCounter() {
        return readCounter;
    }

    public long getLastReadSize() {
        return lastReadSize;
    }

    public int getReadBufferSize() {
        return readBuffer.capacity();
    }

    public int getWriteBufferSize() {
        return writeBuffer.capacity();
    }



}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy