All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jppf.nio.SSLHandler Maven / Gradle / Ivy

The newest version!
/*
 * JPPF.
 * Copyright (C) 2005-2014 JPPF Team.
 * http://www.jppf.org
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jppf.nio;

import java.io.*;
import java.nio.*;
import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.*;

import javax.net.ssl.*;

import org.jppf.utils.*;
import org.slf4j.*;

/**
 * Wrapper for an {@link SSLEngine} and an associated channel.
 * @exclude
 */
public class SSLHandler
{
  /**
   * Logger for this class.
   */
  private static Logger log = LoggerFactory.getLogger(SSLHandler.class);
  /**
   * Determines whether DEBUG logging level is enabled.
   */
  private static boolean traceEnabled = log.isTraceEnabled();
  /**
   * The socket channel from which data is read or to which data is written.
   */
  private SocketChannel channel;
  /**
   * The SSLEngine performs the SSL-related operations before sending data/after receiving data.
   */
  private SSLEngine sslEngine;
  /**
   * Contains the result of the latest wrap() or unwrap() operation on the SSLEngine.
   */
  private SSLEngineResult sslEngineResult = null;
  /**
   * Single thread pool used to executed the delegated tasks generated by the SSL handshake.
   */
  //private static ExecutorService executor = Executors.newSingleThreadExecutor(new JPPFThreadFactory("SSLDelegatedTasks"));
  private static ExecutorService executor = createExecutor();
  /**
   * The data the application is sending.
   */
  private ByteBuffer applicationSendBuffer;
  /**
   * The SSL data sent by the SSLEngine.
   */
  private ByteBuffer channelSendBuffer;
  /**
   * The data the application is receiving.
   */
  private ByteBuffer applicationReceiveBuffer;
  /**
   * The data recevied yb the SSLEngine.
   */
  private ByteBuffer channelReceiveBuffer;
  /**
   * Count of bytes read from the channel, in the scope of a {@link #read()} invocation.
   * This count includes all the SSL overhead: encrypted data, handshaking, renegotiation, etc.
   */
  private long channelReadCount = 0L;
  /**
   * Count of bytes written to the channel, in the scope of a {@link #write()} invocation.
   * This count includes all the SSL overhead: encrypted data, handshaking, renegotiation, etc.
   */
  private long channelWriteCount = 0L;

  /**
   * Instantiate this SSLHandler with the specified channel and SSL engine.
   * @param channel the channel from which data is read or to which data is written. 
   * @param sslEngine performs the SSL-related operations before sending data/after receiving data.
   * @throws Exception if any error occurs.
   */
  public SSLHandler(final ChannelWrapper channel, final SSLEngine sslEngine) throws Exception
  {
    this.channel = (SocketChannel) ((SelectionKey) channel.getChannel()).channel();
    this.sslEngine = sslEngine;
    SSLSession session = sslEngine.getSession();
    this.applicationSendBuffer = ByteBuffer.wrap(new byte[session.getApplicationBufferSize()]);
    this.channelSendBuffer = ByteBuffer.wrap(new byte[session.getPacketBufferSize()]);
    this.applicationReceiveBuffer = ByteBuffer.wrap(new byte[session.getApplicationBufferSize()]);
    this.channelReceiveBuffer = ByteBuffer.wrap(new byte[session.getPacketBufferSize()]);
  }

  /**
   * Read from the channel via the SSLEngine into the application receive buffer.
   * Called in blocking mode when input is expected, or in non-blocking mode when the channel is readable.
   * @return the number of bytes read from the application receive buffer.
   * @throws Exception if any error occurs.
   */
  public int read() throws Exception
  {
    channelReadCount = 0L;
    int sslCount = 0;
    int count = applicationReceiveBuffer.position();
    do
    {
      flush();
      if (sslEngine.isInboundDone()) return count > 0 ? count : -1;
      int readCount = doRead();
      channelReceiveBuffer.flip();
      sslEngineResult = sslEngine.unwrap(channelReceiveBuffer, applicationReceiveBuffer);
      channelReceiveBuffer.compact();
      switch (sslEngineResult.getStatus())
      {
        case BUFFER_UNDERFLOW:
          if (traceEnabled) log.trace("reading into netRecv=" + channelReceiveBuffer);
          sslCount = doRead();
          if (traceEnabled) log.trace("sslCount=" + sslCount + ", channelReceiveBuffer=" + channelReceiveBuffer);
          if (sslCount == 0) return count;
          if (sslCount == -1)
          {
            if (traceEnabled) log.trace("reached EOF, closing inbound");
            sslEngine.closeInbound();
          }
          break;

        case BUFFER_OVERFLOW:
          return 0;

        case CLOSED:
          channel.socket().shutdownInput();
          break;

        case OK:
          count = applicationReceiveBuffer.position();
          break;
      }
      while (processHandshake());
      count = applicationReceiveBuffer.position();
    }
    while (count == 0);
    if (sslEngine.isInboundDone()) count = -1;
    return count;
  }

  /**
   * Write from the application send buffer to the channel via the SSLEngine.
   * @return the number of bytes consumed from the application.
   * @throws Exception if any error occurs.
   */
  public int write() throws Exception
  {
    if (traceEnabled) log.trace("position=" + applicationSendBuffer.position());
    channelWriteCount = 0L;
    int remaining = applicationSendBuffer.position();
    int writeCount = 0;
    if ((remaining > 0) && (flush() > 0)) return 0;
    while (remaining > 0)
    {
      if (traceEnabled) log.trace("before flip/wrap/compact " + printSendBuffers() + " count=" + remaining);
      applicationSendBuffer.flip();
      sslEngineResult = sslEngine.wrap(applicationSendBuffer, channelSendBuffer);
      applicationSendBuffer.compact();
      if (traceEnabled) log.trace("after  flip/wrap/compact " + printSendBuffers());
      switch (sslEngineResult.getStatus())
      {
        case BUFFER_UNDERFLOW:
          if (traceEnabled) log.trace("write", new BufferUnderflowException());
          throw new BufferUnderflowException();

        case BUFFER_OVERFLOW:
          if (traceEnabled) log.trace("buffer overflow, before flush() channelSendBuffer=" + channelSendBuffer);
          int flushCount = flush();
          if (traceEnabled) log.trace("buffer overflow, after  flush() channelSendBuffer=" + channelSendBuffer + ", flushCount=" + flushCount);
          if (flushCount == 0) return 0;
          continue;

        case CLOSED:
          throw new SSLException("outbound closed");

        case OK:
          int n = sslEngineResult.bytesConsumed();
          writeCount += n;
          remaining -= n;
          break;
      }
      while (processHandshake());
    }
    return writeCount;
  }

  /**
   * Flush the underlying channel.
   * @return the number of bytes flushed.
   * @throws IOException if any error occurs.
   */
  public int flush() throws IOException
  {
    channelSendBuffer.flip();
    int n = channel.write(channelSendBuffer);
    if (n > 0) channelWriteCount += n;
    channelSendBuffer.compact();
    return n;
  }

  /**
   * Read bytes from the underlying channel.
   * @return the number of bytes read.
   * @throws IOException if any error occurs.
   */
  private int doRead() throws IOException
  {
    int n = channel.read(channelReceiveBuffer);
    if (n > 0) channelReadCount += n;
    else if (n < 0) throw new EOFException("EOF reading inbound stream");
    return n;
  }

  /**
   * Close the underlying channel and SSL engine.
   * @throws Exception if any error occurs.
   */
  public void close() throws Exception
  {
    if (!sslEngine.isInboundDone() && !channel.isBlocking()) read();
    while (channelSendBuffer.position() > 0)
    {
      int n = flush();
      if (n == 0)
      {
        log.error("unable to flush remaining " + channelSendBuffer.remaining() + " bytes");
        break;
      }
    }
    sslEngine.closeOutbound();
    if (traceEnabled) log.trace("close outbound handshake");
    while (processHandshake());
    if (channelSendBuffer.position() > 0 && flush() == 0) log.error("unable to flush remaining " + channelSendBuffer.position() + " bytes");
    if (traceEnabled) log.trace("close outbound done");
    channel.close();
    if (traceEnabled) log.trace("SSLEngine closed");
  }

  /**
   * 
   * @throws Exception if any error occurs.
   */
  private void processEngineResult() throws Exception
  {
    while (processEngineResultStatus() && processHandshake()) continue;
  }

  /**
   * Process the current handshaking status.
   * @return true if handshaking is still ongoing, false otherwise.
   * @throws Exception if any error occurs.
   */
  private boolean processHandshake() throws Exception
  {
    int count;
    switch (sslEngine.getHandshakeStatus())
    {
      case NOT_HANDSHAKING:
      case FINISHED:
        return false;

      case NEED_TASK:
        performDelegatedTasks();
        return true;

      case NEED_WRAP:
        applicationSendBuffer.flip();
        sslEngineResult = sslEngine.wrap(applicationSendBuffer, channelSendBuffer);
        applicationSendBuffer.compact();
        if (sslEngineResult.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW)
        {
          count = flush();
          return count > 0;
        }
        return true;

      case NEED_UNWRAP:
        channelReceiveBuffer.flip();
        sslEngineResult = sslEngine.unwrap(channelReceiveBuffer, applicationReceiveBuffer);
        channelReceiveBuffer.compact();
        if (sslEngineResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW)
        {
          if (sslEngine.isInboundDone()) count = -1;
          else count = doRead();
          if (traceEnabled) log.trace("readCount=" + count);
          return count > 0;
        }
        if (sslEngineResult.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) return false;
        return true;

      default:
        return false;
    }
  }

  /**
   * 
   * @return true if a full SSL packet was read or written, false otherwise.
   * @throws Exception if any error occurs.
   */
  boolean processEngineResultStatus() throws Exception
  {
    int count;
    if (traceEnabled) log.trace("sslEngineResult=" + sslEngineResult);
    switch (sslEngineResult.getStatus())
    {
      case OK:
        return true;

      case CLOSED:
        return sslEngineResult.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;

      case BUFFER_OVERFLOW:
        switch (sslEngineResult.getHandshakeStatus())
        {
          case NEED_WRAP:
            flush();
            return channelSendBuffer.position() == 0;

          case NEED_UNWRAP:
            if (traceEnabled) log.trace(printBuffers());
            return false;

          default:
            return false;
        }

      case BUFFER_UNDERFLOW:
        if (traceEnabled) log.trace(printBuffers());
        flush();
        count = doRead();
        if (traceEnabled) log.trace("underflow: count=" + count + ", channelReceiveBuffer=" + channelReceiveBuffer);
        return count > 0;

      default:
        return false;
    }
  }

  /**
   * Run delegated tasks for the handshake.
   */
  private void performDelegatedTasks()
  {
    Runnable delegatedTask;
    List> futures = new ArrayList<>();
    while ((delegatedTask = sslEngine.getDelegatedTask()) != null)
    {
      if (traceEnabled) log.trace("running delegated task " + delegatedTask);
      futures.add(executor.submit(delegatedTask));
    }
    for (Future f: futures)
    {
      try
      {
        f.get();
      }
      catch (Exception e)
      {
        if (traceEnabled) log.trace(e.getMessage(), e);
        else log.warn(ExceptionUtils.getMessage(e));
      }
    }
  }

  /**
   * Get the application receive buffer.
   * @return a {@link ByteBuffer} instance.
   */
  public ByteBuffer getApplicationReceiveBuffer()
  {
    return applicationReceiveBuffer;
  }

  /**
   * Get the application send buffer.
   * @return a {@link ByteBuffer} instance.
   */
  public ByteBuffer getApplicationSendBuffer()
  {
    return applicationSendBuffer;
  }

  /**
   * Get the channel receive buffer.
   * @return a {@link ByteBuffer} instance.
   */
  public ByteBuffer getChannelReceiveBuffer()
  {
    return channelReceiveBuffer;
  }
  /**
   * Get the channel send buffer.
   * @return a {@link ByteBuffer} instance.
   */
  public ByteBuffer getChannelSendBuffer()
  {
    return channelSendBuffer;
  }

  /**
   * Perform an SSLEngine.wrap() operation.
   * @return the resulting {@link SSLEngineResult}.
   * @throws Exception if any error occurs.
   */
  private SSLEngineResult doWrap() throws Exception
  {
    return sslEngine.wrap(applicationSendBuffer, channelSendBuffer);
  }

  /**
   * Print the state of all buffers to a string.
   * This method is intended for logging and debugging purposes.
   * @return a string representation of the buffers states.
   */
  private String printBuffers()
  {
    StringBuilder sb = new StringBuilder();
    sb.append("applicationSendBuffer=").append(applicationSendBuffer);
    sb.append(", channelSendBuffer=").append(channelSendBuffer);
    sb.append(", applicationReceiveBuffer=").append(applicationReceiveBuffer);
    sb.append(", channelReceiveBuffer=").append(channelReceiveBuffer);
    return sb.toString();
  }
  
  /**
   * Print the state of all send buffers to a string.
   * This method is intended for logging and debugging purposes.
   * @return a string representation of the send buffers states.
   */
  private String printSendBuffers()
  {
    StringBuilder sb = new StringBuilder();
    sb.append("applicationSendBuffer=").append(applicationSendBuffer);
    sb.append(", channelSendBuffer=").append(channelSendBuffer);
    return sb.toString();
  }

  /**
   * Get the count of bytes read from the channel.
   * @return the byte count as a long value.
   */
  public long getChannelReadCount()
  {
    return channelReadCount;
  }

  /**
   * Get the count of bytes written to the channel.
   * @return the byte count as a long value.
   */
  public long getChannelWriteCount()
  {
    return channelWriteCount;
  }

  /**
   * Create the executor which runs the SSLEngine delegated tasks.
   * @return an {@link ExecutorService} instance.
   */
  private static ExecutorService createExecutor() {
    int n = JPPFConfiguration.getProperties().getInt("jppf.ssl.thread.pool", 10);
    LinkedBlockingQueue queue = new LinkedBlockingQueue<>();
    JPPFThreadFactory tf = new JPPFThreadFactory("SSLDelegatedTasks");
    ThreadPoolExecutor exec = new ThreadPoolExecutor(n, n, 10L, TimeUnit.SECONDS, queue, tf);
    exec.allowCoreThreadTimeOut(true);
    return exec;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy