All Downloads are FREE. Search and download functionalities are using the official Maven repository.

me.bazhenov.groovysh.GroovyShellCommand Maven / Gradle / Ivy

There is a newer version: 2.2.3
Show newest version
package me.bazhenov.groovysh;

import groovy.lang.Binding;
import groovy.lang.Closure;
import me.bazhenov.groovysh.thread.ServerSessionAwareThreadFactory;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.session.helpers.AbstractSession;
import org.apache.sshd.server.Command;
import org.apache.sshd.server.Environment;
import org.apache.sshd.server.ExitCallback;
import org.apache.sshd.server.SessionAware;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.session.ServerSession;
import org.codehaus.groovy.tools.shell.Groovysh;
import org.codehaus.groovy.tools.shell.IO;

import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Map;

import static java.util.Arrays.asList;
import static me.bazhenov.groovysh.GroovyShellService.SHELL_KEY;

class GroovyShellCommand implements Command, SessionAware {

	private final SshServer sshd;
	private final Map bindings;
	private final List defaultScripts;
	private final ServerSessionAwareThreadFactory threadFactory;
	private InputStream in;
	private OutputStream out;
	private OutputStream err;
	private ExitCallback callback;
	private Thread wrapper;
	private ServerSession session;

	public GroovyShellCommand(SshServer sshd, Map bindings, List defaultScripts, ServerSessionAwareThreadFactory threadFactory) {
		this.sshd = sshd;
		this.bindings = bindings;
		this.defaultScripts = defaultScripts;
		this.threadFactory = threadFactory;
	}

	@Override
	public void setSession(ServerSession session) {
		this.session = session;
	}

	@Override
	public void setInputStream(InputStream in) {
		this.in = in;
	}

	@Override
	public void setOutputStream(OutputStream out) {
		this.out = out;
	}

	@Override
	public void setErrorStream(OutputStream err) {
		this.err = err;
	}

	@Override
	public void setExitCallback(ExitCallback callback) {
		this.callback = callback;
	}

	@Override
	public void start(final Environment env) throws IOException {
		TtyFilterOutputStream out = new TtyFilterOutputStream(this.out);
		TtyFilterOutputStream err = new TtyFilterOutputStream(this.err);

		IO io = new IO(in, out, err);
		io.setVerbosity(IO.Verbosity.DEBUG);
		final Groovysh shell = new Groovysh(createBinding(bindings, out, err), io);
		shell.setErrorHook(new Closure(this) {
			@Override
			public Object call(Object... args) {
				if (args[0] instanceof InterruptedIOException || args[0] instanceof SshException) {
					// Stopping groovysh thread in case of broken client channel
					shell.getRunner().setRunning(false);
				}
				return shell.getDefaultErrorHook().call(args);
			}
		});

		try {
			loadDefaultScripts(shell);
		} catch (Exception e) {
			createPrintStream(err).println("Unable to load default scripts: "
				+ e.getClass().getName() + ": " + e.getMessage());
		}

		session.setAttribute(SHELL_KEY, shell);

		wrapper = threadFactory.newThread(new Runnable() {
			@Override
			public void run() {
				try {
					SshTerminal.registerEnvironment(env);
					shell.run("");
					callback.onExit(0);
				} catch (RuntimeException e) {
					callback.onExit(-1, e.getMessage());
				} catch (Error e) {
					callback.onExit(-1, e.getMessage());
				}
			}
		}, session);
		wrapper.start();
	}

	private Binding createBinding(Map objects, OutputStream out, OutputStream err)
		throws UnsupportedEncodingException {
		Binding binding = new Binding();

		if (objects != null)
			for (Map.Entry row : objects.entrySet())
				binding.setVariable(row.getKey(), row.getValue());

		binding.setVariable("out", createPrintStream(out));
		binding.setVariable("err", createPrintStream(err));
		binding.setVariable("activeSessions", new Closure>(this) {
			@Override
			public List call() {
				return sshd.getActiveSessions();
			}
		});

		return binding;
	}

	private static PrintStream createPrintStream(OutputStream out) throws UnsupportedEncodingException {
		return new PrintStream(out, true, "utf8");
	}

	@SuppressWarnings({"unchecked", "serial"})
	private void loadDefaultScripts(final Groovysh shell) {
		if (!defaultScripts.isEmpty()) {
			Closure defaultResultHook = shell.getResultHook();

			try {
				// Set a "no-op closure so we don't get per-line value output when evaluating the default script
				shell.setResultHook(new Closure(this) {
					@Override
					public Groovysh call(Object... args) {
						return shell;
					}
				});

				org.codehaus.groovy.tools.shell.Command cmd = shell.getRegistry().find("load");
				for (String script : defaultScripts) {
					cmd.execute(asList(script));
				}
			} finally {
				// Restoring original result hook
				shell.setResultHook(defaultResultHook);
			}
		}
	}

	@Override
	public void destroy() {
		wrapper.interrupt();
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy