org.jline.builtins.ssh.Ssh Maven / Gradle / Ivy
The newest version!
/*
* Copyright (c) 2002-2017, the original author or authors.
*
* This software is distributable under the BSD license. See the terms of the
* BSD license in the documentation provided with this software.
*
* https://opensource.org/licenses/BSD-3-Clause
*/
package org.jline.builtins.ssh;
import java.io.*;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.auth.keyboard.UserInteraction;
import org.apache.sshd.client.channel.ChannelShell;
import org.apache.sshd.client.channel.ClientChannel;
import org.apache.sshd.client.channel.ClientChannelEvent;
import org.apache.sshd.client.future.ConnectFuture;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.NamedResource;
import org.apache.sshd.common.channel.PtyMode;
import org.apache.sshd.common.config.keys.FilePasswordProvider;
import org.apache.sshd.common.session.SessionContext;
import org.apache.sshd.common.util.io.NoCloseInputStream;
import org.apache.sshd.common.util.io.NoCloseOutputStream;
import org.apache.sshd.scp.server.ScpCommandFactory;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
import org.apache.sshd.server.session.ServerSession;
import org.apache.sshd.sftp.server.SftpSubsystemFactory;
import org.jline.builtins.Options;
import org.jline.builtins.Options.HelpException;
import org.jline.reader.LineReader;
import org.jline.terminal.Attributes;
import org.jline.terminal.Size;
import org.jline.terminal.Terminal;
public class Ssh {
public static final String[] functions = {"ssh", "sshd"};
public static class ShellParams {
private final Map env;
private final Terminal terminal;
private final Runnable closer;
private final ServerSession session;
public ShellParams(Map env, ServerSession session, Terminal terminal, Runnable closer) {
this.env = env;
this.session = session;
this.terminal = terminal;
this.closer = closer;
}
public Map getEnv() {
return env;
}
public ServerSession getSession() {
return session;
}
public Terminal getTerminal() {
return terminal;
}
public Runnable getCloser() {
return closer;
}
}
public static class ExecuteParams {
private final String command;
private final Map env;
private final ServerSession session;
private final InputStream in;
private final OutputStream out;
private final OutputStream err;
public ExecuteParams(String command, Map env, ServerSession session, InputStream in, OutputStream out, OutputStream err) {
this.command = command;
this.session = session;
this.env = env;
this.in = in;
this.out = out;
this.err = err;
}
public String getCommand() {
return command;
}
public Map getEnv() {
return env;
}
public ServerSession getSession() {
return session;
}
public InputStream getIn() {
return in;
}
public OutputStream getOut() {
return out;
}
public OutputStream getErr() {
return err;
}
}
private static final int defaultPort = 2022;
private final Consumer shell;
private final Consumer execute;
private final Supplier serverBuilder;
private final Supplier clientBuilder;
private SshServer server;
private int port;
private String ip;
public Ssh(Consumer shell,
Consumer execute,
Supplier serverBuilder,
Supplier clientBuilder) {
this.shell = shell;
this.execute = execute;
this.serverBuilder = serverBuilder;
this.clientBuilder = clientBuilder;
}
public void ssh(Terminal terminal,
LineReader reader,
String user,
InputStream stdin,
PrintStream stdout,
PrintStream stderr,
String[] argv) throws Exception {
final String[] usage = {"ssh - connect to a server using ssh",
"Usage: ssh [user@]hostname [command]",
" -? --help show help"};
Options opt = Options.compile(usage).parse(argv, true);
List args = opt.args();
if (opt.isSet("help") || args.isEmpty()) {
throw new HelpException(opt.usage());
}
String username = user;
String hostname = args.remove(0);
int port = this.port;
String command = null;
int idx = hostname.indexOf('@');
if (idx >= 0) {
username = hostname.substring(0, idx);
hostname = hostname.substring(idx + 1);
}
idx = hostname.indexOf(':');
if (idx >= 0) {
port = Integer.parseInt(hostname.substring(idx + 1));
hostname = hostname.substring(0, idx);
}
if (!args.isEmpty()) {
command = String.join(" ", args);
}
try (SshClient client = clientBuilder.get()) {
JLineUserInteraction ui = new JLineUserInteraction(terminal, reader, stderr);
client.setFilePasswordProvider(ui);
client.setUserInteraction(ui);
client.start();
try (ClientSession sshSession = connectWithRetries(terminal.writer(), client, username, hostname, port, 3)) {
sshSession.auth().verify();
if (command != null) {
ClientChannel channel = sshSession.createChannel("exec", command + "\n");
channel.setIn(new ByteArrayInputStream(new byte[0]));
channel.setOut(new NoCloseOutputStream(stdout));
channel.setErr(new NoCloseOutputStream(stderr));
channel.open().verify();
channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0);
} else {
final ChannelShell channel = sshSession.createShellChannel();
Attributes attributes = terminal.enterRawMode();
try {
Map modes = new HashMap<>();
// Control chars
modes.put(PtyMode.VINTR, attributes.getControlChar(Attributes.ControlChar.VINTR));
modes.put(PtyMode.VQUIT, attributes.getControlChar(Attributes.ControlChar.VQUIT));
modes.put(PtyMode.VERASE, attributes.getControlChar(Attributes.ControlChar.VERASE));
modes.put(PtyMode.VKILL, attributes.getControlChar(Attributes.ControlChar.VKILL));
modes.put(PtyMode.VEOF, attributes.getControlChar(Attributes.ControlChar.VEOF));
modes.put(PtyMode.VEOL, attributes.getControlChar(Attributes.ControlChar.VEOL));
modes.put(PtyMode.VEOL2, attributes.getControlChar(Attributes.ControlChar.VEOL2));
modes.put(PtyMode.VSTART, attributes.getControlChar(Attributes.ControlChar.VSTART));
modes.put(PtyMode.VSTOP, attributes.getControlChar(Attributes.ControlChar.VSTOP));
modes.put(PtyMode.VSUSP, attributes.getControlChar(Attributes.ControlChar.VSUSP));
modes.put(PtyMode.VDSUSP, attributes.getControlChar(Attributes.ControlChar.VDSUSP));
modes.put(PtyMode.VREPRINT, attributes.getControlChar(Attributes.ControlChar.VREPRINT));
modes.put(PtyMode.VWERASE, attributes.getControlChar(Attributes.ControlChar.VWERASE));
modes.put(PtyMode.VLNEXT, attributes.getControlChar(Attributes.ControlChar.VLNEXT));
modes.put(PtyMode.VSTATUS, attributes.getControlChar(Attributes.ControlChar.VSTATUS));
modes.put(PtyMode.VDISCARD, attributes.getControlChar(Attributes.ControlChar.VDISCARD));
// Input flags
modes.put(PtyMode.IGNPAR, getFlag(attributes, Attributes.InputFlag.IGNPAR));
modes.put(PtyMode.PARMRK, getFlag(attributes, Attributes.InputFlag.PARMRK));
modes.put(PtyMode.INPCK, getFlag(attributes, Attributes.InputFlag.INPCK));
modes.put(PtyMode.ISTRIP, getFlag(attributes, Attributes.InputFlag.ISTRIP));
modes.put(PtyMode.INLCR, getFlag(attributes, Attributes.InputFlag.INLCR));
modes.put(PtyMode.IGNCR, getFlag(attributes, Attributes.InputFlag.IGNCR));
modes.put(PtyMode.ICRNL, getFlag(attributes, Attributes.InputFlag.ICRNL));
modes.put(PtyMode.IXON, getFlag(attributes, Attributes.InputFlag.IXON));
modes.put(PtyMode.IXANY, getFlag(attributes, Attributes.InputFlag.IXANY));
modes.put(PtyMode.IXOFF, getFlag(attributes, Attributes.InputFlag.IXOFF));
// Local flags
modes.put(PtyMode.ISIG, getFlag(attributes, Attributes.LocalFlag.ISIG));
modes.put(PtyMode.ICANON, getFlag(attributes, Attributes.LocalFlag.ICANON));
modes.put(PtyMode.ECHO, getFlag(attributes, Attributes.LocalFlag.ECHO));
modes.put(PtyMode.ECHOE, getFlag(attributes, Attributes.LocalFlag.ECHOE));
modes.put(PtyMode.ECHOK, getFlag(attributes, Attributes.LocalFlag.ECHOK));
modes.put(PtyMode.ECHONL, getFlag(attributes, Attributes.LocalFlag.ECHONL));
modes.put(PtyMode.NOFLSH, getFlag(attributes, Attributes.LocalFlag.NOFLSH));
modes.put(PtyMode.TOSTOP, getFlag(attributes, Attributes.LocalFlag.TOSTOP));
modes.put(PtyMode.IEXTEN, getFlag(attributes, Attributes.LocalFlag.IEXTEN));
// Output flags
modes.put(PtyMode.OPOST, getFlag(attributes, Attributes.OutputFlag.OPOST));
modes.put(PtyMode.ONLCR, getFlag(attributes, Attributes.OutputFlag.ONLCR));
modes.put(PtyMode.OCRNL, getFlag(attributes, Attributes.OutputFlag.OCRNL));
modes.put(PtyMode.ONOCR, getFlag(attributes, Attributes.OutputFlag.ONOCR));
modes.put(PtyMode.ONLRET, getFlag(attributes, Attributes.OutputFlag.ONLRET));
channel.setPtyModes(modes);
channel.setPtyColumns(terminal.getWidth());
channel.setPtyLines(terminal.getHeight());
channel.setAgentForwarding(true);
channel.setEnv("TERM", terminal.getType());
// TODO: channel.setEnv("LC_CTYPE", terminal.encoding().toString());
channel.setIn(new NoCloseInputStream(stdin));
channel.setOut(new NoCloseOutputStream(stdout));
channel.setErr(new NoCloseOutputStream(stderr));
channel.open().verify();
Terminal.SignalHandler prevWinchHandler = terminal.handle(Terminal.Signal.WINCH, signal -> {
try {
Size size = terminal.getSize();
channel.sendWindowChange(size.getColumns(), size.getRows());
} catch (IOException e) {
// Ignore
}
});
Terminal.SignalHandler prevQuitHandler = terminal.handle(Terminal.Signal.QUIT, signal -> {
try {
channel.getInvertedIn().write(attributes.getControlChar(Attributes.ControlChar.VQUIT));
channel.getInvertedIn().flush();
} catch (IOException e) {
// Ignore
}
});
Terminal.SignalHandler prevIntHandler = terminal.handle(Terminal.Signal.INT, signal -> {
try {
channel.getInvertedIn().write(attributes.getControlChar(Attributes.ControlChar.VINTR));
channel.getInvertedIn().flush();
} catch (IOException e) {
// Ignore
}
});
Terminal.SignalHandler prevStopHandler = terminal.handle(Terminal.Signal.TSTP, signal -> {
try {
channel.getInvertedIn().write(attributes.getControlChar(Attributes.ControlChar.VDSUSP));
channel.getInvertedIn().flush();
} catch (IOException e) {
// Ignore
}
});
try {
channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0);
} finally {
terminal.handle(Terminal.Signal.WINCH, prevWinchHandler);
terminal.handle(Terminal.Signal.INT, prevIntHandler);
terminal.handle(Terminal.Signal.TSTP, prevStopHandler);
terminal.handle(Terminal.Signal.QUIT, prevQuitHandler);
}
} finally {
terminal.setAttributes(attributes);
}
}
}
}
}
private static int getFlag(Attributes attributes, Attributes.InputFlag flag) {
return attributes.getInputFlag(flag) ? 1 : 0;
}
private static int getFlag(Attributes attributes, Attributes.OutputFlag flag) {
return attributes.getOutputFlag(flag) ? 1 : 0;
}
private static int getFlag(Attributes attributes, Attributes.LocalFlag flag) {
return attributes.getLocalFlag(flag) ? 1 : 0;
}
private ClientSession connectWithRetries(PrintWriter stdout, SshClient client, String username, String host, int port, int maxAttempts) throws Exception {
ClientSession session = null;
int retries = 0;
do {
ConnectFuture future = client.connect(username, host, port);
future.await();
try {
session = future.getSession();
} catch (Exception ex) {
if (retries++ < maxAttempts) {
Thread.sleep(2 * 1000);
stdout.println("retrying (attempt " + retries + ") ...");
} else {
throw ex;
}
}
} while (session == null);
return session;
}
public void sshd(PrintStream stdout, PrintStream stderr, String[] argv) throws Exception {
final String[] usage = {"sshd - start an ssh server",
"Usage: sshd [-i ip] [-p port] start | stop | status",
" -i --ip=INTERFACE listen interface (default=127.0.0.1)",
" -p --port=PORT listen port (default=" + defaultPort + ")",
" -? --help show help"};
Options opt = Options.compile(usage).parse(argv, true);
List args = opt.args();
if (opt.isSet("help") || args.isEmpty()) {
throw new HelpException(opt.usage());
}
String command = args.get(0);
if ("start".equals(command)) {
if (server != null) {
throw new IllegalStateException("sshd is already running on port " + port);
}
ip = opt.get("ip");
port = opt.getNumber("port");
start();
status(stdout);
} else if ("stop".equals(command)) {
if (server == null) {
throw new IllegalStateException("sshd is not running.");
}
stop();
} else if ("status".equals(command)) {
status(stdout);
} else {
throw opt.usageError("bad command: " + command);
}
}
private void status(PrintStream stdout) {
if (server != null) {
stdout.println("sshd is running on " + ip + ":" + port);
} else {
stdout.println("sshd is not running.");
}
}
private void start() throws IOException {
server = serverBuilder.get();
server.setPort(port);
server.setHost(ip);
server.setShellFactory(new ShellFactoryImpl(shell));
server.setCommandFactory(new ScpCommandFactory.Builder()
.withDelegate((channel, command) -> new ShellCommand(execute, command)).build());
server.setSubsystemFactories(Collections.singletonList(
new SftpSubsystemFactory.Builder().build()
));
server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider());
server.start();
}
private void stop() throws IOException {
try {
server.stop();
} finally {
server = null;
}
}
private static class JLineUserInteraction implements UserInteraction, FilePasswordProvider {
private final Terminal terminal;
private final LineReader reader;
private final PrintStream stderr;
public JLineUserInteraction(Terminal terminal, LineReader reader, PrintStream stderr) {
this.terminal = terminal;
this.reader = reader;
this.stderr = stderr;
}
@Override
public String getPassword(SessionContext session, NamedResource resourceKey, int retryIndex) throws IOException {
return readLine("Enter password for " + resourceKey + ":", false);
}
@Override
public void welcome(ClientSession session, String banner, String lang) {
terminal.writer().println(banner);
}
@Override
public String[] interactive(ClientSession s, String name, String instruction, String lang, String[] prompt, boolean[] echo) {
String[] answers = new String[prompt.length];
try {
for (int i = 0; i < prompt.length; i++) {
answers[i] = readLine(prompt[i], echo[i]);
}
} catch (Exception e) {
stderr.append(e.getClass().getSimpleName()).append(" while read prompts: ").println(e.getMessage());
}
return answers;
}
@Override
public boolean isInteractionAllowed(ClientSession session) {
return true;
}
@Override
public void serverVersionInfo(ClientSession session, List lines) {
for (String l : lines) {
terminal.writer().append('\t').println(l);
}
}
@Override
public String getUpdatedPassword(ClientSession session, String prompt, String lang) {
try {
return readLine(prompt, false);
} catch (Exception e) {
stderr.append(e.getClass().getSimpleName()).append(" while reading password: ").println(e.getMessage());
}
return null;
}
private String readLine(String prompt, boolean echo) {
return reader.readLine(prompt + " ", echo ? null : '\0');
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy