All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.coos.messaging.transport.SecureNioTCPTransport Maven / Gradle / Ivy

The newest version!
/**
 * COOS - Connected Objects Operating System (www.connectedobjects.org).
 *
 * Copyright (C) 2009 Telenor ASA and Tellu AS. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This library is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published
 * by the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see .
 *
 * You may also contact one of the following for additional information:
 * Telenor ASA, Snaroyveien 30, N-1331 Fornebu, Norway (www.telenor.no)
 * Tellu AS, Hagalokkveien 13, N-1383 Asker, Norway (www.tellu.no)
 */
package org.coos.messaging.transport;

import org.coos.messaging.Channel;
import org.coos.messaging.Message;
import org.coos.messaging.Processor;
import org.coos.messaging.ProcessorException;
import org.coos.messaging.Service;
import org.coos.messaging.Transport;
import org.coos.messaging.impl.DefaultMessage;
import org.coos.messaging.impl.DefaultProcessor;
import org.coos.messaging.util.Log;
import org.coos.messaging.util.LogFactory;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;

import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;

import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;

import java.util.Collections;
import java.util.Hashtable;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;


/**
 * Represents one non-blocking tcp connection. Decodes messages by reading the
 * length from the first 4 bytes.
 *
 * @author Morten Versvik, Tellu AS
 *
 */
public class SecureNioTCPTransport extends DefaultProcessor implements Transport, Service {
    private static final Log logger = LogFactory.getLog(SecureNioTCPTransport.class.getName());
    protected final static int MAX_LENGTH = (16 * 1024); // Room for large
                                                         // headers
    protected final static int MAX_BODY_LENGTH = (8 * 1024);

    private final static int BUFFER_SIZE = MAX_LENGTH * 2;

    /**
     * 1 int is 4 bytes.
     */
    private static final int SIZE_POSITION = 4 - 1;
    protected List mailbox = Collections.synchronizedList(new LinkedList());
    protected Processor transportProcessor;
    private ByteBuffer netOutBuffer, netInBuffer, inBuffer;
    SocketChannel sc;
    SecureNioTCPTransportManager tm;
    Channel channel;

    /**
     * Initial buffer size.
     */
    private final byte[] buffer = new byte[BUFFER_SIZE];
    private int pos = 0;
    private int newmsgLength = Integer.MAX_VALUE;
    SSLEngine sslEngine;

    boolean initDone = false, usedTemporaryBuffer = false;
    ByteBuffer toBeWritten = null;

    public SecureNioTCPTransport(SecureNioTCPTransportManager tm, Selector selector,
        SocketChannel sc, Hashtable properties) throws IOException,
        NoSuchAlgorithmException, KeyManagementException, KeyStoreException, CertificateException,
        UnrecoverableKeyException {
        this.sc = sc;
        netInBuffer = ByteBuffer.allocateDirect(sc.socket().getReceiveBufferSize());
        inBuffer = ByteBuffer.allocateDirect(sc.socket().getReceiveBufferSize());
        netOutBuffer = ByteBuffer.allocateDirect(BUFFER_SIZE);

        this.tm = tm;

        String keystore = (properties.get("keystore"));
        char[] keystorepass = (properties.get("keystorepass")).toCharArray();
        char[] keypassword = (properties.get("keypassword")).toCharArray();

        // Fetch the keystore
        KeyStore ks = KeyStore.getInstance("JKS");

        FileInputStream fin = new FileInputStream(keystore);
        ks.load(fin, keystorepass);
        fin.close();

        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
        kmf.init(ks, keypassword);

        // set the context to SSL 3
        SSLContext sslcontext = SSLContext.getInstance("SSLv3");
        sslcontext.init(kmf.getKeyManagers(), null, null);

        sslEngine = sslcontext.createSSLEngine();

        sslEngine.setUseClientMode(false);
        sslEngine.setNeedClientAuth(false);
        sslEngine.setWantClientAuth(false);

    }

    public void setChainedProcessor(Processor chainedProcessor) {
        this.transportProcessor = chainedProcessor;
    }

    public void receivedMessage(Message in) {

        try {
            transportProcessor.processMessage(in);
        } catch (ProcessorException e) {
            logger.error("Caught unhandled exception with message: " + in, e);
        }
    }

    public int getMailboxSize() {
        return mailbox.size();
    }

    public Message getMessage() {
        return mailbox.remove(0);
    }

    public void setChannel(Channel channel) {
        this.channel = channel;
    }

    public void processMessage(Message msg) throws ProcessorException {
        String priStr = msg.getHeader(Message.PRIORITY);

        if (priStr != null) {
            int pri = Integer.valueOf(priStr);
            int idx = 0;

            for (Message message : mailbox) {
                String pr = message.getHeader(Message.PRIORITY);

                if (pr != null) {
                    int p = Integer.valueOf(pr);

                    if (pri < p) {
                        mailbox.add(idx, msg);

                        synchronized (this) {
                            this.notify();
                        }

                        break;
                    }
                }

                idx++;
            }
        } else {
            mailbox.add(msg);

            synchronized (this) {
                this.notify();
            }
        }

        // A new message was added, notify the writer.
        if (sc != null)
            tm.readyWrite(sc);
    }

    public void start() throws Exception {
    }

    public void stop() throws Exception {
    }

    public void removeTransport() {

        if (channel != null)
            channel.disconnect();
    }

    public byte[] decode(ByteBuffer socketBuffer) throws IOException {

        // Reads until the buffer is empty or until a packet
        // is fully reassembled.
        while (socketBuffer.hasRemaining()) {

            // Copies into the temporary buffer
            byte data = socketBuffer.get();

            try {
                buffer[pos] = data;

                if (pos == SIZE_POSITION) { // Got size parameter

                    ByteArrayInputStream bin = new ByteArrayInputStream(buffer);
                    DataInputStream din = new DataInputStream(bin);

                    newmsgLength = din.readInt() + SIZE_POSITION + 1;

                    if (newmsgLength > MAX_LENGTH)
                        throw new IOException("Packet too big, declared size " + newmsgLength);

                    if (buffer.length < newmsgLength) { // Should not happen, it

                        // is 20kb default
                        // (larger than max)
                        byte[] newbuffer = new byte[newmsgLength];
                        System.arraycopy(buffer, 0, newbuffer, 0, pos + 1);
                    }
                }
            } catch (IndexOutOfBoundsException e) {

                // We resize the buffer, shouldn't happen.
                throw new IOException("Packet too big. Maximum size allowed: " + BUFFER_SIZE +
                    " bytes.");
            }

            pos++;

            // Check if it is the final byte of a packet.
            if (pos == newmsgLength) {
                newmsgLength = Integer.MAX_VALUE;

                // The current packet is fully reassembled. Return it
                byte[] newBuffer = new byte[pos];
                System.arraycopy(buffer, 0, newBuffer, 0, pos);
                pos = 0;

                return newBuffer;
            }
        }

        // No packet was reassembled. There is not enough data. Wait
        // for more data to arrive.
        return null;
    }

    /**
     * Reads and decodes bytes from socket
     *
     * @return -1 disconnected, 0 = 0 read, 1 = >= 1 bytes read
     * @throws Exception
     */
    synchronized public int decodeFromSocket() throws Exception {
        int readBytes;

        try {
            readBytes = sc.read(netInBuffer);

            // logger.info("read: " + readBytes);
            if (readBytes == -1) {
                tm.socketDisconnected(sc);

                return -1;
            }
        } catch (ClosedChannelException e) {
            tm.socketDisconnected(sc);

            return -1;
        }

        // if ((readBytes == 0) && (netInBuffer.position() == 0) && initDone) {
        // return 0;
        // }

        netInBuffer.flip();

        SSLEngineResult ser = sslEngine.unwrap(netInBuffer, inBuffer);

        // logger.info("read " + ser);

        /*
         * logger.info("read: " + ser + " in: " + inBuffer.position() + " - " +
         * inBuffer.limit());
         */

        if (ser.getHandshakeStatus() == HandshakeStatus.FINISHED) {
            initDone = true;
        }

        while (ser.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
            Executor exec = Executors.newSingleThreadExecutor();
            Runnable task;
            netInBuffer.compact();

            while ((task = sslEngine.getDelegatedTask()) != null) {
                exec.execute(task);
            }

            exec.execute(new Runnable() {
                    public void run() {

                        try {
                            decodeFromSocket();
                        } catch (Exception e) {
                            logger.error("Unhandled exception in tasks", e);
                        }
                    }
                });

            return 1;
        }

        if (ser.getHandshakeStatus() == HandshakeStatus.NEED_WRAP) {
            netInBuffer.compact();
            tm.readyWrite(sc);

            return 1;
        }

        if (ser.getStatus() == Status.BUFFER_UNDERFLOW) { // Resize input to
                                                          // have more room
            netInBuffer.compact();

            int netsize = sslEngine.getSession().getPacketBufferSize();

            if (netsize > netInBuffer.capacity()) {
                ByteBuffer newBuffer = ByteBuffer.allocateDirect(netsize);
                netInBuffer.flip();
                newBuffer.put(netInBuffer);
                netInBuffer = newBuffer;
            }

            return 1;
        } else if (ser.getStatus() == Status.BUFFER_OVERFLOW) { // Resize output
            netInBuffer.compact();

            int readsize = sslEngine.getSession().getApplicationBufferSize();
            ByteBuffer newBuffer = ByteBuffer.allocateDirect(readsize + inBuffer.position());
            inBuffer.flip();
            newBuffer.put(inBuffer);
            inBuffer = newBuffer;

            return 1;
        }

        if (ser.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) {
            netInBuffer.compact();

            return decodeFromSocket();
        }

        // Try unwrapping more from the same packet.

        boolean readMore = true;

        while (readMore && initDone) {
            ser = sslEngine.unwrap(netInBuffer, inBuffer);
            // logger.info("unwrap again! " + ser);

            if (ser.getStatus() == Status.BUFFER_UNDERFLOW)
                readMore = false;
            else if (ser.getStatus() == Status.BUFFER_OVERFLOW) {

                while (ser.getStatus() == Status.BUFFER_OVERFLOW) {
                    int readsize = sslEngine.getSession().getApplicationBufferSize();
                    ByteBuffer newBuffer = ByteBuffer.allocateDirect(readsize +
                            inBuffer.capacity());
                    inBuffer.flip();
                    newBuffer.put(inBuffer);
                    inBuffer = newBuffer;
                    ser = sslEngine.unwrap(netInBuffer, inBuffer);
                    // logger.info("unwrap again loop! " + ser);
                }
            }

        }

        netInBuffer.compact();

        inBuffer.flip();

        byte[] bb = null;

        do {
            bb = decode(inBuffer);
            //logger.info("decode: " + " in: " + inBuffer.position() + " - " + inBuffer.limit() + " " + bb);

            if (bb == null) {

                // Partial packet received. Must wait for more data. All the
                // contents
                // of inBuffer were processed by the protocol decoder. We can
                // delete it and prepare for more data.
                inBuffer.clear();
                inBuffer.flip();
                // netInBuffer.flip();
                // netInBuffer.clear();

            } else {
                // A packet was reassembled.

                Message msg;

                try {

                    // logger.info(bb.length);
                    msg = new DefaultMessage(new DataInputStream(new ByteArrayInputStream(bb)));

                    //  logger.info("read msg" + msg);

                    if ((msg.getSerializedBody() == null) ||
                            ((msg.getSerializedBody() != null) &&
                                (msg.getSerializedBody().length <= MAX_BODY_LENGTH)))
                        receivedMessage(msg);

                    // netInBuffer.flip();

                } catch (Exception e) {
                    logger.warn("Unhandled exception handling message", e);
                }
                // The netInBuffer might still have some data left. Perhaps
                // the beginning of another packet. So don't clear it. Next
                // time reading is activated, we start by processing the
                // netInBuffer
                // again.
            }
        } while (bb != null);

        inBuffer.compact();

        return 1;
    }

    public void handleWrite() throws Exception { // Netoutbuffer starts in

        //logger.info("handlewrite");

        // fillable mode
        netOutBuffer.flip(); // Readable mode

        if (!netOutBuffer.hasRemaining()) {
            netOutBuffer.clear(); // Fillable mode

            SSLEngineResult ser = null;

            if (initDone) {

                if (mailbox.isEmpty() && ((toBeWritten == null) || !toBeWritten.hasRemaining())) {
                    tm.doneWrite(sc);

                    // logger.info("doneWrite");
                    return;
                } else {

                    if ((toBeWritten != null) && toBeWritten.hasRemaining()) {
                        ser = sslEngine.wrap(toBeWritten, netOutBuffer); // Filled
                                                                         // buffer

                        //  logger.info("write, ser " + ser);
                    } else {
                        Message msg = getMessage();

                        byte[] data = msg.serialize();
                        /*
                         * if (data.length > MAX_LENGTH) { // Send larger
                         * message // than maximum, create // a temp buffer
                         * logger.info("Sending large message, size: " +
                         * data.length + " receiver: " +
                         * msg.getReceiverEndpointUri()); usedTemporaryBuffer =
                         * true; netOutBuffer = ByteBuffer.allocate(data.length
                         * + 100); } else if (usedTemporaryBuffer) {
                         * usedTemporaryBuffer = false;
                         */
                        // netOutBuffer = ByteBuffer.allocate(BUFFER_SIZE);
                        // toBeWritten = null;
                        // }

                        toBeWritten = ByteBuffer.wrap(data);
                        ser = sslEngine.wrap(toBeWritten, netOutBuffer); // Filled
                                                                         // buffer

                        //  logger.info("write complete, ser " + ser);
                    }
                }
            } else {

                // SslEngine wants to send more headers?
                ser = sslEngine.wrap(ByteBuffer.allocate(0), netOutBuffer);

                if (ser.getStatus() == Status.BUFFER_OVERFLOW) {
                    netOutBuffer = ByteBuffer.allocateDirect(sslEngine.getSession()
                            .getPacketBufferSize());

                    handleWrite(); // Retry encryption

                    return;
                }
            }

            if (ser.getStatus() == Status.BUFFER_OVERFLOW) {
                netOutBuffer = ByteBuffer.allocateDirect(netOutBuffer.capacity() * 2);

                handleWrite(); // Retry encryption

                return;
            }

            if (ser.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                Executor exec = Executors.newSingleThreadExecutor();
                Runnable task;

                while ((task = sslEngine.getDelegatedTask()) != null) {
                    exec.execute(task);
                }
            }

            if (ser.getHandshakeStatus() == HandshakeStatus.FINISHED) {
                initDone = true;
            }

            netOutBuffer.flip(); // Readable mode

            sc.write(netOutBuffer);
            netOutBuffer.compact(); // Fillable mode

            if ((ser.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) && !initDone) {

                // if (decodeFromSocket() == 0)
                tm.doneWrite(sc); // Need to wait for reading

                return;
            }
        } else { // Empty buffer before we try sending more
            sc.write(netOutBuffer); // Readable mode

            // Check if there is more to be written.
            if (!netOutBuffer.hasRemaining()) {
                // netOutBuffer was completely written, deactivate write intent.
                // tm.doneWrite(sc); // Will be done in next writer

                netOutBuffer.clear(); // Fillable
            } else {
                netOutBuffer.compact(); // Fillable
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy