com.m11n.jdbc.ssh.SshTunnel Maven / Gradle / Ivy
package com.m11n.jdbc.ssh;
import com.jcraft.jsch.Channel;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.concurrent.atomic.AtomicInteger;
import static com.m11n.jdbc.ssh.SshConfiguration.*;
public class SshTunnel {
private static final Logger logger = LoggerFactory.getLogger(SshTunnel.class);
private SshConfiguration config;
private Session session;
private AtomicInteger localPort;
public SshTunnel(SshConfiguration config) {
this.config = config;
localPort = new AtomicInteger(Integer.valueOf(config.getProperty(CONFIG_PORT_AUTO)));
logger.info("Automatic local port assignment starts at: {}", localPort.get());
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
logger.info("Shutting down tunnel...");
SshTunnel.this.stop();
}
});
}
public void start() {
int assignedPort = 0;
try {
JSch jsch = new JSch();
String username = config.getProperty(CONFIG_USERNAME);
String password = config.getProperty(CONFIG_PASSWORD);
String keyPrivate = config.getProperty(CONFIG_KEY_PRIVATE);
String keyPublic = config.getProperty(CONFIG_KEY_PUBLIC);
String passphrase = config.getProperty(CONFIG_PASSPHRASE);
String knownHosts = config.getProperty(CONFIG_KNOWN_HOSTS);
String host = config.getProperty(CONFIG_HOST);
Integer port = Integer.valueOf(config.getProperty(CONFIG_PORT));
assert host!=null;
assert port!=null;
boolean useKey = (keyPrivate!=null && !"".equals(keyPrivate.trim()));
session = jsch.getSession(username, host, port);
jsch.setKnownHosts(knownHosts);
if(useKey) {
if(passphrase==null || "".equals(passphrase.trim())) {
jsch.addIdentity(keyPrivate, keyPublic);
} else {
jsch.addIdentity(keyPrivate, keyPublic, passphrase.getBytes());
}
} else {
session.setPassword(password);
}
session.setConfig(config.getProperties());
session.setDaemonThread(true);
// Connect
session.connect();
Channel channel = session.openChannel("shell");
channel.connect();
String forwardHost = config.getProperty(CONFIG_HOST_REMOTE);
Integer remotePort = Integer.valueOf(config.getProperty(CONFIG_PORT_REMOTE));
int nextPort = localPort.incrementAndGet();
// NOTE: scan max next 10 ports
for(int i=0; i<10; i++) {
if(isPortOpen("127.0.0.1", nextPort)) {
break;
}
nextPort = localPort.incrementAndGet();
}
assignedPort = session.setPortForwardingL(localPort.incrementAndGet(), forwardHost, remotePort);
if(logger.isDebugEnabled()) {
logger.debug("Server version: {}", session.getServerVersion());
logger.debug("Client version: {}", session.getClientVersion());
logger.debug("Host : {}", session.getHost());
logger.debug("Port : {}", session.getPort());
logger.debug("Forwarding : {}", session.getPortForwardingL());
logger.debug("Connected : {}", session.isConnected());
logger.debug("Private key : {}", useKey);
}
} catch (Exception e) {
logger.error(e.toString(), e);
}
if (assignedPort == 0) {
throw new RuntimeException("Port forwarding failed !");
}
}
public void stop() {
if(session!=null) {
session.disconnect();
if(logger.isDebugEnabled()) {
logger.debug("Disconnected.");
}
}
}
public Integer getLocalPort() {
return localPort.get();
}
public boolean isPortOpen(String ip, int port) {
try {
Socket socket = new Socket();
socket.connect(new InetSocketAddress(ip, port), 1000);
socket.close();
return false;
} catch (Exception ex) {
return true;
}
}
}