com.sshtools.ssh2.ConnectionProtocol Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2003-2016 SSHTOOLS Limited. All Rights Reserved.
*
* For product documentation visit https://www.sshtools.com/
*
* This file is part of J2SSH Maverick.
*
* J2SSH Maverick is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* J2SSH Maverick is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with J2SSH Maverick. If not, see .
*/
package com.sshtools.ssh2;
import java.io.IOException;
import java.util.Hashtable;
import com.sshtools.logging.Log;
import com.sshtools.ssh.ChannelOpenException;
import com.sshtools.ssh.SshContext;
import com.sshtools.ssh.SshException;
import com.sshtools.ssh.message.Message;
import com.sshtools.ssh.message.MessageObserver;
import com.sshtools.ssh.message.SshAbstractChannel;
import com.sshtools.ssh.message.SshChannelMessage;
import com.sshtools.ssh.message.SshMessage;
import com.sshtools.ssh.message.SshMessageRouter;
import com.sshtools.util.ByteArrayWriter;
/**
*
* @author Lee David Painter
*/
class ConnectionProtocol extends SshMessageRouter implements
TransportProtocolListener {
/** The name of this service "ssh-connection" */
public static final String SERVICE_NAME = "ssh-connection";
final static int SSH_MSG_CHANNEL_OPEN = 90;
final static int SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 91;
final static int SSH_MSG_CHANNEL_OPEN_FAILURE = 92;
final static int SSH_MSG_GLOBAL_REQUEST = 80;
final static int SSH_MSG_REQUEST_SUCCESS = 81;
final static int SSH_MSG_REQUEST_FAILURE = 82;
Object channelOpenLock = new Object();
final static MessageObserver CHANNEL_OPEN_RESPONSE_MESSAGES = new MessageObserver() {
public boolean wantsNotification(Message msg) {
switch (msg.getMessageId()) {
case SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
case SSH_MSG_CHANNEL_OPEN_FAILURE:
return true;
default:
return false;
}
}
};
final static MessageObserver GLOBAL_REQUEST_MESSAGES = new MessageObserver() {
public boolean wantsNotification(Message msg) {
switch (msg.getMessageId()) {
case SSH_MSG_REQUEST_SUCCESS:
case SSH_MSG_REQUEST_FAILURE:
return true;
default:
return false;
}
}
};
TransportProtocol transport;
Hashtable channelfactories = new Hashtable();
Hashtable requesthandlers = new Hashtable();
public ConnectionProtocol(TransportProtocol transport, SshContext context,
boolean buffered) {
super(transport, context.getChannelLimit(), buffered);
this.transport = transport;
this.transport.addListener(this);
}
public void addChannelFactory(ChannelFactory factory) throws SshException {
String[] types = factory.supportedChannelTypes();
for (int i = 0; i < types.length; i++) {
if (channelfactories.containsKey(types[i])) {
throw new SshException(types[i]
+ " channel is already registered!",
SshException.BAD_API_USAGE);
}
channelfactories.put(types[i], factory);
}
}
public void addRequestHandler(GlobalRequestHandler handler)
throws SshException {
String[] types = handler.supportedRequests();
for (int i = 0; i < types.length; i++) {
if (requesthandlers.containsKey(types[i])) {
throw new SshException(types[i]
+ " request is already registered!",
SshException.BAD_API_USAGE);
}
requesthandlers.put(types[i], handler);
}
}
public boolean sendGlobalRequest(GlobalRequest request, boolean wantreply)
throws SshException {
return sendGlobalRequest(request, wantreply, 0);
}
public boolean sendGlobalRequest(GlobalRequest request, boolean wantreply,
long timeout) throws SshException {
ByteArrayWriter msg = new ByteArrayWriter();
try {
msg.write(SSH_MSG_GLOBAL_REQUEST);
msg.writeString(request.getName());
msg.writeBoolean(wantreply);
if (request.getData() != null) {
msg.write(request.getData());
}
if (Log.isDebugEnabled()) {
Log.debug(
this,
"Sending SSH_MSG_GLOBAL_REQUEST request="
+ request.getName() + " wantreply=" + wantreply);
}
sendMessage(msg.toByteArray(), true);
if (wantreply) {
SshMessage reply = getGlobalMessages().nextMessage(
GLOBAL_REQUEST_MESSAGES, timeout);
if (reply.getMessageId() == SSH_MSG_REQUEST_SUCCESS) {
if (Log.isDebugEnabled()) {
Log.debug(this,
"Received SSH_MSG_REQUEST_SUCCESS request="
+ request.getName());
}
if (reply.available() > 0) {
byte[] tmp = new byte[reply.available()];
reply.read(tmp);
request.setData(tmp);
} else {
request.setData(null);
}
return true;
}
if (Log.isDebugEnabled()) {
Log.debug(this,
"Received SSH_MSG_REQUEST_FAILURE request="
+ request.getName());
}
return false;
}
return true;
} catch (IOException ex) {
throw new SshException(ex, SshException.INTERNAL_ERROR);
} finally {
try {
msg.close();
} catch (IOException e) {
}
}
}
public void closeChannel(Ssh2Channel channel) {
freeChannel(channel);
}
public SshContext getContext() {
return transport.transportContext;
}
public void openChannel(Ssh2Channel channel, byte[] requestdata)
throws SshException, ChannelOpenException {
openChannel(channel, requestdata, 0);
}
public void openChannel(Ssh2Channel channel, byte[] requestdata,
long timeout) throws SshException, ChannelOpenException {
// synchronized(channelOpenLock) {
try {
int channelid = allocateChannel(channel);
if (channelid == -1) {
if (Log.isDebugEnabled()) {
Log.debug(this,
"Maximum number of channels exceeded! active="
+ getChannelCount() + " channels="
+ getMaxChannels());
}
throw new ChannelOpenException(
"Maximum number of channels exceeded",
ChannelOpenException.RESOURCE_SHORTAGE);
}
channel.init(this, channelid);
/*
* byte SSH_MSG_CHANNEL_OPEN string channel type in US-ASCII only
* uint32 sender channel uint32 initial window size uint32 maximum
* packet size .... channel type specific data follows
*/
ByteArrayWriter msg = new ByteArrayWriter();
try {
msg.write(SSH_MSG_CHANNEL_OPEN);
msg.writeString(channel.getName());
msg.writeInt(channel.getChannelId());
msg.writeInt(channel.getWindowSize());
msg.writeInt(channel.getPacketSize());
if (requestdata != null) {
msg.write(requestdata);
}
if (Log.isDebugEnabled()) {
Log.debug(
this,
"Sending SSH_MSG_CHANNEL_OPEN type="
+ channel.getName() + " id="
+ channel.getChannelId() + " window="
+ channel.getWindowSize() + " packet="
+ channel.getPacketSize());
}
transport.sendMessage(msg.toByteArray(), true);
} finally {
try {
msg.close();
} catch (IOException e) {
}
}
SshMessage reply = channel.getMessageStore().nextMessage(
CHANNEL_OPEN_RESPONSE_MESSAGES, timeout);
if (reply.getMessageId() == SSH_MSG_CHANNEL_OPEN_FAILURE) {
if (Log.isDebugEnabled()) {
Log.debug(this,
"Received SSH_MSG_CHANNEL_OPEN_FAILURE id="
+ channel.getChannelId());
}
freeChannel(channel);
int reason = (int) reply.readInt();
throw new ChannelOpenException(reply.readString(), reason);
}
int remoteid = (int) reply.readInt();
long remotewindow = reply.readInt();
int remotepacket = (int) reply.readInt();
byte[] responsedata = new byte[reply.available()];
reply.read(responsedata);
if (Log.isDebugEnabled()) {
Log.debug(this,
"Received SSH_MSG_CHANNEL_OPEN_CONFIRMATION id="
+ channel.getChannelId() + " rid=" + remoteid
+ " window=" + remotewindow + " packet="
+ remotepacket);
}
channel.open(remoteid, remotewindow, remotepacket, responsedata);
return;
} catch (IOException ex) {
throw new SshException(ex, SshException.INTERNAL_ERROR);
}
// }
}
protected void sendMessage(byte[] msg, boolean isActivity)
throws SshException {
transport.sendMessage(msg, isActivity);
}
protected SshMessage createMessage(byte[] msg) throws SshException {
if (msg[0] >= 91 && msg[0] <= 100) {
return new SshChannelMessage(msg);
}
return new SshMessage(msg);
}
protected boolean processGlobalMessage(SshMessage message)
throws SshException {
/**
* We need to filter for any messages that require a response from the
* connection protocol such as channel open or global requests. These
* are not handled anywhere else within this implementation because
* doing so would require a thread to wait.
*/
try {
switch (message.getMessageId()) {
case SSH_MSG_CHANNEL_OPEN: {
// Attempt to open the channel
String type = message.readString();
int remoteid = (int) message.readInt();
int remotewindow = (int) message.readInt();
int remotepacket = (int) message.readInt();
byte[] requestdata = message.available() > 0 ? new byte[message
.available()] : null;
message.read(requestdata);
if (Log.isDebugEnabled()) {
Log.debug(this,
"Received SSH_MSG_CHANNEL_OPEN rid=" + remoteid
+ " window=" + remotewindow + " packet="
+ remotepacket);
}
processChannelOpenRequest(type, remoteid, remotewindow,
remotepacket, requestdata);
return true;
}
case SSH_MSG_GLOBAL_REQUEST: {
// Attempt to process the global request
String requestname = message.readString();
boolean wantreply = message.read() != 0;
byte[] requestdata = new byte[message.available()];
message.read(requestdata);
if (Log.isDebugEnabled()) {
Log.debug(this,
"Received SSH_MSG_GLOBAL_REQUEST request="
+ requestname + " wantreply=" + wantreply);
}
// Process the request
processGlobalRequest(requestname, wantreply, requestdata);
return true;
}
default:
return false;
}
} catch (IOException ex) {
throw new SshException(ex, SshException.INTERNAL_ERROR);
}
}
void processChannelOpenRequest(String type, int remoteid, int remotewindow,
int remotepacket, byte[] requestdata) throws SshException {
ByteArrayWriter response = new ByteArrayWriter();
try {
if (channelfactories.containsKey(type)) {
try {
Ssh2Channel channel = ((ChannelFactory) channelfactories
.get(type)).createChannel(type, requestdata);
// Allocate a channel
if (Log.isDebugEnabled()) {
Log.debug(this,
"There are " + this.getChannelCount()
+ " channels open");
}
int localid = allocateChannel(channel);
if (localid > -1) {
try {
channel.init(this, localid);
byte[] responsedata = channel.create();
response.write(SSH_MSG_CHANNEL_OPEN_CONFIRMATION);
response.writeInt(remoteid);
response.writeInt(localid);
response.writeInt(channel.getWindowSize());
response.writeInt(channel.getPacketSize());
if (responsedata != null) {
response.write(responsedata);
}
if (Log.isDebugEnabled()) {
Log.debug(
this,
"Sending SSH_MSG_CHANNEL_OPEN_CONFIRMATION type="
+ channel.getName() + " id="
+ channel.getChannelId()
+ " rid=" + remoteid
+ " window="
+ channel.getWindowSize()
+ " packet="
+ channel.getPacketSize());
}
transport.sendMessage(response.toByteArray(), true);
channel.open(remoteid, remotewindow, remotepacket);
return;
} catch (SshException ex) {
response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
response.writeInt(remoteid);
response.writeInt(ChannelOpenException.CONNECT_FAILED);
response.writeString(ex.getMessage());
response.writeString("");
}
} else {
response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
response.writeInt(remoteid);
response.writeInt(ChannelOpenException.RESOURCE_SHORTAGE);
response.writeString("Maximum allowable open channel limit of "
+ String.valueOf(maximumChannels())
+ " exceeded!");
response.writeString("");
}
} catch (ChannelOpenException ex) {
response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
response.writeInt(remoteid);
response.writeInt(ex.getReason());
response.writeString(ex.getMessage());
response.writeString("");
}
} else {
response.write(SSH_MSG_CHANNEL_OPEN_FAILURE);
response.writeInt(remoteid);
response.writeInt(ChannelOpenException.UNKNOWN_CHANNEL_TYPE);
response.writeString(type + " is not a supported channel type!");
response.writeString("");
}
if (Log.isDebugEnabled()) {
Log.debug(this,
"Sending SSH_MSG_CHANNEL_OPEN_FAILURE rid=" + remoteid);
}
transport.sendMessage(response.toByteArray(), true);
} catch (IOException ex1) {
throw new SshException(ex1.getMessage(),
SshException.INTERNAL_ERROR);
} finally {
try {
response.close();
} catch (IOException e) {
}
}
}
void processGlobalRequest(String requestname, boolean wantreply,
byte[] requestdata) throws SshException {
ByteArrayWriter response = new ByteArrayWriter();
try {
boolean success = false;
GlobalRequest request = new GlobalRequest(requestname, requestdata);
if (requesthandlers.containsKey(requestname)) {
success = ((GlobalRequestHandler) requesthandlers
.get(requestname)).processGlobalRequest(request);
}
if (wantreply) {
if (success) {
response.write(SSH_MSG_REQUEST_SUCCESS);
if (request.getData() != null) {
response.write(request.getData());
}
if (Log.isDebugEnabled()) {
Log.debug(this,
"Sending SSH_MSG_REQUEST_SUCCESS request="
+ requestname);
}
transport.sendMessage(response.toByteArray(), true);
} else {
// Return a response
transport.sendMessage(
new byte[] { SSH_MSG_REQUEST_FAILURE }, true);
}
}
} catch (IOException ex) {
throw new SshException(ex, SshException.INTERNAL_ERROR);
} finally {
try {
response.close();
} catch (IOException e) {
}
}
}
protected void onThreadExit() {
if (transport != null && transport.isConnected()) {
transport.disconnect(TransportProtocol.CONNECTION_LOST, "Exiting");
}
stop();
}
public void onDisconnect(String msg, int reason) {
}
public void onIdle(long lastActivity) {
SshAbstractChannel[] channels = getActiveChannels();
for (int i = 0; i < channels.length; i++)
channels[i].idle();
}
}