
com.sumologic.shellbase.ShellBase.scala Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.sumologic.shellbase
import java.io.{File, FileOutputStream, PrintStream}
import java.util.concurrent.TimeUnit
import java.util.regex.Pattern
import com.sumologic.shellbase.cmdline.RichCommandLine._
import com.sumologic.shellbase.cmdline.{CommandLineArgument, CommandLineFlag, CommandLineOption}
import com.sumologic.shellbase.interrupts.{InterruptKeyMonitor, KillableSingleThread}
import com.sumologic.shellbase.notifications.{InMemoryShellNotificationManager, NotificationCommandSet, RingingNotification, ShellNotificationManager}
import com.sumologic.shellbase.timeutil.{TimeFormats, TimedBlock}
import jline.console.ConsoleReader
import jline.console.completer.{ArgumentCompleter, Completer, NullCompleter, StringsCompleter}
import jline.console.history.FileHistory
import org.apache.commons.cli.{CommandLine, GnuParser, HelpFormatter, Options, ParseException, Option => CLIOption}
import org.apache.commons.io.output.TeeOutputStream
import org.slf4j.LoggerFactory
import scala.collection.JavaConversions._
import scala.concurrent.duration.Duration
import scala.util.{Success, Try}
/**
* A shell base that can be used to build shell like applications.
*/
abstract class ShellBase(val name: String) {
// -----------------------------------------------------------------------------------------------
// Abstract methods to implement.
// -----------------------------------------------------------------------------------------------
/**
* Return the list of commands.
*/
def commands: Seq[ShellCommand] = List[ShellCommand]()
/**
* Return additional command line options.
*/
def additionalOptions: Seq[CLIOption] = List[CLIOption]()
/**
* Process additional command line options.
*/
def init(cmdLine: CommandLine): Boolean = true
/**
* Override for a custom prompt.
*/
def prompt: String = "$"
/**
* Return a custom banner.
*/
def banner: String = ""
/**
* Where to look for scripts.
*/
def scriptDir: File = new File("scripts/")
/**
* File extension for scripts.
*/
def scriptExtension: String = name
/**
* Prints commands to stdout as they're executed
*/
def verboseMode: Boolean = false
/**
* Name of the history file.
*/
def historyPath: File = new File("%s/.%s_history".format(System.getProperty("user.home"), name))
/**
* Manages notifications
*/
lazy val notificationManager: ShellNotificationManager = new InMemoryShellNotificationManager(Seq(new RingingNotification))
/**
* Pre, post command hooks
*/
protected[this] def preCommandHooks: Seq[ShellCommandSet.ExecuteHook] = List()
protected[this] def postCommandHooks: Seq[ShellCommandSet.ExecuteHook] = List()
private val _logger = LoggerFactory.getLogger(getClass)
/**
* Exit the shell.
*/
protected def exitShell(exitValue: Int) {
System.exit(exitValue)
}
// rootSet is visible for testing only
private[shellbase] val rootSet = new ShellCommandSet(name, "")
rootSet.preExecuteHooks.appendAll(preCommandHooks)
rootSet.postExecuteHooks.appendAll(postCommandHooks)
private val history = new FileHistory(historyPath)
history.setMaxSize(1000)
// NOTE(stefan, 2014-01-06): From the jline JavaDoc:Implementers should install shutdown hook to
// call {@link FileHistory#flush} to save history to disk.
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run() = {
history.flush()
interruptKeyMonitor.shutdown()
}
})
private val interruptKeyMonitor = new InterruptKeyMonitor()
interruptKeyMonitor.init()
private val reader = new ConsoleReader()
reader.setHistory(history)
reader.setHandleUserInterrupt(false)
def main(args: Array[String]) = {
Thread.currentThread.setName("Shell main")
try {
val result = actualMain(args)
exitShell(result)
} finally {
interruptKeyMonitor.shutdown()
}
}
def actualMain(args: Array[String]): Int = {
val options = new Options
for (optn <- additionalOptions) {
options.addOption(optn)
}
options.addOption(null, "no-exit", false,
"Don't exit after executing the command passed on the command line.")
val cmdLine: CommandLine = try {
new GnuParser().parse(options, args, true)
} catch {
case e: ParseException =>
println(e.getMessage)
new HelpFormatter().printHelp(name, options)
return 1
}
if (init(cmdLine)) {
initializeCommands()
val arguments = cmdLine.getArgs
if (arguments.nonEmpty) {
val interactiveAfterScript = cmdLine.hasOption("no-exit")
if (interactiveAfterScript) {
reader.clearScreen
println(banner)
}
val scriptSucceeded = rootSet.executeLine(List[String]() ++ arguments)
if (!interactiveAfterScript && !scriptSucceeded) {
println(s"Execution failed! Input was ${arguments.mkString(" ")}")
return 1
}
if (interactiveAfterScript) {
interactiveMainLoop()
}
} else {
reader.clearScreen
println(banner)
interactiveMainLoop()
}
0
} else {
println("Could not initialize!")
1
}
}
final def initializeCommands() {
val customCommands = commands
rootSet.commands ++= customCommands
validateCommands()
}
def validateCommands() = {
rootSet.validateCommands()
}
private def interactiveMainLoop() {
println(" Enter your commands below. Type 'help' for help. ")
reader.setAutoprintThreshold(128) // allow more completions candidates without prompting
reader.addCompleter(rootSet.argCompleter)
var keepRunning = true
while (keepRunning) {
val line = reader.readLine(prompt)
// Line is null on CTRL-D...
if (line == null) {
println()
keepRunning = false
} else {
runKillableCommand(line)
}
}
println("Exiting...")
}
private def runKillableCommand(line: String): Boolean = {
val commandRunner = new KillableSingleThread(runCommand(line))
// Wait a bit for completion, so quick commands don't need to go through the keyMonitor.
val startTime = now
commandRunner.start()
commandRunner.waitForCompletion(Duration(200, TimeUnit.MILLISECONDS))
// Start the keyMonitor to watch for key interrupts.
if (!commandRunner.future.isCompleted) {
interruptKeyMonitor.startMonitoring(interruptCallback = commandRunner.synchronized {
if (!commandRunner.future.isCompleted) {
println(s"Caught interrupt.")
println(s"Killing command with 1s grace period: `$line`...")
commandRunner.kill(Duration(1, TimeUnit.SECONDS))
}
})
commandRunner.waitForCompletion(Duration.Inf)
val runTimeInSec = (now - startTime) / 1000
val timing = TimeFormats.formatAsTersePeriod(now - startTime)
_logger.debug(s"Done running command `$line` (took: ${runTimeInSec}s [$timing]).")
interruptKeyMonitor.stopMonitoring()
}
commandRunner.future.value match {
case Some(Success(result)) => result
case _ => false
}
}
def runCommand(line: String): Boolean = {
val tokens = parseLine(line)
val tokensIterator = tokens.iterator
while (tokensIterator.hasNext) {
val newTokens = tokensIterator.takeWhile(_.trim != "&&").toList
_logger.debug(s"Executing: $newTokens")
if (verboseMode) {
println(s"Executing: $newTokens")
}
val exitStatus = runSingleTokenizedCommand(newTokens)
if (!exitStatus) {
if (tokensIterator.hasNext) {
println(s"Execution failed for `${newTokens.mkString(" ")}`.")
println(s"Skipping the remaining commands: `${tokensIterator.toList.mkString(" ")}`")
}
return false
}
}
true
}
private def runSingleTokenizedCommand(tokens: List[String]): Boolean = {
val out = rootSet.executeLine(tokens)
val msg = if (out) {
s"[$name] Command finished successfully"
} else {
s"[$name] Command failed"
}
notificationManager.notify(msg)
out
}
private def now = System.currentTimeMillis()
protected def parseLine(line: String): List[String] = {
if (line.trim.startsWith("#") || line.trim.length < 1) {
return List[String]()
}
var parsedLine = List[String]()
val regex = Pattern.compile("[^\\s\"'`]+|\"([^\"]*)\"|'([^']*)'|`([^`]*)`")
val regexMatcher = regex.matcher(line)
while (regexMatcher.find()) {
if (regexMatcher.group(1) != null) {
// Add double-quoted string without the quotes
parsedLine :+= regexMatcher.group(1)
} else if (regexMatcher.group(2) != null) {
// Add single-quoted string without the quotes
parsedLine :+= regexMatcher.group(2)
} else if (regexMatcher.group(3) != null) {
// Add `-quoted string with the quotes
parsedLine :+= "`" + regexMatcher.group(3) + "`"
} else {
// Add unquoted word
parsedLine :+= regexMatcher.group()
}
}
parsedLine
}
// -----------------------------------------------------------------------------------------------
// Little helpers.
// -----------------------------------------------------------------------------------------------
def getListParameter(cmdLine: CommandLine, name: String): Seq[String] = {
val str = cmdLine.getOptionValue(name)
if (str == null) {
return List[String]()
}
str.split(",")
}
// -----------------------------------------------------------------------------------------------
// Root commands.
// -----------------------------------------------------------------------------------------------
val subCommandExtractor = "`([^`]*)`".r // Regex extractor for commands within `-quotes.
rootSet.commands += new ShellCommand("clear", "Clear the screen.") {
def execute(cmdLine: CommandLine) = {
new ConsoleReader().clearScreen
true
}
}
rootSet.commands += new ShellCommand("exit", "Quit the shell.", List("quit")) {
def execute(cmdLine: CommandLine) = {
System.exit(0)
true
}
}
rootSet.commands += new ShellCommand("sleep", "Sleeps the specified time period.", List("zzz")) {
override def maxNumberOfArguments = 1
def execute(cmdLine: CommandLine) = {
val showSleepIndicator = cmdLine.checkFlag(ShowIndicatorFlag)
cmdLine.get(PeriodArgument) match {
case Some(period) =>
val milliseconds = TimeFormats.parseTersePeriod(period).
getOrElse(throw new IllegalArgumentException(s"Could not parse $period"))
val canonicalPeriod = TimeFormats.formatWithMillis(milliseconds)
println(s"Sleeping $canonicalPeriod.")
val startTime = now
var timeRemaining = milliseconds
while (timeRemaining > 0) {
try {
Thread.sleep(1000L min timeRemaining)
} catch {
case ie: InterruptedException => _logger.debug("Error sleeping", ie)
}
if (showSleepIndicator) {
print(".")
}
timeRemaining = startTime + milliseconds - now
}
if (showSleepIndicator) {
println()
}
true
case None =>
println("Missing argument to sleep command!")
false
}
}
private val PeriodArgument = new CommandLineArgument("period", 0, true)
private val ShowIndicatorFlag = new CommandLineFlag("v", "verbose", "whether to print . every second while sleeping")
override def addOptions(opts: Options) {
opts += PeriodArgument
opts += ShowIndicatorFlag
}
}
rootSet.commands += new ShellCommand("echo", "Write to the screen") {
override def maxNumberOfArguments = -1 //unlimited
def execute(cmdLine: CommandLine): Boolean = {
if (cmdLine.getArgList.size() > 0) {
println(cmdLine.getArgList.mkString(" "))
}
true
}
}
rootSet.commands += new ShellCommand("tee", "Forks the stdout of a command so it also prints to a file") {
private val CommandArgument = new CommandLineArgument("command", 0, true)
private val OutputFileOption = new CommandLineOption("o", "outputFile", false, "Filename of output file (defaults to ~/tee.out)")
private val AppendFileFlag = new CommandLineFlag("a", "append", "Append the output to the file rather than overwriting it")
override def maxNumberOfArguments = 1
override def addOptions(opts: Options) {
opts += CommandArgument
opts += OutputFileOption
opts += AppendFileFlag
}
def execute(cmdLine: CommandLine) = {
val outputFile = cmdLine.get(OutputFileOption).getOrElse(System.getProperty("user.home") + s"/tee.out")
val appendFile = cmdLine.checkFlag(AppendFileFlag)
cmdLine.get(CommandArgument) match {
case Some(subCommandExtractor(cmd)) =>
val fileOut = new FileOutputStream(outputFile, appendFile)
val newOut = new PrintStream(new TeeOutputStream(Console.out, fileOut))
val status = Console.withOut(newOut) {
println(s"Running `$cmd` and outputting to '$outputFile' [append=$appendFile].")
runCommand(cmd)
}
Try(fileOut.close())
status
case badCmd =>
println(s"Usage: tee ``, but found $badCmd.")
false
}
}
}
rootSet.commands += new ShellCommand("time", "Measure the execution time of a command") {
private val CommandArgument = new CommandLineArgument("command", 0, true)
override def maxNumberOfArguments = 1
override def addOptions(opts: Options) {
opts += CommandArgument
}
def execute(cmdLine: CommandLine) = {
cmdLine.get(CommandArgument) match {
case Some(subCommandExtractor(cmd)) =>
val start = now
val exitStatus = runCommand(cmd)
val dt = now - start
val dtMessage = s"Execution took $dt ms (${TimeFormats.formatAsTersePeriod(dt)})"
_logger.info(s"$dtMessage for `$cmd`")
println(s"\n$dtMessage\n")
exitStatus
case badCmd =>
println(s"Usage: time ``, but found $badCmd.")
false
}
}
}
rootSet.commands += new ShellCommand("run_script",
"Run the script from the specified file.", List("script")) {
override def maxNumberOfArguments = -1
def execute(cmdLine: CommandLine): Boolean = {
val continue = cmdLine.hasOption("continue")
var scriptFile: File = null
val args: Array[String] = cmdLine.getArgs
if (args.length < 1) {
printf("Please specify a script to run!")
return false
}
val scriptFileName = args(0)
for (pattern <- List("scripts/%s.dsh", "scripts/%s", "%s")) {
val tmp = new File(pattern.format(scriptFileName))
if (tmp.exists) {
scriptFile = tmp
}
}
if (scriptFile == null) {
print(s"Could not find the script $scriptFileName! Please make sure the script file exists locally.")
return false
}
// Execute the script, line by line.
TimedBlock(s"Executing script $scriptFileName", println(_)) {
val scriptLines = new ScriptRenderer(scriptFile, args.tail).
getLines.filterNot(parseLine(_).isEmpty)
require(scriptLines.nonEmpty, s"No non-comment lines found in $scriptFileName")
for (line <- scriptLines) {
val success = runCommand(line)
if (!continue && !success) {
return false
}
}
}
true
}
override def argCompleter: Completer = {
if (scriptDir.exists) {
var scriptNames = scriptDir.listFiles.filter(_.isFile).map(_.getName)
if (scriptExtension != null) {
scriptNames = scriptNames.filter(_.endsWith(scriptExtension))
}
new ArgumentCompleter(List(new StringsCompleter(scriptNames: _*), new NullCompleter))
} else {
new NullCompleter
}
}
override def addOptions(opts: Options) = {
opts.addOption("c", "continue", false, "Continue even if there was a failure in execution.")
}
}
rootSet.commands += new NotificationCommandSet(notificationManager) // NOTE(chris, 2014-02-05): This has to be near the end for overrides to work
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy