All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.marklogic.io.SslByteChannel Maven / Gradle / Ivy

/*
 * Copyright 2003-2019 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.marklogic.io;

import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.logging.Logger;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLEngineResult.Status;

// TODO: Add more logging?
/**
 * A ByteChannel that passes the data through an SSLEngine.
 */
public class SslByteChannel implements ByteChannel {
    private final ByteChannel wrappedChannel;
    private final SSLEngine engine;
    protected final Logger logger;

    private ByteBuffer inAppData; // cleartext decoded from SSL
    private final ByteBuffer outAppData; // cleartext data to send
    private ByteBuffer inNetData; // SSL data read from wrappedChannel
    private final ByteBuffer outNetData; // SSL data to send on wrappedChannel

    private boolean closed = false;
    private int timeoutMillis = 0;
    private Selector selector = null;

    public void setTimeout(int timeoutMillis) {
        this.timeoutMillis = timeoutMillis;
    }

    public int getTimeout() {
        return timeoutMillis;
    }

    /**
     * Creates a new instance of SSLByteChannel
     * 
     * @param wrappedChannel
     *            The byte channel on which this ssl channel is built. This channel contains
     *            encrypted data.
     * @param engine
     *            A SSLEngine instance that will remember SSL current context. Warning, such an
     *            instance CAN NOT be shared
     * @param logger
     *            Logger for logging.
     */
    public SslByteChannel(ByteChannel wrappedChannel, SSLEngine engine, Logger logger) {
        this.wrappedChannel = wrappedChannel;
        this.engine = engine;
        this.logger = logger;

        SSLSession session = engine.getSession();

        inAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
        outAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
        logger.fine("app buffer size=" + session.getApplicationBufferSize());

        inNetData = ByteBuffer.allocate(session.getPacketBufferSize());
        outNetData = ByteBuffer.allocate(session.getPacketBufferSize());
        logger.fine("app buffer size=" + session.getPacketBufferSize());
    }

    /**
     * Ends SSL operation and close the wrapped byte channel
     * 
     * @throws java.io.IOException
     *             May be raised by close operation on wrapped byte channel
     */
    public void close() throws IOException {
        close(true);
    }
    
    public void close(boolean closeSocket) throws IOException {
        if (!closed) {
            logger.fine("closing SslByteChannel");
            try {
                try {
                    engine.closeOutbound();
                    SSLEngineResult ser = wrapAppData();
                    if (ser.getStatus() != Status.CLOSED) {
                        logger.fine("SSLEngine not closed, calling handshake");
                        handleHandshake(ser);
                    }
                    if (closeSocket && selector != null) {
                        selector.close();
                    }
                } catch (IOException e) {
                    // do nothing here
                }
                if (closeSocket) {
                    wrappedChannel.close();
                }
            } finally {
                closed = true;
            }
        }
    }

    /**
     * Is the channel open ?
     * 
     * @return true if the channel is still open
     */
    public boolean isOpen() {
        return !closed;
    }

    /**
     * Fill the given buffer with some bytes and return the number of bytes added in the buffer.
* This method may return immediately with nothing added in the buffer. This method must be use * exactly in the same way of ByteChannel read operation, so be careful with buffer position, * limit, ... Check corresponding javadoc. * * @param clientBuffer * The buffer that will received read bytes * @return The number of bytes read * @throws java.io.IOException * May be raised by ByteChannel read operation */ public int read(ByteBuffer clientBuffer) throws IOException { // first try to copy out anything left over from last time int bytesCopied = copyOutClientData(clientBuffer); if (bytesCopied > 0) return bytesCopied; fillBufferFromEngine(); bytesCopied = copyOutClientData(clientBuffer); if (bytesCopied > 0) return bytesCopied; return -1; } private void fillBufferFromEngine() throws IOException { while (true) { SSLEngineResult ser = unwrapNetData(); if (ser.bytesProduced() > 0) return; switch (ser.getStatus()) { case OK: break; case CLOSED: close(); return; case BUFFER_OVERFLOW: { int appSize = engine.getSession().getApplicationBufferSize(); ByteBuffer b = ByteBuffer.allocate(appSize + inAppData.position()); inAppData.flip(); b.put(inAppData); inAppData = b; continue; // retry operation } case BUFFER_UNDERFLOW: { int netSize = engine.getSession().getPacketBufferSize(); if (netSize > inNetData.capacity()) { ByteBuffer b = ByteBuffer.allocate(netSize); inNetData.flip(); b.put(inNetData); inNetData = b; } int rc = timedRead(inNetData, timeoutMillis); if (rc == 0 && timeoutMillis > 0) { throw new IOException("Timeout waiting for read (" + timeoutMillis + " milliseconds)"); } if (rc == -1) break; continue; // retry operation } } switch (ser.getHandshakeStatus()) { case NOT_HANDSHAKING: return; default: handleHandshake(ser); break; } } } public int readInsideHandshake(ByteBuffer clientBuffer) throws IOException { // first try to copy out anything left over from last time int bytesCopied = copyOutClientData(clientBuffer); if (bytesCopied > 0) { logger.fine("read bytesCopied=" + bytesCopied); return bytesCopied; } int handShake = fillBufferFromEngineInsideHandshake(); if (handShake==-2) { return handShake;//Done handshake from within another handshake call } else { bytesCopied = copyOutClientData(clientBuffer); if (bytesCopied > 0) { logger.fine("read bytesCopied=" + bytesCopied); return bytesCopied; } } return -1; } private int fillBufferFromEngineInsideHandshake() throws IOException { boolean doneHandshake=false; while (true) { SSLEngineResult ser = unwrapNetData(); if (ser.bytesProduced() > 0) return 0; switch (ser.getStatus()) { case OK: break; case CLOSED: close(); return 0; case BUFFER_OVERFLOW: { int appSize = engine.getSession().getApplicationBufferSize(); ByteBuffer b = ByteBuffer.allocate(appSize + inAppData.position()); inAppData.flip(); b.put(inAppData); inAppData = b; continue; // retry operation } case BUFFER_UNDERFLOW: { if (doneHandshake) { // Related to support CASE #16254 // Summary: // The handshake is completed in this function during a read // that was triggered from within another handshake call // due to UNDERFLOW mode (yes, nested handshake calls). // // This happens with a busy server that doesn't immediately // respond to an SSL handshake. The client gets in UNDERFLOW // mode during the sslAccept procedure. It then needs to // wait for the server to respond before it can really // finish the handshake. // But while reading, another handshake is initiated in this // function. If the handshake is successful the story should // end. However, the client keeps looping back here thinking // that it still has something to read. Meanwhile, the // server is done with the sslAccept (after a successful // handshake). The sever starts a background thread and // waits for the client that never sends anything since it's // stuck reading over and over. Eventually, the caller of // this code kills the read due to tiemout and attempts to // cleanly close the connection with the server. The // connection closing doesn't fully complete until the // waiting Server thread times out typically ater 30 // seconds(default "request timeout" on server). return -2; } int netSize = engine.getSession().getPacketBufferSize(); if (netSize > inNetData.capacity()) { ByteBuffer b = ByteBuffer.allocate(netSize); inNetData.flip(); b.put(inNetData); inNetData = b; } int rc = timedRead(inNetData, timeoutMillis); if (rc == 0 && timeoutMillis > 0) { throw new IOException("Timeout waiting for read (" + timeoutMillis + " milliseconds)"); } if (rc == -1) { break; } continue; // retry operation } } switch (ser.getHandshakeStatus()) { case NOT_HANDSHAKING: return 0; default: handleHandshake(ser); doneHandshake=true; break; } } } private int timedRead(ByteBuffer buf, int timeoutMillis) throws IOException { if (timeoutMillis <= 0) return wrappedChannel.read(buf); SelectableChannel ch = (SelectableChannel)wrappedChannel; synchronized (ch) { SelectionKey key = null; if (selector == null) { selector = Selector.open(); } try { selector.selectNow(); // Needed to clear old key state ch.configureBlocking(false); key = ch.register(selector, SelectionKey.OP_READ); selector.select(timeoutMillis); return wrappedChannel.read(buf); } finally { if (key != null) key.cancel(); ch.configureBlocking(true); } } } /** * Write remaining bytes of the given byte buffer. This method may return immediately with * nothing written. This method must be use exactly in the same way of ByteChannel write * operation, so be careful with buffer position, limit, ... Check corresponding javadoc. * * @param clientBuffer * buffer with remaining bytes to write * @return The number of bytes written * @throws java.io.IOException * May be raised by ByteChannel write operation */ public int write(ByteBuffer clientBuffer) throws IOException { int bytesWritten = 0; while (clientBuffer.remaining() > 0) { bytesWritten += pushToEngine(clientBuffer); } return bytesWritten; } private int pushToEngine(ByteBuffer clientBuffer) throws IOException { int bytesWritten = 0; while (clientBuffer.remaining() > 0) { bytesWritten += copyInClientData(clientBuffer); logger.fine("bytesWritten="+bytesWritten); while (outAppData.position() > 0) { SSLEngineResult ser = wrapAppData(); logger.fine("ser.getStatus()="+ser.getStatus()); logger.fine("ser.getHandshakeStatus()="+ser.getHandshakeStatus()); logger.fine("app bytes after wrap()="+outAppData.position()); switch (ser.getStatus()) { case OK: break; case CLOSED: pushNetData(); close(); return bytesWritten; case BUFFER_OVERFLOW: continue; case BUFFER_UNDERFLOW: return bytesWritten; // TODO: handshake needed here? } switch (ser.getHandshakeStatus()) { case NOT_HANDSHAKING: break; default: handleHandshake(ser); break; } } } return bytesWritten; } private void handleHandshake(SSLEngineResult initialSer) throws IOException { SSLEngineResult ser = initialSer; while (ser.getStatus() != Status.CLOSED) { if (ser.getStatus() == Status.BUFFER_UNDERFLOW) { int n = readInsideHandshake(inNetData); if (n==-2) {//done handshake from within another handshake call return; } if (n<0) throw new EOFException("SSL wrapped byte channel"); } switch (ser.getHandshakeStatus()) { case NEED_TASK: Runnable task; while ((task = engine.getDelegatedTask()) != null) { task.run(); } pushNetData(); ser = wrapAppData(); break; case NEED_WRAP: pushNetData(); ser = wrapAppData(); break; case NEED_UNWRAP: pushNetData(); if (inNetData.position() == 0) { int n = wrappedChannel.read(inNetData); if (n<0) throw new EOFException("SSL wrapped byte channel"); } ser = unwrapNetData(); break; case FINISHED: case NOT_HANDSHAKING: return; } } } private SSLEngineResult unwrapNetData() throws SSLException { SSLEngineResult ser; inNetData.flip(); ser = engine.unwrap(inNetData, inAppData); inNetData.compact(); return ser; } private SSLEngineResult wrapAppData() throws IOException { outAppData.flip(); SSLEngineResult ser = engine.wrap(outAppData, outNetData); outAppData.compact(); pushNetData(); return ser; } private void pushNetData() throws IOException { outNetData.flip(); while (outNetData.remaining() > 0) { wrappedChannel.write(outNetData); } outNetData.compact(); } // ------------------------------------------------------------ private int copyInClientData(ByteBuffer clientBuffer) { if (clientBuffer.remaining() == 0) { return 0; } int posBefore; posBefore = clientBuffer.position(); if (clientBuffer.remaining() <= outAppData.remaining()) { outAppData.put(clientBuffer); } else { while (clientBuffer.hasRemaining() && outAppData.hasRemaining()) { outAppData.put(clientBuffer.get()); } } return clientBuffer.position() - posBefore; } private int copyOutClientData(ByteBuffer clientBuffer) { inAppData.flip(); int posBefore = inAppData.position(); if (inAppData.remaining() <= clientBuffer.remaining()) { clientBuffer.put(inAppData); } else { while (clientBuffer.hasRemaining()) { clientBuffer.put(inAppData.get()); } } int posAfter = inAppData.position(); inAppData.compact(); return posAfter - posBefore; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy