net.schmizz.sshj.transport.TransportImpl Maven / Gradle / Ivy
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.transport;
import com.hierynomus.sshj.key.KeyAlgorithm;
import com.hierynomus.sshj.key.KeyAlgorithms;
import com.hierynomus.sshj.transport.IdentificationStringParser;
import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.Service;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.transport.verification.AlgorithmsVerifier;
import net.schmizz.sshj.transport.verification.HostKeyVerifier;
import org.slf4j.Logger;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
/**
* A thread-safe {@link Transport} implementation.
*/
public final class TransportImpl
implements Transport, DisconnectListener {
private static final class NullService
extends AbstractService {
NullService(Transport trans) {
super("null-service", trans);
}
}
static final class ConnInfo {
final String host;
final int port;
final InputStream in;
final OutputStream out;
ConnInfo(String host, int port, InputStream in, OutputStream out) {
this.host = host;
this.port = port;
this.in = in;
this.out = out;
}
}
private final LoggerFactory loggerFactory;
private final Logger log;
private final Service nullService;
private final Config config;
private final KeyExchanger kexer;
private final Reader reader;
private final Encoder encoder;
private final Decoder decoder;
private KeyAlgorithm hostKeyAlgorithm;
private boolean rsaSHA2Support;
private final Event serviceAccept;
private final Event close;
/**
* Client version identification string
*/
private final String clientID;
private volatile int timeoutMs = 30 * 1000; // Crazy long, but it was the original default
private volatile boolean authed = false;
/**
* Currently active service e.g. UserAuthService, ConnectionService
*/
private volatile Service service;
/**
* The next service that will be activated, only set when sending an SSH_MSG_SERVICE_REQUEST
*/
private volatile Service nextService;
private DisconnectListener disconnectListener;
private ConnInfo connInfo;
/**
* Server version identification string
*/
private String serverID;
/**
* Message identifier of last packet received
*/
private Message msg;
private final ReentrantLock writeLock = new ReentrantLock();
public TransportImpl(Config config) {
this.config = config;
this.loggerFactory = config.getLoggerFactory();
this.serviceAccept = new Event("service accept", TransportException.chainer, loggerFactory);
this.close = new Event("transport close", TransportException.chainer, loggerFactory);
this.nullService = new NullService(this);
this.service = nullService;
this.log = loggerFactory.getLogger(getClass());
this.disconnectListener = this;
this.reader = new Reader(this);
this.encoder = new Encoder(config.getRandomFactory().create(), writeLock, loggerFactory);
this.decoder = new Decoder(this);
this.kexer = new KeyExchanger(this);
this.clientID = String.format("SSH-2.0-%s", config.getVersion());
}
@Override
public void init(String remoteHost, int remotePort, InputStream in, OutputStream out)
throws TransportException {
connInfo = new ConnInfo(remoteHost, remotePort, in, out);
try {
if (config.isWaitForServerIdentBeforeSendingClientIdent()) {
receiveServerIdent();
sendClientIdent();
} else {
sendClientIdent();
receiveServerIdent();
}
log.info("Server identity string: {}", serverID);
} catch (IOException e) {
throw new TransportException(e);
}
reader.start();
}
/**
* TransportImpl implements its own default DisconnectListener.
*/
@Override
public void notifyDisconnect(DisconnectReason reason, String message) {
log.info("Disconnected - {}", reason);
}
private void receiveServerIdent() throws IOException {
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer();
while ((serverID = readIdentification(buf)).isEmpty()) {
int b = connInfo.in.read();
if (b == -1) {
log.error("Received end of connection, but no identification received. ");
throw new TransportException("Server closed connection during identification exchange");
}
buf.putByte((byte) b);
}
}
/**
* Receive the server identification string.
*
* @throws IOException If there was an error writing to the outputstream.
*/
private void sendClientIdent() throws IOException {
log.info("Client identity string: {}", clientID);
connInfo.out.write((clientID + "\r\n").getBytes(IOUtils.UTF8));
connInfo.out.flush();
}
/**
* Reads the identification string from the SSH server. This is the very first string that is sent upon connection
* by the server. It takes the form of, e.g. "SSH-2.0-OpenSSH_ver".
*
* Several concerns are taken care of here, e.g. verifying protocol version, correct line endings as specified in
* RFC and such.
*
* This is not efficient but is only done once.
*
* @param buffer The buffer to read from.
* @return empty string if full ident string has not yet been received
* @throws IOException
*/
private String readIdentification(Buffer.PlainBuffer buffer)
throws IOException {
String ident = new IdentificationStringParser(buffer, loggerFactory).parseIdentificationString();
if (ident.isEmpty()) {
return ident;
}
if (!ident.startsWith("SSH-2.0-") && !ident.startsWith("SSH-1.99-"))
throw new TransportException(DisconnectReason.PROTOCOL_VERSION_NOT_SUPPORTED,
"Server does not support SSHv2, identified as: " + ident);
return ident;
}
@Override
public void addHostKeyVerifier(HostKeyVerifier hkv) {
kexer.addHostKeyVerifier(hkv);
}
@Override
public void addAlgorithmsVerifier(AlgorithmsVerifier verifier) {
kexer.addAlgorithmsVerifier(verifier);
}
@Override
public void doKex()
throws TransportException {
kexer.startKex(true);
}
public boolean isKexDone() {
return kexer.isKexDone();
}
@Override
public int getTimeoutMs() {
return timeoutMs;
}
@Override
public void setTimeoutMs(int timeoutMs) {
this.timeoutMs = timeoutMs;
}
@Override
public String getRemoteHost() {
return connInfo.host;
}
@Override
public int getRemotePort() {
return connInfo.port;
}
@Override
public String getClientVersion() {
return clientID.substring(8);
}
@Override
public Config getConfig() {
return config;
}
@Override
public String getServerVersion() {
return serverID == null ? null : serverID.substring(8);
}
@Override
public byte[] getSessionID() {
return kexer.getSessionID();
}
@Override
public synchronized Service getService() {
return service;
}
@Override
public synchronized void setService(Service service) {
if (service == null) {
service = nullService;
}
log.debug("Setting active service to {}", service.getName());
this.service = service;
}
@Override
public void reqService(Service service)
throws TransportException {
serviceAccept.lock();
try {
serviceAccept.clear();
this.nextService = service;
sendServiceRequest(service.getName());
serviceAccept.await(timeoutMs, TimeUnit.MILLISECONDS);
} finally {
serviceAccept.unlock();
this.nextService = null;
}
}
/**
* Sends a service request for the specified service
*
* @param serviceName name of the service being requested
* @throws TransportException if there is an error while sending the request
*/
private void sendServiceRequest(String serviceName)
throws TransportException {
log.debug("Sending SSH_MSG_SERVICE_REQUEST for {}", serviceName);
write(new SSHPacket(Message.SERVICE_REQUEST).putString(serviceName));
}
@Override
public void setAuthenticated() {
this.authed = true;
encoder.setAuthenticated();
decoder.setAuthenticated();
}
@Override
public boolean isAuthenticated() {
return authed;
}
@Override
public long sendUnimplemented()
throws TransportException {
final long seq = decoder.getSequenceNumber();
log.debug("Sending SSH_MSG_UNIMPLEMENTED for packet #{}", seq);
return write(new SSHPacket(Message.UNIMPLEMENTED).putUInt32(seq));
}
@Override
public void join()
throws TransportException {
close.await();
}
@Override
public void join(int timeout, TimeUnit unit)
throws TransportException {
close.await(timeout, unit);
}
@Override
public boolean isRunning() {
return reader.isAlive() && !close.isSet();
}
@Override
public void disconnect() {
disconnect(DisconnectReason.BY_APPLICATION);
}
@Override
public void disconnect(DisconnectReason reason) {
disconnect(reason, "");
}
@Override
public void disconnect(DisconnectReason reason, String message) {
close.lock();
try {
if (isRunning()) {
disconnectListener.notifyDisconnect(reason, message);
getService().notifyError(new TransportException(reason, "Disconnected"));
sendDisconnect(reason, message);
finishOff();
close.set();
}
} finally {
close.unlock();
}
}
@Override
public void setDisconnectListener(DisconnectListener listener) {
this.disconnectListener = listener == null ? this : listener;
}
@Override
public DisconnectListener getDisconnectListener() {
return disconnectListener;
}
@Override
public long write(SSHPacket payload)
throws TransportException {
writeLock.lock();
try {
if (kexer.isKexOngoing()) {
// Only transport layer packets (1 to 49) allowed except SERVICE_REQUEST
final Message m = Message.fromByte(payload.array()[payload.rpos()]);
if (!m.in(1, 49) || m == Message.SERVICE_REQUEST) {
assert m != Message.KEXINIT;
kexer.waitForDone();
}
} else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet
kexer.startKex(true);
final long seq = encoder.encode(payload);
try {
connInfo.out.write(payload.array(), payload.rpos(), payload.available());
connInfo.out.flush();
} catch (IOException ioe) {
throw new TransportException(ioe);
}
return seq;
} finally {
writeLock.unlock();
}
}
private void sendDisconnect(DisconnectReason reason, String message) {
if (message == null)
message = "";
log.debug("Sending SSH_MSG_DISCONNECT: reason=[{}], msg=[{}]", reason, message);
try {
write(new SSHPacket(Message.DISCONNECT)
.putUInt32(reason.toInt())
.putString(message)
.putString(""));
} catch (IOException worthless) {
log.debug("Error writing packet: {}", worthless.toString());
}
}
/**
* This is where all incoming packets are handled. If they pertain to the transport layer, they are handled here;
* otherwise they are delegated to the active service instance if any via {@link Service#handle}.
*
* Even among the transport layer specific packets, key exchange packets are delegated to {@link
* KeyExchanger#handle}.
*
* This method is called in the context of the {@link #reader} thread via {@link Decoder#received} when a full
* packet has been decoded.
*
* @param msg the message identifier
* @param buf buffer containing rest of the packet
* @throws SSHException if an error occurs during handling (unrecoverable)
*/
@Override
public void handle(Message msg, SSHPacket buf)
throws SSHException {
this.msg = msg;
log.trace("Received packet {}", msg);
if (msg.geq(50)) { // not a transport layer packet
service.handle(msg, buf);
} else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet
kexer.handle(msg, buf);
} else {
switch (msg) {
case DISCONNECT:
gotDisconnect(buf);
break;
case IGNORE:
log.debug("Received SSH_MSG_IGNORE");
break;
case UNIMPLEMENTED:
gotUnimplemented(buf);
break;
case DEBUG:
gotDebug(buf);
break;
case SERVICE_ACCEPT:
gotServiceAccept();
break;
case EXT_INFO:
log.debug("Received SSH_MSG_EXT_INFO");
break;
case USERAUTH_BANNER:
log.debug("Received USERAUTH_BANNER");
break;
default:
sendUnimplemented();
break;
}
}
}
private void gotDebug(SSHPacket buf)
throws TransportException {
try {
final boolean display = buf.readBoolean();
final String message = buf.readString();
log.debug("Received SSH_MSG_DEBUG (display={}) '{}'", display, message);
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
}
private void gotDisconnect(SSHPacket buf)
throws TransportException {
try {
final DisconnectReason code = DisconnectReason.fromInt(buf.readUInt32AsInt());
final String message = buf.readString();
log.info("Received SSH_MSG_DISCONNECT (reason={}, msg={})", code, message);
throw new TransportException(code, message);
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
}
private void gotServiceAccept()
throws TransportException {
serviceAccept.lock();
try {
if (!serviceAccept.hasWaiters())
throw new TransportException(DisconnectReason.PROTOCOL_ERROR,
"Got a service accept notification when none was awaited");
// Immediately switch to next service to prevent race condition mentioned in #559
setService(nextService);
serviceAccept.set();
} finally {
serviceAccept.unlock();
}
}
/**
* Got an SSH_MSG_UNIMPLEMENTED, so lets see where we're at and act accordingly.
*
* @param packet The 'unimplemented' packet received
* @throws TransportException
*/
private void gotUnimplemented(SSHPacket packet)
throws SSHException {
long seqNum = packet.readUInt32();
log.debug("Received SSH_MSG_UNIMPLEMENTED #{}", seqNum);
if (kexer.isKexOngoing())
throw new TransportException("Received SSH_MSG_UNIMPLEMENTED while exchanging keys");
getService().notifyUnimplemented(seqNum);
}
private void finishOff() {
reader.interrupt();
IOUtils.closeQuietly(connInfo.in);
IOUtils.closeQuietly(connInfo.out);
}
public void die(Exception ex) {
close.lock();
try {
if (!close.isSet()) {
log.error("Dying because - {}", ex.getMessage(), ex);
final SSHException causeOfDeath = SSHException.chainer.chain(ex);
disconnectListener.notifyDisconnect(causeOfDeath.getDisconnectReason(), causeOfDeath.getMessage());
ErrorDeliveryUtil.alertEvents(causeOfDeath, close, serviceAccept);
kexer.notifyError(causeOfDeath);
getService().notifyError(causeOfDeath);
setService(nullService);
{ // Perhaps can send disconnect packet to server
final boolean didNotReceiveDisconnect = msg != Message.DISCONNECT;
final boolean gotRequiredInfo = causeOfDeath.getDisconnectReason() != DisconnectReason.UNKNOWN;
if (didNotReceiveDisconnect && gotRequiredInfo)
sendDisconnect(causeOfDeath.getDisconnectReason(), causeOfDeath.getMessage());
}
finishOff();
close.set();
}
} finally {
close.unlock();
}
}
String getClientID() {
return clientID;
}
String getServerID() {
return serverID;
}
Encoder getEncoder() {
return encoder;
}
Decoder getDecoder() {
return decoder;
}
ReentrantLock getWriteLock() {
return writeLock;
}
ConnInfo getConnInfo() {
return connInfo;
}
public void setHostKeyAlgorithm(KeyAlgorithm keyAlgorithm) {
this.hostKeyAlgorithm = keyAlgorithm;
}
@Override
public KeyAlgorithm getHostKeyAlgorithm() {
return this.hostKeyAlgorithm;
}
public void setRSASHA2Support(boolean rsaSHA2Support) {
this.rsaSHA2Support = rsaSHA2Support;
}
@Override
public KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException {
if (keyType != KeyType.RSA || !rsaSHA2Support) {
return Factory.Named.Util.create(getConfig().getKeyAlgorithms(), keyType.toString());
}
List> factories = getConfig().getKeyAlgorithms();
if (factories != null)
for (Factory.Named f : factories)
if (f.getName().equals("ssh-rsa") || KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS.contains(f.getName()))
return f.create();
throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy