All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sshtools.ssh2.ConnectionProtocol Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2003-2016 SSHTOOLS Limited. All Rights Reserved.
 *
 * For product documentation visit https://www.sshtools.com/
 *
 * This file is part of J2SSH Maverick.
 *
 * J2SSH Maverick is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * J2SSH Maverick is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with J2SSH Maverick.  If not, see .
 */
package com.sshtools.ssh2;

import java.io.IOException;
import java.util.Hashtable;

import com.sshtools.logging.Log;
import com.sshtools.ssh.ChannelOpenException;
import com.sshtools.ssh.SshContext;
import com.sshtools.ssh.SshException;
import com.sshtools.ssh.message.Message;
import com.sshtools.ssh.message.MessageObserver;
import com.sshtools.ssh.message.SshAbstractChannel;
import com.sshtools.ssh.message.SshChannelMessage;
import com.sshtools.ssh.message.SshMessage;
import com.sshtools.ssh.message.SshMessageRouter;
import com.sshtools.util.ByteArrayWriter;

/**
 * 
 * @author Lee David Painter
 */
class ConnectionProtocol extends SshMessageRouter implements
		TransportProtocolListener {

	/** The name of this service "ssh-connection" */
	public static final String SERVICE_NAME = "ssh-connection";

	final static int SSH_MSG_CHANNEL_OPEN = 90;
	final static int SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 91;
	final static int SSH_MSG_CHANNEL_OPEN_FAILURE = 92;

	final static int SSH_MSG_GLOBAL_REQUEST = 80;
	final static int SSH_MSG_REQUEST_SUCCESS = 81;
	final static int SSH_MSG_REQUEST_FAILURE = 82;

	Object channelOpenLock = new Object();

	final static MessageObserver CHANNEL_OPEN_RESPONSE_MESSAGES = new MessageObserver() {
		public boolean wantsNotification(Message msg) {
			switch (msg.getMessageId()) {
			case SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
			case SSH_MSG_CHANNEL_OPEN_FAILURE:
				return true;
			default:
				return false;
			}
		}
	};

	final static MessageObserver GLOBAL_REQUEST_MESSAGES = new MessageObserver() {
		public boolean wantsNotification(Message msg) {
			switch (msg.getMessageId()) {
			case SSH_MSG_REQUEST_SUCCESS:
			case SSH_MSG_REQUEST_FAILURE:
				return true;
			default:
				return false;
			}
		}
	};

	TransportProtocol transport;
	Hashtable channelfactories = new Hashtable();
	Hashtable requesthandlers = new Hashtable();

	public ConnectionProtocol(TransportProtocol transport, SshContext context,
			boolean buffered) {
		super(transport, context.getChannelLimit(), buffered);
		this.transport = transport;
		this.transport.addListener(this);
	}

	public void addChannelFactory(ChannelFactory factory) throws SshException {
		String[] types = factory.supportedChannelTypes();
		for (int i = 0; i < types.length; i++) {
			if (channelfactories.containsKey(types[i])) {
				throw new SshException(types[i]
						+ " channel is already registered!",
						SshException.BAD_API_USAGE);
			}
			channelfactories.put(types[i], factory);
		}
	}

	public void addRequestHandler(GlobalRequestHandler handler)
			throws SshException {
		String[] types = handler.supportedRequests();
		for (int i = 0; i < types.length; i++) {
			if (requesthandlers.containsKey(types[i])) {
				throw new SshException(types[i]
						+ " request is already registered!",
						SshException.BAD_API_USAGE);
			}
			requesthandlers.put(types[i], handler);
		}
	}

	public boolean sendGlobalRequest(GlobalRequest request, boolean wantreply)
			throws SshException {
		return sendGlobalRequest(request, wantreply, 0);
	}

	public boolean sendGlobalRequest(GlobalRequest request, boolean wantreply,
			long timeout) throws SshException {

		ByteArrayWriter msg = new ByteArrayWriter();
		try {

			msg.write(SSH_MSG_GLOBAL_REQUEST);
			msg.writeString(request.getName());
			msg.writeBoolean(wantreply);
			if (request.getData() != null) {
				msg.write(request.getData());

			}

			if (Log.isDebugEnabled()) {
				Log.debug(
						this,
						"Sending SSH_MSG_GLOBAL_REQUEST request="
								+ request.getName() + " wantreply=" + wantreply);
			}
			sendMessage(msg.toByteArray(), true);

			if (wantreply) {
				SshMessage reply = getGlobalMessages().nextMessage(
						GLOBAL_REQUEST_MESSAGES, timeout);
				if (reply.getMessageId() == SSH_MSG_REQUEST_SUCCESS) {
					if (Log.isDebugEnabled()) {
						Log.debug(this,
								"Received SSH_MSG_REQUEST_SUCCESS request="
										+ request.getName());
					}
					if (reply.available() > 0) {
						byte[] tmp = new byte[reply.available()];
						reply.read(tmp);
						request.setData(tmp);
					} else {
						request.setData(null);
					}
					return true;
				}
				if (Log.isDebugEnabled()) {
					Log.debug(this,
							"Received SSH_MSG_REQUEST_FAILURE request="
									+ request.getName());
				}
				return false;
			}
			return true;
		} catch (IOException ex) {
			throw new SshException(ex, SshException.INTERNAL_ERROR);
		} finally {
			try {
				msg.close();
			} catch (IOException e) {
			}
		}
	}

	public void closeChannel(Ssh2Channel channel) {
		freeChannel(channel);
	}

	public SshContext getContext() {
		return transport.transportContext;
	}

	public void openChannel(Ssh2Channel channel, byte[] requestdata)
			throws SshException, ChannelOpenException {
		openChannel(channel, requestdata, 0);
	}

	public void openChannel(Ssh2Channel channel, byte[] requestdata,
			long timeout) throws SshException, ChannelOpenException {

		// synchronized(channelOpenLock) {
		try {

			int channelid = allocateChannel(channel);

			if (channelid == -1) {
				if (Log.isDebugEnabled()) {
					Log.debug(this,
							"Maximum number of channels exceeded! active="
									+ getChannelCount() + " channels="
									+ getMaxChannels());
				}
				throw new ChannelOpenException(
						"Maximum number of channels exceeded",
						ChannelOpenException.RESOURCE_SHORTAGE);
			}

			channel.init(this, channelid);
			/*
			 * byte SSH_MSG_CHANNEL_OPEN string channel type in US-ASCII only
			 * uint32 sender channel uint32 initial window size uint32 maximum
			 * packet size .... channel type specific data follows
			 */
			ByteArrayWriter msg = new ByteArrayWriter();

			try {
				msg.write(SSH_MSG_CHANNEL_OPEN);
				msg.writeString(channel.getName());
				msg.writeInt(channel.getChannelId());
				msg.writeInt(channel.getWindowSize());
				msg.writeInt(channel.getPacketSize());
				if (requestdata != null) {
					msg.write(requestdata);

				}

				if (Log.isDebugEnabled()) {
					Log.debug(
							this,
							"Sending SSH_MSG_CHANNEL_OPEN type="
									+ channel.getName() + " id="
									+ channel.getChannelId() + " window="
									+ channel.getWindowSize() + " packet="
									+ channel.getPacketSize());
				}
				transport.sendMessage(msg.toByteArray(), true);

			} finally {
				try {
					msg.close();
				} catch (IOException e) {
				}
			}

			SshMessage reply = channel.getMessageStore().nextMessage(
					CHANNEL_OPEN_RESPONSE_MESSAGES, timeout);

			if (reply.getMessageId() == SSH_MSG_CHANNEL_OPEN_FAILURE) {

				if (Log.isDebugEnabled()) {
					Log.debug(this,
							"Received SSH_MSG_CHANNEL_OPEN_FAILURE id="
									+ channel.getChannelId());
				}

				freeChannel(channel);
				int reason = (int) reply.readInt();
				throw new ChannelOpenException(reply.readString(), reason);
			}
			int remoteid = (int) reply.readInt();
			long remotewindow = reply.readInt();
			int remotepacket = (int) reply.readInt();
			byte[] responsedata = new byte[reply.available()];
			reply.read(responsedata);

			if (Log.isDebugEnabled()) {
				Log.debug(this,
						"Received SSH_MSG_CHANNEL_OPEN_CONFIRMATION id="
								+ channel.getChannelId() + " rid=" + remoteid
								+ " window=" + remotewindow + " packet="
								+ remotepacket);
			}

			channel.open(remoteid, remotewindow, remotepacket, responsedata);

			return;
		} catch (IOException ex) {
			throw new SshException(ex, SshException.INTERNAL_ERROR);
		}
		// }

	}

	protected void sendMessage(byte[] msg, boolean isActivity)
			throws SshException {
		transport.sendMessage(msg, isActivity);
	}

	protected SshMessage createMessage(byte[] msg) throws SshException {

		if (msg[0] >= 91 && msg[0] <= 100) {
			return new SshChannelMessage(msg);
		}
		return new SshMessage(msg);
	}

	protected boolean processGlobalMessage(SshMessage message)
			throws SshException {

		/**
		 * We need to filter for any messages that require a response from the
		 * connection protocol such as channel open or global requests. These
		 * are not handled anywhere else within this implementation because
		 * doing so would require a thread to wait.
		 */

		try {
			switch (message.getMessageId()) {
			case SSH_MSG_CHANNEL_OPEN: {
				// Attempt to open the channel

				String type = message.readString();
				int remoteid = (int) message.readInt();
				int remotewindow = (int) message.readInt();
				int remotepacket = (int) message.readInt();
				byte[] requestdata = message.available() > 0 ? new byte[message
						.available()] : null;
				message.read(requestdata);

				if (Log.isDebugEnabled()) {
					Log.debug(this,
							"Received SSH_MSG_CHANNEL_OPEN rid=" + remoteid
									+ " window=" + remotewindow + " packet="
									+ remotepacket);
				}

				processChannelOpenRequest(type, remoteid, remotewindow,
						remotepacket, requestdata);
				return true;
			}
			case SSH_MSG_GLOBAL_REQUEST: {

				// Attempt to process the global request
				String requestname = message.readString();
				boolean wantreply = message.read() != 0;
				byte[] requestdata = new byte[message.available()];
				message.read(requestdata);

				if (Log.isDebugEnabled()) {
					Log.debug(this,
							"Received SSH_MSG_GLOBAL_REQUEST request="
									+ requestname + " wantreply=" + wantreply);
				}

				// Process the request
				processGlobalRequest(requestname, wantreply, requestdata);
				return true;
			}
			default:
				return false;
			}
		} catch (IOException ex) {
			throw new SshException(ex, SshException.INTERNAL_ERROR);
		}

	}

	void processChannelOpenRequest(String type, int remoteid, int remotewindow,
			int remotepacket, byte[] requestdata) throws SshException {

		ByteArrayWriter response = new ByteArrayWriter();

		try {

			if (channelfactories.containsKey(type)) {
				try {
					Ssh2Channel channel = ((ChannelFactory) channelfactories
							.get(type)).createChannel(type, requestdata);

					// Allocate a channel
					if (Log.isDebugEnabled()) {
						Log.debug(this,
								"There are " + this.getChannelCount()
										+ " channels open");
					}

					int localid = allocateChannel(channel);

					if (localid > -1) {
						try {
							channel.init(this, localid);
							byte[] responsedata = channel.create();
							response.write(SSH_MSG_CHANNEL_OPEN_CONFIRMATION);
							response.writeInt(remoteid);
							response.writeInt(localid);
							response.writeInt(channel.getWindowSize());
							response.writeInt(channel.getPacketSize());
							if (responsedata != null) {
								response.write(responsedata);

							}

							if (Log.isDebugEnabled()) {
								Log.debug(
										this,
										"Sending SSH_MSG_CHANNEL_OPEN_CONFIRMATION type="
												+ channel.getName() + " id="
												+ channel.getChannelId()
												+ " rid=" + remoteid
												+ " window="
												+ channel.getWindowSize()
												+ " packet="
												+ channel.getPacketSize());
							}
							transport.sendMessage(response.toByteArray(), true);

							channel.open(remoteid, remotewindow, remotepacket);

							return;

						} catch (SshException ex) {
							response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
							response.writeInt(remoteid);
							response.writeInt(ChannelOpenException.CONNECT_FAILED);
							response.writeString(ex.getMessage());
							response.writeString("");
						}

					} else {
						response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
						response.writeInt(remoteid);
						response.writeInt(ChannelOpenException.RESOURCE_SHORTAGE);
						response.writeString("Maximum allowable open channel limit of "
								+ String.valueOf(maximumChannels())
								+ " exceeded!");
						response.writeString("");
					}

				} catch (ChannelOpenException ex) {
					response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
					response.writeInt(remoteid);
					response.writeInt(ex.getReason());
					response.writeString(ex.getMessage());
					response.writeString("");
				}
			} else {
				response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
				response.writeInt(remoteid);
				response.writeInt(ChannelOpenException.UNKNOWN_CHANNEL_TYPE);
				response.writeString(type + " is not a supported channel type!");
				response.writeString("");
			}

			if (Log.isDebugEnabled()) {
				Log.debug(this,
						"Sending SSH_MSG_CHANNEL_OPEN_FAILURE rid=" + remoteid);
			}
			transport.sendMessage(response.toByteArray(), true);
		} catch (IOException ex1) {
			throw new SshException(ex1.getMessage(),
					SshException.INTERNAL_ERROR);
		} finally {
			try {
				response.close();
			} catch (IOException e) {
			}
		}
	}

	void processGlobalRequest(String requestname, boolean wantreply,
			byte[] requestdata) throws SshException {

		ByteArrayWriter response = new ByteArrayWriter();
		try {
			boolean success = false;
			GlobalRequest request = new GlobalRequest(requestname, requestdata);
			if (requesthandlers.containsKey(requestname)) {
				success = ((GlobalRequestHandler) requesthandlers
						.get(requestname)).processGlobalRequest(request);
			}

			if (wantreply) {
				if (success) {

					response.write(SSH_MSG_REQUEST_SUCCESS);
					if (request.getData() != null) {
						response.write(request.getData());
					}

					if (Log.isDebugEnabled()) {
						Log.debug(this,
								"Sending SSH_MSG_REQUEST_SUCCESS request="
										+ requestname);
					}
					transport.sendMessage(response.toByteArray(), true);
				} else {
					// Return a response
					transport.sendMessage(
							new byte[] { SSH_MSG_REQUEST_FAILURE }, true);
				}
			}
		} catch (IOException ex) {
			throw new SshException(ex, SshException.INTERNAL_ERROR);
		} finally {
			try {
				response.close();
			} catch (IOException e) {
			}
		}

	}

	protected void onThreadExit() {
		if (transport != null && transport.isConnected()) {
			transport.disconnect(TransportProtocol.CONNECTION_LOST, "Exiting");
		}
		stop();
	}

	public void onDisconnect(String msg, int reason) {

	}

	public void onIdle(long lastActivity) {

		SshAbstractChannel[] channels = getActiveChannels();
		for (int i = 0; i < channels.length; i++)
			channels[i].idle();

	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy