All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.snmp4j.transport.TLSTM Maven / Gradle / Ivy

There is a newer version: 3.9.0
Show newest version
/*_############################################################################
  _## 
  _##  SNMP4J - TLSTM.java  
  _## 
  _##  Copyright (C) 2003-2018  Frank Fock and Jochen Katz (SNMP4J.org)
  _##  
  _##  Licensed under the Apache License, Version 2.0 (the "License");
  _##  you may not use this file except in compliance with the License.
  _##  You may obtain a copy of the License at
  _##  
  _##      http://www.apache.org/licenses/LICENSE-2.0
  _##  
  _##  Unless required by applicable law or agreed to in writing, software
  _##  distributed under the License is distributed on an "AS IS" BASIS,
  _##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  _##  See the License for the specific language governing permissions and
  _##  limitations under the License.
  _##  
  _##########################################################################*/

package org.snmp4j.transport;

import org.snmp4j.SNMP4JSettings;
import org.snmp4j.TransportStateReference;
import org.snmp4j.asn1.BER;
import org.snmp4j.asn1.BERInputStream;
import org.snmp4j.event.CounterEvent;
import org.snmp4j.log.LogAdapter;
import org.snmp4j.log.LogFactory;
import org.snmp4j.mp.CounterSupport;
import org.snmp4j.mp.SnmpConstants;
import org.snmp4j.security.SecurityLevel;
import org.snmp4j.smi.*;
import org.snmp4j.transport.tls.*;
import org.snmp4j.util.CommonTimer;
import org.snmp4j.util.SnmpConfigurator;
import org.snmp4j.util.WorkerTask;

import javax.net.ssl.*;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.*;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.security.*;
import java.security.cert.*;
import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;

/**
 * The TLSTM implements the Transport Layer Security
 * Transport Mapping (TLS-TM) as defined by RFC 5953
 * with the new IO API and {@link javax.net.ssl.SSLEngine}.
 * 

* It uses a single thread for processing incoming and outgoing messages. * The thread is started when the listen method is called, or * when an outgoing request is sent using the sendMessage method. * * @author Frank Fock * @version 3.0.5 * @since 2.0 */ public class TLSTM extends TcpTransportMapping implements X509TlsTransportMappingConfig { private static final LogAdapter logger = LogFactory.getLogger(TLSTM.class); public static final int TLS_MAX_FRAGMENT_SIZE = 16384; private WorkerTask server; private ServerThread serverThread; private CommonTimer socketCleaner; // 1 minute default timeout private long connectionTimeout = 60000; private boolean serverEnabled = false; private long nextSessionID = 1; private SSLEngineConfigurator sslEngineConfigurator; private TlsTmSecurityCallback securityCallback; private CounterSupport counterSupport; public static final String DEFAULT_TLSTM_PROTOCOLS = "TLSv1"; public static final int MAX_TLS_PAYLOAD_SIZE = 32 * 1024; private String localCertificateAlias; private String keyStore; private String keyStorePassword; private String trustStore; private String trustStorePassword; private String[] tlsProtocols; private TLSTMTrustManagerFactory trustManagerFactory = new DefaultTLSTMTrustManagerFactory(); private int tlsMaxFragmentSize; /** * Creates a default TCP transport mapping with the server for incoming * messages disabled. * * @throws UnknownHostException * if the local host cannot be determined. */ public TLSTM() throws UnknownHostException { super(new TlsAddress(InetAddress.getLocalHost(), 0)); this.counterSupport = CounterSupport.getInstance(); super.maxInboundMessageSize = MAX_TLS_PAYLOAD_SIZE; this.tlsMaxFragmentSize = TLS_MAX_FRAGMENT_SIZE; } /** * Creates a TLS transport mapping with the server for incoming * messages bind to the given address. The {@code securityCallback} * needs to be specified before {@link #listen()} is called. * * @param address * the address to bind for incoming requests. * * @throws java.io.IOException * on failure of binding a local port. */ public TLSTM(TlsAddress address) throws IOException { super(address); super.maxInboundMessageSize = MAX_TLS_PAYLOAD_SIZE; this.serverEnabled = true; this.counterSupport = CounterSupport.getInstance(); this.tlsMaxFragmentSize = TLS_MAX_FRAGMENT_SIZE; try { if (Class.forName("javax.net.ssl.X509ExtendedTrustManager") != null) { Class trustManagerFactoryClass = Class.forName("org.snmp4j.transport.tls.TLSTMExtendedTrustManagerFactory"); Constructor c = trustManagerFactoryClass.getConstructors()[0]; TLSTMTrustManagerFactory trustManagerFactory = (TLSTMTrustManagerFactory) c.newInstance(CounterSupport.getInstance(), securityCallback); setTrustManagerFactory(trustManagerFactory); } } catch (ClassNotFoundException ex) { //throw new IOException("Failed to load TLSTMTrustManagerFactory: "+ex.getMessage(), ex); } catch (InvocationTargetException ex) { throw new IOException("Failed to init TLSTMTrustManagerFactory: " + ex.getMessage(), ex); } catch (IllegalArgumentException ex) { throw new IOException("Failed to setup TLSTMTrustManagerFactory: " + ex.getMessage(), ex); } catch (IllegalAccessException ex) { throw new IOException("Failed to access TLSTMTrustManagerFactory: " + ex.getMessage(), ex); } catch (InstantiationException ex) { throw new IOException("Failed to instantiate TLSTMTrustManagerFactory: " + ex.getMessage(), ex); } } /** * Creates a TLS transport mapping that binds to the given address * (interface) on the local host. * * @param securityCallback * a security name callback to resolve X509 certificates to tmSecurityNames. * @param serverAddress * the TcpAddress instance that describes the server address to listen * on incoming connection requests. * * @throws java.io.IOException * if the given address cannot be bound. */ public TLSTM(TlsTmSecurityCallback securityCallback, TlsAddress serverAddress) throws IOException { this(securityCallback, serverAddress, CounterSupport.getInstance()); } /** * Creates a TLS transport mapping that binds to the given address * (interface) on the local host. * * @param securityCallback * a security name callback to resolve X509 certificates to tmSecurityNames. * @param serverAddress * the TcpAddress instance that describes the server address to listen * on incoming connection requests. * @param counterSupport * The CounterSupport instance to be used to count events created by this * TLSTM instance. To get a default instance, use * {@link CounterSupport#getInstance()}. * * @throws java.io.IOException * if the given address cannot be bound. */ public TLSTM(TlsTmSecurityCallback securityCallback, TlsAddress serverAddress, CounterSupport counterSupport) throws IOException { super(serverAddress); super.maxInboundMessageSize = MAX_TLS_PAYLOAD_SIZE; this.serverEnabled = true; this.securityCallback = securityCallback; this.counterSupport = counterSupport; } public String getLocalCertificateAlias() { if (localCertificateAlias == null) { return System.getProperty(SnmpConfigurator.P_TLS_LOCAL_ID, null); } return localCertificateAlias; } /** * Gets the maximum fragment size of supported for this transport mapping when acting as TLS server. * @return * the maximum TLS fragment size as defined by RFC 6066 section 4. */ public int getTlsMaxFragmentSize() { return tlsMaxFragmentSize; } /** * Sets the maximum TLS fragment size that this transport mapping should support as server. There is no need to * change that from the default {@link #TLS_MAX_FRAGMENT_SIZE} unless, a new Java version allows to set the * maximum fragment size to a lower value. * @param tlsMaxFragmentSize * a value as defined by RFC 6066 section 4. * @since 3.0.5 */ public void setTlsMaxFragmentSize(int tlsMaxFragmentSize) { this.tlsMaxFragmentSize = tlsMaxFragmentSize; } /** * Gets the TLS protocols supported by this transport mapping. * @return * an array of TLS protocol (version) names supported by the SunJSSE provider. * @deprecated Use {@link #getProtocolVersions} instead. */ @Deprecated public String[] getTlsProtocols() { return getProtocolVersions(); } /** * Sets the TLS protocols/versions that TLSTM should use during handshake. * The default is defined by {@link #DEFAULT_TLSTM_PROTOCOLS}. * * @param tlsProtocols * an array of TLS protocol (version) names supported by the SunJSSE provider. * The order in the array defines which protocol is tried during handshake * first. * * @since 2.0.3 * @deprecated Use {@link #setProtocolVersions(String[])} instead. */ @Deprecated public void setTlsProtocols(String[] tlsProtocols) { setProtocolVersions(tlsProtocols); } /** * Sets the TLS protocols/versions that TLSTM should use during handshake. * The default is defined by {@link #DEFAULT_TLSTM_PROTOCOLS}. * * @param protocolVersions * an array of TLS protocol (version) names supported by the SunJSSE provider. * The order in the array defines which protocol is tried during handshake * first. * @since 3.0 */ @Override public void setProtocolVersions(String[] protocolVersions) { this.tlsProtocols = protocolVersions; } @Override public String[] getProtocolVersions() { if (tlsProtocols == null) { String s = System.getProperty(getProtocolVersionPropertyName(), DEFAULT_TLSTM_PROTOCOLS); return s.split(","); } return tlsProtocols; } /** * Returns the property name that is used by this transport mapping to determine the protocol versions * from system properties. * * @return a property name like {@link SnmpConfigurator#P_TLS_VERSION} or * {@link SnmpConfigurator#P_DTLS_VERSION}. * @since 3.0 */ @Override public String getProtocolVersionPropertyName() { return SnmpConfigurator.P_TLS_VERSION; } public String getKeyStore() { if (keyStore == null) { return System.getProperty("javax.net.ssl.keyStore"); } return keyStore; } public void setKeyStore(String keyStore) { this.keyStore = keyStore; } public String getKeyStorePassword() { if (keyStorePassword == null) { return System.getProperty("javax.net.ssl.keyStorePassword"); } return keyStorePassword; } public void setKeyStorePassword(String keyStorePassword) { this.keyStorePassword = keyStorePassword; } public String getTrustStore() { if (trustStore == null) { return System.getProperty("javax.net.ssl.trustStore"); } return trustStore; } public void setTrustStore(String trustStore) { this.trustStore = trustStore; } public String getTrustStorePassword() { if (trustStorePassword == null) { return System.getProperty("javax.net.ssl.trustStorePassword"); } return trustStorePassword; } public void setTrustStorePassword(String trustStorePassword) { this.trustStorePassword = trustStorePassword; } /** * Sets the certificate alias used for client and server authentication * by this TLSTM. Setting this property to a value other than {@code null} * filters out any certificates which are not in the chain of the given * alias. * * @param localCertificateAlias * a certificate alias which filters a single certification chain from * the {@code javax.net.ssl.keyStore} key store to be used to * authenticate this TLS transport mapping. If {@code null} no * filtering appears, which could lead to more than a single chain * available for authentication by the peer, which would violate the * TLSTM standard requirements. */ public void setLocalCertificateAlias(String localCertificateAlias) { this.localCertificateAlias = localCertificateAlias; } public CounterSupport getCounterSupport() { return counterSupport; } @Override public Class getSupportedAddressClass() { return TlsAddress.class; } public TlsTmSecurityCallback getSecurityCallback() { return securityCallback; } public void setSecurityCallback(TlsTmSecurityCallback securityCallback) { this.securityCallback = securityCallback; } public SSLEngineConfigurator getSslEngineConfigurator() { return sslEngineConfigurator; } /** * Sets the configurator for the {@link SSLEngine} internally used to run the TLS communication. This method should * be called before any new connection is established that should use this configurator/configuration. * @param sslEngineConfigurator * a {@link SSLEngineConfigurator} instance like {@link DefaultSSLEngineConfiguration}. * @since 3.0.5 */ public void setSslEngineConfigurator(SSLEngineConfigurator sslEngineConfigurator) { this.sslEngineConfigurator = sslEngineConfigurator; } public TLSTMTrustManagerFactory getTrustManagerFactory() { return trustManagerFactory; } /** * Set the TLSTM trust manager factory. Using a trust manager factory other than the * default allows to add support for Java 1.7 X509ExtendedTrustManager. * * @param trustManagerFactory * a X.509 trust manager factory implementing the interface {@link TLSTMTrustManagerFactory}. * * @since 2.0.3 */ public void setTrustManagerFactory(TLSTMTrustManagerFactory trustManagerFactory) { if (trustManagerFactory == null) { throw new NullPointerException(); } this.trustManagerFactory = trustManagerFactory; } /** * Listen for incoming and outgoing requests. If the {@code serverEnabled} * member is {@code false} the server for incoming requests is not * started. This starts the internal server thread that processes messages. * * @throws java.net.SocketException * when the transport is already listening for incoming/outgoing messages. * @throws java.io.IOException * if the listen port could not be bound to the server thread. */ public synchronized void listen() throws IOException { if (server != null) { throw new SocketException("Port already listening"); } try { serverThread = new ServerThread(); if (logger.isInfoEnabled()) { logger.info("TCP address " + tcpAddress + " bound successfully"); } } catch (NoSuchAlgorithmException e) { throw new IOException("SSL not available: " + e.getMessage(), e); } server = SNMP4JSettings.getThreadFactory().createWorkerThread( "TLSTM_" + getAddress(), serverThread, true); if (connectionTimeout > 0) { // run as daemon socketCleaner = SNMP4JSettings.getTimerFactory().createTimer(); } server.run(); } /** * Sets the name of the listen thread for this UDP transport mapping. * This method has no effect, if called before {@link #listen()} has been * called for this transport mapping. * * @param name * the new thread name. * * @since 1.6 */ public void setThreadName(String name) { WorkerTask st = server; if (st instanceof Thread) { ((Thread) st).setName(name); } } /** * Returns the name of the listen thread. * * @return the thread name if in listening mode, otherwise {@code null}. * @since 1.6 */ public String getThreadName() { WorkerTask st = server; if (st != null) { return ((Thread) st).getName(); } else { return null; } } /** * Closes all open sockets and stops the internal server thread that * processes messages. */ public void close() { for (SocketEntry entry : sockets.values()) { entry.closeSession(); } WorkerTask st = server; server = null; if (st != null) { st.terminate(); st.interrupt(); try { st.join(); } catch (InterruptedException ex) { logger.warn(ex); } closeSockets(sockets); if (socketCleaner != null) { socketCleaner.cancel(); } socketCleaner = null; } } /** * Sends a SNMP message to the supplied address. * * @param address * an {@code TcpAddress}. A {@code ClassCastException} is thrown * if {@code address} is not a {@code TcpAddress} instance. * @param message * byte[] * the message to sent. * @param tmStateReference * the (optional) transport model state reference as defined by * RFC 5590 section 6.1. * @param timeoutMillis * maximum number of milli seconds the connection creation might take (if connection based). * @param maxRetries * maximum retries during connection creation. * * @throws java.io.IOException * if an IO exception occurs while trying to send the message. */ public void sendMessage(TcpAddress address, byte[] message, TransportStateReference tmStateReference, long timeoutMillis, int maxRetries) throws IOException { if (server == null) { listen(); } serverThread.sendMessage(address, message, tmStateReference); } /** * Gets the connection timeout. This timeout specifies the time a connection * may be idle before it is closed. * * @return long * the idle timeout in milliseconds. */ public long getConnectionTimeout() { return connectionTimeout; } /** * Sets the connection timeout. This timeout specifies the time a connection * may be idle before it is closed. * * @param connectionTimeout * the idle timeout in milliseconds. A zero or negative value will disable * any timeout and connections opened by this transport mapping will stay * opened until they are explicitly closed. */ public void setConnectionTimeout(long connectionTimeout) { this.connectionTimeout = connectionTimeout; } /** * Gets the {@link CommonTimer} that controls socket cleanup operations. * * @return a socket cleaner timer. * @since 3.0 */ @Override public CommonTimer getSocketCleaner() { return socketCleaner; } /** * Checks whether a server for incoming requests is enabled. * * @return boolean */ public boolean isServerEnabled() { return serverEnabled; } @Override public MessageLengthDecoder getMessageLengthDecoder() { return null; } /** * Sets whether a server for incoming requests should be created when * the transport is set into listen state. Setting this value has no effect * until the {@link #listen()} method is called (if the transport is already * listening, {@link #close()} has to be called before). * * @param serverEnabled * if {@code true} if the transport will listens for incoming * requests after {@link #listen()} has been called. */ public void setServerEnabled(boolean serverEnabled) { this.serverEnabled = serverEnabled; } @Override public void setMessageLengthDecoder(MessageLengthDecoder messageLengthDecoder) { /* if (messageLengthDecoder == null) { throw new NullPointerException(); } this.messageLengthDecoder = messageLengthDecoder; */ } /** * Gets the inbound buffer size for incoming requests. When SNMP packets are * received that are longer than this maximum size, the messages will be * silently dropped and the connection will be closed. * * @return the maximum inbound buffer size in bytes. */ public int getMaxInboundMessageSize() { return super.getMaxInboundMessageSize(); } /** * Sets the maximum buffer size for incoming requests. When SNMP packets are * received that are longer than this maximum size, the messages will be * silently dropped and the connection will be closed. * * @param maxInboundMessageSize * the length of the inbound buffer in bytes. */ public void setMaxInboundMessageSize(int maxInboundMessageSize) { this.maxInboundMessageSize = maxInboundMessageSize; } private synchronized void timeoutSocket(SocketEntry entry) { if (connectionTimeout > 0) { SocketTimeout socketTimeout = new SocketTimeout<>(this, entry); entry.setSocketTimeout(socketTimeout); socketCleaner.schedule(socketTimeout, connectionTimeout); } } public boolean isListening() { return (server != null); } @Override public TcpAddress getListenAddress() { int port = tcpAddress.getPort(); ServerThread serverThreadCopy = serverThread; try { port = serverThreadCopy.ssc.socket().getLocalPort(); } catch (NullPointerException npe) { if (logger.isDebugEnabled()) { logger.debug("TLSTM.getListenAddress called but TLSTM is not listening yet"); } } return new TcpAddress(tcpAddress.getInetAddress(), port); } class SocketEntry extends AbstractSocketEntry { private LinkedList message = new LinkedList(); private ByteBuffer inNetBuffer; private ByteBuffer inAppBuffer; private ByteBuffer outAppBuffer; private ByteBuffer outNetBuffer; private SSLEngine sslEngine; private long sessionID; private TransportStateReference tmStateReference; private boolean handshakeFinished; private final Object outboundLock = new Object(); private final Object inboundLock = new Object(); public SocketEntry(TcpAddress address, Socket socket, boolean useClientMode, TransportStateReference tmStateReference) throws NoSuchAlgorithmException { super(address, socket); this.inAppBuffer = ByteBuffer.allocate(getMaxInboundMessageSize()); this.inNetBuffer = ByteBuffer.allocate(getMaxInboundMessageSize()); this.outNetBuffer = ByteBuffer.allocate(getMaxInboundMessageSize()); this.tmStateReference = tmStateReference; if (tmStateReference == null) { counterSupport.fireIncrementCounter(new CounterEvent(this, SnmpConstants.snmpTlstmSessionAccepts)); } SSLEngineConfigurator sslEngineConfigurator = ensureSslEngineConfigurator(); SSLContext sslContext = sslEngineConfigurator.getSSLContext(useClientMode, tmStateReference); this.sslEngine = sslContext.createSSLEngine(address.getInetAddress().getHostName(), address.getPort()); sslEngine.setUseClientMode(useClientMode); sslEngineConfigurator.configure(sslEngine); synchronized (TLSTM.this) { sessionID = nextSessionID++; } } public synchronized void addMessage(byte[] message) { this.message.add(message); } public synchronized byte[] nextMessage() { if (this.message.size() > 0) { return this.message.removeFirst(); } return null; } public synchronized boolean hasMessage() { return !this.message.isEmpty(); } @Override public void setSocketTimeout(SocketTimeout socketTimeout) { } public void setInNetBuffer(ByteBuffer byteBuffer) { this.inNetBuffer = byteBuffer; } public ByteBuffer getInNetBuffer() { return inNetBuffer; } public ByteBuffer getOutNetBuffer() { return outNetBuffer; } public void setOutNetBuffer(ByteBuffer outNetBuffer) { this.outNetBuffer = outNetBuffer; } public String toString() { return "SocketEntry[peerAddress=" + getPeerAddress() + ",socket=" + socket + ",lastUse=" + new Date(getLastUse() / SnmpConstants.MILLISECOND_TO_NANOSECOND) + ",inNetBuffer=" + inNetBuffer + ",inAppBuffer=" + inAppBuffer + ",outAppBuffer=" + outAppBuffer + ",outNetBuffer=" + outNetBuffer + ",socketTimeout=" + getSocketTimeout() + "]"; } public void checkTransportStateReference() { if (tmStateReference == null) { tmStateReference = new TransportStateReference(TLSTM.this, getPeerAddress(), new OctetString(), SecurityLevel.authPriv, SecurityLevel.authPriv, true, sessionID); OctetString securityName = null; if (securityCallback != null) { try { securityName = securityCallback.getSecurityName( (X509Certificate[]) sslEngine.getSession().getPeerCertificates()); } catch (SSLPeerUnverifiedException e) { logger.error("SSL peer '" + getPeerAddress() + "' is not verified: " + e.getMessage(), e); sslEngine.setEnableSessionCreation(false); } } tmStateReference.setSecurityName(securityName); } else if (tmStateReference.getTransportSecurityLevel().equals(SecurityLevel.undefined)) { tmStateReference.setTransportSecurityLevel(SecurityLevel.authPriv); } } public void setInAppBuffer(ByteBuffer inAppBuffer) { this.inAppBuffer = inAppBuffer; } public ByteBuffer getInAppBuffer() { return inAppBuffer; } public boolean isHandshakeFinished() { return handshakeFinished; } public void setHandshakeFinished(boolean handshakeFinished) { this.handshakeFinished = handshakeFinished; } public boolean isAppOutPending() { synchronized (outboundLock) { return (outAppBuffer != null) && (outAppBuffer.limit() > 0); } } public long getSessionID() { return sessionID; } public void closeSession() { sslEngine.closeOutbound(); counterSupport.fireIncrementCounter(new CounterEvent(this, SnmpConstants.snmpTlstmSessionServerCloses)); try { SSLEngineResult result; do { result = sendNetMessage(this); } while (result != null && (result.getStatus() != SSLEngineResult.Status.CLOSED) && (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP)); } catch (IOException e) { logger.error("IOException while closing outbound channel of " + this + ": " + e.getMessage(), e); } } } /** * Returns the configured {@link #setSslEngineConfigurator(SSLEngineConfigurator)} or the * {@link DefaultSSLEngineConfiguration} which will then become the configured SSL engine configurator. This method * is not synchronized against concurrent execution of {@link #setSslEngineConfigurator(SSLEngineConfigurator)}. * @return * a non-null {@link SSLEngineConfigurator}. * @since 3.0.5 */ protected SSLEngineConfigurator ensureSslEngineConfigurator() { if (sslEngineConfigurator == null) { sslEngineConfigurator = new DefaultSSLEngineConfiguration(this, trustManagerFactory, DEFAULT_TLSTM_PROTOCOLS); } return sslEngineConfigurator; } class ServerThread extends AbstractTcpServerThread { private Throwable lastError = null; private ServerSocketChannel ssc; private BlockingQueue outQueue = new LinkedBlockingQueue(); private BlockingQueue inQueue = new LinkedBlockingQueue(); public ServerThread() throws IOException, NoSuchAlgorithmException { super(TLSTM.this); // Selector for incoming requests if (serverEnabled) { // Create a new server socket and set to non blocking mode ssc = ServerSocketChannel.open(); ssc.configureBlocking(false); // Bind the server socket InetSocketAddress isa = new InetSocketAddress(tcpAddress.getInetAddress(), tcpAddress.getPort()); setSocketOptions(ssc.socket()); ssc.socket().bind(isa); // Register accepts on the server socket with the selector. This // step tells the selector that the socket wants to be put on the // ready list when accept operations occur, so allowing multiplexed // non-blocking I/O to take place. ssc.register(selector, SelectionKey.OP_ACCEPT); } } private synchronized void processQueues() { while (!outQueue.isEmpty() || !inQueue.isEmpty()) { while (!outQueue.isEmpty()) { SocketEntry entry = null; try { SSLEngineResult result; entry = outQueue.take(); result = sendNetMessage(entry); if ((result != null) && runDelegatedTasks(result, entry)) { if (entry.isAppOutPending()) { writeMessage(entry, entry.getSocket().getChannel()); } } } catch (IOException iox) { logger.error("IO exception caught while SSL processing: " + iox.getMessage(), iox); while (inQueue.remove(entry)) { // no body } } catch (InterruptedException e) { logger.error("SSL processing interrupted: " + e.getMessage(), e); return; } } while (!inQueue.isEmpty()) { SocketEntry entry = null; try { entry = inQueue.take(); synchronized (entry.inboundLock) { entry.inNetBuffer.flip(); logger.debug("TLS inNetBuffer = " + entry.inNetBuffer); SSLEngineResult nextResult = entry.sslEngine.unwrap(entry.inNetBuffer, entry.inAppBuffer); adjustInNetBuffer(entry, nextResult); if (runDelegatedTasks(nextResult, entry)) { switch (nextResult.getStatus()) { case BUFFER_UNDERFLOW: entry.inNetBuffer.limit(entry.inNetBuffer.capacity()); entry.addRegistration(selector, SelectionKey.OP_READ); break; case BUFFER_OVERFLOW: // TODO break; case CLOSED: continue; case OK: if (entry.isAppOutPending()) { // we have a message to send writeMessage(entry, entry.getSocket().getChannel()); } entry.inAppBuffer.flip(); logger.debug("Dispatching inAppBuffer=" + entry.inAppBuffer); if (entry.inAppBuffer.limit() > 0) { dispatchMessage(entry.getPeerAddress(), entry.inAppBuffer, entry.inAppBuffer.limit(), entry.sessionID, entry.tmStateReference); } entry.inAppBuffer.clear(); } } } } catch (IOException iox) { logger.error("IO exception caught while SSL processing: " + iox.getMessage(), iox); while (inQueue.remove(entry)) { // no body } } catch (InterruptedException e) { logger.error("SSL processing interrupted: " + e.getMessage(), e); return; } } } } private void processPending() { synchronized (pending) { for (int i = 0; i < pending.size(); i++) { SocketEntry entry = pending.getFirst(); try { // Register the channel with the selector, indicating // interest in connection completion and attaching the // target object so that we can get the target back // after the key is added to the selector's // selected-key set if (entry.getSocket().isConnected()) { if (entry.isHandshakeFinished()) { entry.addRegistration(selector, SelectionKey.OP_WRITE); } } else { entry.addRegistration(selector, SelectionKey.OP_CONNECT); } } catch (CancelledKeyException ckex) { logger.warn(ckex); pending.remove(entry); try { entry.getSocket().getChannel().close(); TransportStateEvent e = new TransportStateEvent(TLSTM.this, entry.getPeerAddress(), TransportStateEvent.STATE_CLOSED, null); fireConnectionStateChanged(e); } catch (IOException ex) { logger.error(ex); } } catch (IOException iox) { logger.error(iox); pending.remove(entry); // Something went wrong, so close the channel and // record the failure try { entry.getSocket().getChannel().close(); TransportStateEvent e = new TransportStateEvent(TLSTM.this, entry.getPeerAddress(), TransportStateEvent.STATE_CLOSED, iox); fireConnectionStateChanged(e); } catch (IOException ex) { logger.error(ex); } lastError = iox; if (SNMP4JSettings.isForwardRuntimeExceptions()) { throw new RuntimeException(iox); } } } } } /** * If the result indicates that we have outstanding tasks to do, * go ahead and run them in this thread. * * @param result * the SSLEngine wrap/unwrap result. * @param entry * the session to use. * * @return {@code true} if processing of delegated tasks has been * finished, {@code false} otherwise. */ public boolean runDelegatedTasks(SSLEngineResult result, SocketEntry entry) throws IOException { if (logger.isDebugEnabled()) { logger.debug("Running delegated task on " + entry + ": " + result); } SSLEngineResult.HandshakeStatus status = result.getHandshakeStatus(); if (status == SSLEngineResult.HandshakeStatus.NEED_TASK) { Runnable runnable; while ((runnable = entry.sslEngine.getDelegatedTask()) != null) { logger.debug("Running delegated task..."); runnable.run(); } status = entry.sslEngine.getHandshakeStatus(); if (status == SSLEngineResult.HandshakeStatus.NEED_TASK) { throw new IOException("Inconsistent Handshake status"); } logger.info("Handshake status = " + status); } switch (result.getStatus()) { case BUFFER_UNDERFLOW: entry.inNetBuffer.limit(entry.inNetBuffer.capacity()); entry.addRegistration(selector, SelectionKey.OP_READ); return false; case CLOSED: return false; } switch (status) { case NEED_WRAP: outQueue.add(entry); // entry.addRegistration(selector, SelectionKey.OP_WRITE); break; case NEED_UNWRAP: logger.debug("NEED_UNRWAP processing with inNetBuffer=" + entry.inNetBuffer); inQueue.add(entry); entry.addRegistration(selector, SelectionKey.OP_READ); break; case FINISHED: logger.debug("TLS handshake finished"); entry.setHandshakeFinished(true);/* if (result.bytesProduced() > 0) { writeNetBuffer(entry, entry.getSocket().getChannel()); } /* if (entry.isAppOutPending()) { writeMessage(entry, entry.getSocket().getChannel()); } */ // fall through case NOT_HANDSHAKING: if (result.bytesProduced() > 0) { writeNetBuffer(entry, entry.getSocket().getChannel()); } return true; } return false; } public Throwable getLastError() { return lastError; } public void sendMessage(Address address, byte[] message, TransportStateReference tmStateReference) throws IOException { Socket s = null; SocketEntry entry = sockets.get(address); if (logger.isDebugEnabled()) { logger.debug("Looking up connection for destination '" + address + "' returned: " + entry); logger.debug(sockets.toString()); } if (entry != null) { if ((tmStateReference != null) && (tmStateReference.getSessionID() != null) && (!tmStateReference.getSessionID().equals(entry.getSessionID()))) { // session IDs do not match -> drop message counterSupport.fireIncrementCounter( new CounterEvent(this, SnmpConstants.snmpTlstmSessionNoSessions)); throw new IOException("Session " + tmStateReference.getSessionID() + " not available"); } s = entry.getSocket(); } if ((s == null) || (s.isClosed()) || (!s.isConnected())) { if (logger.isDebugEnabled()) { logger.debug("Socket for address '" + address + "' is closed, opening it..."); } synchronized (pending) { pending.remove(entry); } SocketChannel sc; try { InetSocketAddress targetAddress = new InetSocketAddress(((TcpAddress) address).getInetAddress(), ((TcpAddress) address).getPort()); if ((s == null) || (s.isClosed())) { // Open the channel, set it to non-blocking, initiate connect sc = SocketChannel.open(); sc.configureBlocking(false); sc.connect(targetAddress); counterSupport.fireIncrementCounter( new CounterEvent(this, SnmpConstants.snmpTlstmSessionOpens)); } else { sc = s.getChannel(); sc.configureBlocking(false); if (!sc.isConnectionPending()) { sc.connect(targetAddress); counterSupport.fireIncrementCounter( new CounterEvent(this, SnmpConstants.snmpTlstmSessionOpens)); } else { if (matchingStateReferences(tmStateReference, entry.tmStateReference)) { entry.addMessage(message); synchronized (pending) { pending.add(entry); } selector.wakeup(); return; } else { logger.error("TransportStateReferences refNew=" + tmStateReference + ",refOld=" + entry.tmStateReference + " do not match, message dropped"); throw new IOException("Transport state reference does not match existing reference" + " for this session/target"); } } } s = sc.socket(); entry = new SocketEntry((TcpAddress) address, s, true, tmStateReference); entry.addMessage(message); sockets.put(address, entry); synchronized (pending) { pending.add(entry); } selector.wakeup(); logger.debug("Trying to connect to " + address); } catch (IOException iox) { logger.error(iox); throw iox; } catch (NoSuchAlgorithmException e) { logger.error("NoSuchAlgorithmException while sending message to " + address + ": " + e.getMessage(), e); } } else if (matchingStateReferences(tmStateReference, entry.tmStateReference)) { entry.addMessage(message); synchronized (pending) { pending.addFirst(entry); } logger.debug("Waking up selector for new message"); selector.wakeup(); } else { logger.error("TransportStateReferences refNew=" + tmStateReference + ",refOld=" + entry.tmStateReference + " do not match, message dropped"); throw new IOException("Transport state reference does not match existing reference" + " for this session/target"); } } @Override public void run() { // Here's where everything happens. The select method will // return when any operations registered above have occurred, the // thread has been interrupted, etc. try { while (!stop) { try { processQueues(); if (selector.select() > 0) { if (stop) { break; } // Someone is ready for I/O, get the ready keys Set readyKeys = selector.selectedKeys(); Iterator it = readyKeys.iterator(); // Walk through the ready keys collection and process date requests. while (it.hasNext()) { try { SocketEntry entry = null; SelectionKey sk = it.next(); it.remove(); SocketChannel readChannel = null; TcpAddress incomingAddress = null; if (sk.isAcceptable()) { logger.debug("Key is acceptable"); // The key indexes into the selector so you // can retrieve the socket that's ready for I/O ServerSocketChannel nextReady = (ServerSocketChannel) sk.channel(); Socket s = nextReady.accept().socket(); readChannel = s.getChannel(); readChannel.configureBlocking(false); incomingAddress = new TcpAddress(s.getInetAddress(), s.getPort()); entry = new SocketEntry(incomingAddress, s, false, null); entry.addRegistration(selector, SelectionKey.OP_READ); sockets.put(incomingAddress, entry); timeoutSocket(entry); TransportStateEvent e = new TransportStateEvent(TLSTM.this, incomingAddress, TransportStateEvent. STATE_CONNECTED, null); fireConnectionStateChanged(e); if (e.isCancelled()) { logger.warn("Incoming connection cancelled"); s.close(); sockets.remove(incomingAddress); readChannel = null; } } else if (sk.isWritable()) { logger.debug("Key is writable"); incomingAddress = writeData(sk, incomingAddress); } else if (sk.isReadable()) { logger.debug("Key is readable"); readChannel = (SocketChannel) sk.channel(); incomingAddress = new TcpAddress(readChannel.socket().getInetAddress(), readChannel.socket().getPort()); } else if (sk.isConnectable()) { logger.debug("Key is connectable"); connectChannel(sk, incomingAddress); } if (readChannel != null) { logger.debug("Key is reading"); try { readMessage(sk, readChannel, incomingAddress, entry); } catch (IOException iox) { // IO exception -> channel closed remotely logger.warn(iox); iox.printStackTrace(); sk.cancel(); readChannel.close(); TransportStateEvent e = new TransportStateEvent(TLSTM.this, incomingAddress, TransportStateEvent. STATE_DISCONNECTED_REMOTELY, iox); fireConnectionStateChanged(e); } } } catch (CancelledKeyException ckex) { if (logger.isDebugEnabled()) { logger.debug("Selection key cancelled, skipping it"); } } catch (NoSuchAlgorithmException e) { logger.error("NoSuchAlgorithm while reading from server socket: " + e.getMessage(), e); } } } } catch (NullPointerException npex) { // There seems to happen a NullPointerException within the select() npex.printStackTrace(); logger.warn("NullPointerException within select()?"); stop = true; } processPending(); } if (ssc != null) { ssc.close(); } if (selector != null) { selector.close(); } } catch (IOException iox) { logger.error(iox); lastError = iox; } if (!stop) { stop = true; synchronized (TLSTM.this) { server = null; } } if (logger.isDebugEnabled()) { logger.debug("Worker task finished: " + getClass().getName()); } } @Override protected TcpAddress writeData(SelectionKey sk, TcpAddress incomingAddress) { SocketEntry entry = (SocketEntry) sk.attachment(); try { SocketChannel sc = (SocketChannel) sk.channel(); incomingAddress = new TcpAddress(sc.socket().getInetAddress(), sc.socket().getPort()); if ((entry != null) && (!entry.hasMessage())) { synchronized (pending) { pending.remove(entry); entry.removeRegistration(selector, SelectionKey.OP_WRITE); } } if (entry != null) { writeMessage(entry, sc); } } catch (IOException iox) { logger.warn(iox); TransportStateEvent e = new TransportStateEvent(TLSTM.this, incomingAddress, TransportStateEvent. STATE_DISCONNECTED_REMOTELY, iox); fireConnectionStateChanged(e); // make sure channel is closed properly: closeChannel(sk.channel()); } return incomingAddress; } private void readMessage(SelectionKey sk, SocketChannel readChannel, TcpAddress incomingAddress, SocketEntry session) throws IOException { SocketEntry entry = (SocketEntry) sk.attachment(); if (entry == null) { entry = session; } if (entry == null) { logger.error("SocketEntry null in readMessage"); } assert (entry != null); // note that socket has been used entry.used(); ByteBuffer inNetBuffer = entry.getInNetBuffer(); ByteBuffer inAppBuffer = entry.getInAppBuffer(); try { long bytesRead = readChannel.read(inNetBuffer); inNetBuffer.flip(); if (logger.isDebugEnabled()) { logger.debug("Read " + bytesRead + " bytes from " + incomingAddress); logger.debug("TLS inNetBuffer: " + inNetBuffer); } if (bytesRead < 0) { logger.debug("Socket closed remotely"); sk.cancel(); readChannel.close(); TransportStateEvent e = new TransportStateEvent(TLSTM.this, incomingAddress, TransportStateEvent. STATE_DISCONNECTED_REMOTELY, null); fireConnectionStateChanged(e); return; } if (bytesRead == 0) { entry.inNetBuffer.clear(); //entry.addRegistration(selector, SelectionKey.OP_READ); } else { SSLEngineResult result; synchronized (entry.inboundLock) { result = entry.sslEngine.unwrap(inNetBuffer, inAppBuffer); adjustInNetBuffer(entry, result); switch (result.getStatus()) { case BUFFER_OVERFLOW: logger.error("BUFFER_OVERFLOW"); throw new IOException("BUFFER_OVERFLOW"); } if (runDelegatedTasks(result, entry)) { logger.info("SSL session established"); if (result.bytesProduced() > 0) { entry.inAppBuffer.flip(); logger.debug("SSL established, dispatching inAppBuffer=" + entry.inAppBuffer); // SSL session is established if (entry.inAppBuffer.remaining() % tlsMaxFragmentSize == 0) { if (logger.isDebugEnabled()) { logger.debug("Checking PDU header for fragmented message: "+entry); } try { BER.decodeHeader(new BERInputStream(entry.inAppBuffer.asReadOnlyBuffer()), new BER.MutableByte(), true); } catch (IOException iox) { entry.inAppBuffer.position(entry.inAppBuffer.limit()); entry.inAppBuffer.limit(entry.inAppBuffer.capacity()); // wait to get rest of the PDU first if (logger.isDebugEnabled()) { logger.debug("Waiting for rest of packet because: "+iox.getMessage()+ ", inAppBuffer="+entry.inAppBuffer); } return; } } entry.checkTransportStateReference(); dispatchMessage(incomingAddress, inAppBuffer, inAppBuffer.limit(), entry.sessionID, entry.tmStateReference); entry.getInAppBuffer().clear(); } else if (entry.isAppOutPending()) { writeMessage(entry, entry.getSocket().getChannel()); } } } } } catch (ClosedChannelException ccex) { sk.cancel(); if (logger.isDebugEnabled()) { logger.debug("Read channel not open, no bytes read from " + incomingAddress); } } } private ByteBuffer createBufferCopy(ByteBuffer buffer) { byte[] conInNetData = new byte[buffer.limit()]; int buflen = buffer.limit() - buffer.remaining(); buffer.flip(); buffer.get(conInNetData, 0, buflen); ByteBuffer bufferCopy = ByteBuffer.wrap(conInNetData); bufferCopy.position(buflen); return bufferCopy; } private void dispatchMessage(TcpAddress incomingAddress, ByteBuffer byteBuffer, long bytesRead, Object sessionID, TransportStateReference tmStateReference) { byteBuffer.flip(); if (logger.isDebugEnabled()) { logger.debug("Received message from " + incomingAddress + " with length " + bytesRead + ": " + new OctetString(byteBuffer.array(), 0, (int) bytesRead).toHexString()); } ByteBuffer bis; if (isAsyncMsgProcessingSupported()) { byte[] bytes = new byte[(int) bytesRead]; System.arraycopy(byteBuffer.array(), 0, bytes, 0, (int) bytesRead); bis = ByteBuffer.wrap(bytes); } else { bis = ByteBuffer.wrap(byteBuffer.array(), 0, (int) bytesRead); } fireProcessMessage(incomingAddress, bis, tmStateReference); } private void writeMessage(SocketEntry entry, SocketChannel sc) throws IOException { synchronized (entry.outboundLock) { boolean sendNextFragment = false; do { sendNextFragment = false; if (entry.outAppBuffer == null) { byte[] message = entry.nextMessage(); if (message != null) { entry.outAppBuffer = ByteBuffer.wrap(message); if (logger.isDebugEnabled()) { logger.debug("Sending message with length " + message.length + " to " + entry.getPeerAddress() + ": " + new OctetString(message).toHexString()); } } else { entry.removeRegistration(selector, SelectionKey.OP_WRITE); // Make sure that we did not clear a selection key that was concurrently // added: if (entry.hasMessage() && !entry.isRegistered(SelectionKey.OP_WRITE)) { entry.addRegistration(selector, SelectionKey.OP_WRITE); logger.debug("Waking up selector"); selector.wakeup(); } entry.addRegistration(selector, SelectionKey.OP_READ); return; } } SSLEngineResult result; result = entry.sslEngine.wrap(entry.outAppBuffer, entry.outNetBuffer); if (result.getStatus() == SSLEngineResult.Status.OK) { if (result.bytesProduced() > 0) { writeNetBuffer(entry, sc); } } else if (runDelegatedTasks(result, entry)) { logger.debug("SSL session OK"); } if (result.bytesConsumed() >= entry.outAppBuffer.limit()) { logger.debug("Payload sent completely"); entry.outAppBuffer = null; } else if (result.bytesConsumed() > 0) { logger.debug("Fragment of size " + result.bytesConsumed() + " sent: " + entry); sendNextFragment = true; } } while (sendNextFragment); } entry.addRegistration(selector, SelectionKey.OP_READ); } private void writeNetBuffer(SocketEntry entry, SocketChannel sc) throws IOException { entry.outNetBuffer.flip(); // Send SSL/TLS encoded data to peer while (entry.outNetBuffer.hasRemaining()) { logger.debug("Writing TLS outNetBuffer(PAYLOAD): " + entry.outNetBuffer); int num = sc.write(entry.outNetBuffer); logger.debug("Wrote TLS " + num + " bytes from outNetBuffer(PAYLOAD)"); if (num == -1) { throw new IOException("TLS connection closed"); } else if (num == 0) { entry.outNetBuffer.compact(); //entry.outNetBuffer.limit(entry.outNetBuffer.capacity()); return; } } entry.outNetBuffer.clear(); } } private boolean matchingStateReferences(TransportStateReference tmStateReferenceNew, TransportStateReference tmStateReferenceExisting) { if ((tmStateReferenceExisting == null) || (tmStateReferenceNew == null)) { logger.error("Failed to compare TransportStateReferences refNew=" + tmStateReferenceNew + ",refOld=" + tmStateReferenceExisting); return false; } if ((tmStateReferenceNew.getSecurityName() == null) || (tmStateReferenceExisting.getSecurityName() == null)) { logger.error("Could not match TransportStateReferences refNew=" + tmStateReferenceNew + ",refOld=" + tmStateReferenceExisting); return false; } else if (!tmStateReferenceNew.getSecurityName().equals(tmStateReferenceExisting.getSecurityName())) { return false; } return true; } private SSLEngineResult sendNetMessage(SocketEntry entry) throws IOException { SSLEngineResult result; synchronized (entry.outboundLock) { if (!entry.outNetBuffer.hasRemaining()) { return null; } result = entry.sslEngine.wrap(ByteBuffer.allocate(0), entry.outNetBuffer); entry.outNetBuffer.flip(); logger.debug("TLS outNetBuffer = " + entry.outNetBuffer); entry.socket.getChannel().write(entry.outNetBuffer); entry.outNetBuffer.clear(); } return result; } private void adjustInNetBuffer(SocketEntry entry, SSLEngineResult result) { if (result.bytesConsumed() == entry.inNetBuffer.limit()) { entry.inNetBuffer.clear(); } else if (result.bytesConsumed() > 0) { entry.inNetBuffer.compact(); } } private class DefaultTLSTMTrustManagerFactory implements TLSTMTrustManagerFactory { public X509TrustManager create(X509TrustManager trustManager, boolean useClientMode, TransportStateReference tmStateReference) { return new TlsTrustManager(trustManager, useClientMode, tmStateReference, counterSupport, securityCallback); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy