
ammonite.repl.interp.Preprocessor.scala Maven / Gradle / Ivy
package ammonite.repl.interp
import acyclic.file
import ammonite.repl._
import fastparse.all._
import scala.reflect.internal.Flags
import scala.tools.nsc.{Global => G}
import collection.mutable
/**
* Responsible for all scala-source-code-munging that happens within the
* Ammonite REPL.
*
* Performs several tasks:
*
* - Takes top-level Scala expressions and assigns them to `res{1, 2, 3, ...}`
* values so they can be accessed later in the REPL
*
* - Wraps the code snippet with an wrapper `object` since Scala doesn't allow
* top-level expressions
*
* - Mangles imports from our [[ImportData]] data structure into a source
* String
*
* - Combines all of these into a complete compilation unit ready to feed into
* the Scala compiler
*/
trait Preprocessor{
def transform(stmts: Seq[String],
resultIndex: String,
leadingSpaces: String,
pkgName: String,
indexedWrapperName: String,
imports: Imports,
printerTemplate: String => String): Res[Preprocessor.Output]
}
object Preprocessor{
private case class Expanded(code: String, printer: Seq[String])
case class Output(code: String, prefixCharLength: Int)
def errMsg(msg: String, code: String, expected: String, idx: Int): String = {
val locationString = {
val (first, last) = code.splitAt(idx)
val lastSnippet = last.split('\n').headOption.getOrElse("")
val firstSnippet = first.reverse.split('\n').lift(0).getOrElse("").reverse
firstSnippet + lastSnippet + "\n" + (" " * firstSnippet.length) + "^"
}
s"Syntax Error: $msg\n$locationString"
}
def splitScript(rawCode: String): Res[Seq[(String, Seq[String])]] = {
Parsers.splitScript(rawCode) match {
case f: Parsed.Failure =>
Timer("processScriptFailed 0b")
Res.Failure(None, errMsg(f.msg, rawCode, f.extra.traced.expected, f.index))
case s: Parsed.Success[Seq[(String, Seq[String])]] =>
Timer("processCorrectScript 0b")
var offset = 0
val blocks = mutable.Buffer[(String, Seq[String])]()
// comment holds comments or empty lines above the code which is not caught along with code
for( (comment, code) <- s.value){
val ncomment = comment + "\n"*offset
// 1 is added as Separator parser eats up the '\n' following @
offset = offset + comment.count(_ == '\n') + code.map(_.count(_ == '\n')).sum + 1
blocks.append((ncomment, code))
}
Res.Success(blocks)
}
}
def apply(parse: => String => Either[String, Seq[G#Tree]]): Preprocessor = new Preprocessor{
def transform(stmts: Seq[String],
resultIndex: String,
leadingSpaces: String,
pkgName: String,
indexedWrapperName: String,
imports: Imports,
printerTemplate: String => String) = for{
Preprocessor.Expanded(code, printer) <- expandStatements(stmts, resultIndex)
(wrappedCode, importsLength) = wrapCode(
pkgName, indexedWrapperName, leadingSpaces + code,
printerTemplate(printer.mkString(", ")),
imports
)
} yield Preprocessor.Output(wrappedCode, importsLength)
def Processor(cond: PartialFunction[(String, String, G#Tree), Preprocessor.Expanded]) = {
(code: String, name: String, tree: G#Tree) => cond.lift(name, code, tree)
}
def pprintSignature(ident: String, customMsg: Option[String]) = {
val customCode = customMsg.fold("_root_.scala.None")(x => s"""_root_.scala.Some("$x")""")
s"""
_root_.ammonite
.repl
.frontend
.ReplBridge
.repl
.Internal
.print($ident, $ident, "$ident", $customCode)
"""
}
def definedStr(definitionLabel: String, name: String) =
s"""
_root_.ammonite
.repl
.frontend
.ReplBridge
.repl
.Internal
.printDef("$definitionLabel", "$name")
"""
def pprint(ident: String) = pprintSignature(ident, None)
/**
* Processors for declarations which all have the same shape
*/
def DefProc(definitionLabel: String)(cond: PartialFunction[G#Tree, G#Name]) =
(code: String, name: String, tree: G#Tree) =>
cond.lift(tree).map{ name =>
Preprocessor.Expanded(
code,
Seq(definedStr(definitionLabel, Parsers.backtickWrap(name.decoded)))
)
}
val ObjectDef = DefProc("object"){case m: G#ModuleDef => m.name}
val ClassDef = DefProc("class"){ case m: G#ClassDef if !m.mods.isTrait => m.name }
val TraitDef = DefProc("trait"){ case m: G#ClassDef if m.mods.isTrait => m.name }
val DefDef = DefProc("function"){ case m: G#DefDef => m.name }
val TypeDef = DefProc("type"){ case m: G#TypeDef => m.name }
val PatVarDef = Processor { case (name, code, t: G#ValDef) =>
Expanded(
//Only wrap rhs in function if it is not a function
//Wrapping functions causes type inference errors.
code,
// Try to leave out all synthetics; we don't actually have proper
// synthetic flags right now, because we're dumb-parsing it and not putting
// it through a full compilation
if (t.name.decoded.contains("$")) Nil
else if (!t.mods.hasFlag(Flags.LAZY)) Seq(pprint(Parsers.backtickWrap(t.name.decoded)))
else Seq(s"""${pprintSignature(Parsers.backtickWrap(t.name.decoded), Some(""))}""")
)
}
val Import = Processor{
case (name, code, tree: G#Import) =>
val Array(keyword, body) = code.split(" ", 2)
val tq = "\"\"\""
Expanded(code, Seq(
s"""
_root_.ammonite
.repl
.frontend
.ReplBridge
.repl
.Internal
.printImport($tq$body$tq)
"""
))
}
val Expr = Processor{
//Expressions are lifted to anon function applications so they will be JITed
case (name, code, tree) => Expanded(s"val $name = $code", Seq(pprint(name)))
}
val decls = Seq[(String, String, G#Tree) => Option[Preprocessor.Expanded]](
ObjectDef, ClassDef, TraitDef, DefDef, TypeDef, PatVarDef, Import, Expr
)
def expandStatements(stmts: Seq[String],
wrapperIndex: String): Res[Preprocessor.Expanded] = {
val unwrapped = stmts.flatMap{x => Parsers.unwrapBlock(x) match {
case Some(contents) =>
Parsers.split(contents).get.get.value
case None => Seq(x)
}}
unwrapped match{
case Nil => Res.Skip
case postSplit =>
complete(stmts.mkString(""), wrapperIndex, postSplit)
}
}
def complete(code: String, resultIndex: String, postSplit: Seq[String]) = {
val reParsed = postSplit.map(p => (parse(p), p))
val errors = reParsed.collect{case (Left(e), _) => e }
if (errors.length != 0) Res.Failure(None, errors.mkString("\n"))
else {
val allDecls = for {
((Right(trees), code), i) <- reParsed.zipWithIndex if (trees.nonEmpty)
} yield {
// Suffix the name of the result variable with the index of
// the tree if there is more than one statement in this command
val suffix = if (reParsed.length > 1) "_" + i else ""
def handleTree(t: G#Tree) = {
decls.iterator.flatMap(_.apply(code, "res" + resultIndex + suffix, t)).next()
}
trees match {
case Seq(tree) => handleTree(tree)
// This handles the multi-import case `import a.b, c.d`
case trees if trees.forall(_.isInstanceOf[G#Import]) => handleTree(trees(0))
// AFAIK this can only happen for pattern-matching multi-assignment,
// which for some reason parse into a list of statements. In such a
// scenario, aggregate all their printers, but only output the code once
case trees =>
val printers = for {
tree <- trees
if tree.isInstanceOf[G#ValDef]
Preprocessor.Expanded(_, printers) = handleTree(tree)
printer <- printers
} yield printer
Preprocessor.Expanded(code, printers)
}
}
val Seq(first, rest@_*) = allDecls
val allDeclsWithComments = Expanded(first.code, first.printer) +: rest
Res(
allDeclsWithComments.reduceOption { (a, b) =>
Expanded(
a.code + ";" + b.code,
a.printer ++ b.printer
)
},
"Don't know how to handle " + code
)
}
}
}
def importBlock(importData: Imports) = {
Timer("importBlock 0")
// Group the remaining imports into sliding groups according to their
// prefix, while still maintaining their ordering
val grouped = mutable.Buffer[mutable.Buffer[ImportData]]()
for(data <- importData.value){
if (grouped.isEmpty) grouped.append(mutable.Buffer(data))
else {
val last = grouped.last.last
// Start a new import if we're importing from somewhere else, or
// we're importing the same thing from the same place but aliasing
// it to a different name, since you can't import the same thing
// twice in a single import statement
val startNewImport =
last.prefix != data.prefix || grouped.last.exists(_.fromName == data.fromName)
if (startNewImport) grouped.append(mutable.Buffer(data))
else grouped.last.append(data)
}
}
// Stringify everything
val out = for(group <- grouped) yield {
val printedGroup = for(item <- group) yield{
if (item.fromName == item.toName) Parsers.backtickWrap(item.fromName)
else s"${Parsers.backtickWrap(item.fromName)} => ${Parsers.backtickWrap(item.toName)}"
}
"import " + group.head.prefix + ".{\n " + printedGroup.mkString(",\n ") + "\n}\n"
}
val res = out.mkString
Timer("importBlock 1")
res
}
def wrapCode(pkgName: String,
indexedWrapperName: String,
code: String,
printCode: String,
imports: Imports) = {
val topWrapper = s"""
package $pkgName
${importBlock(imports)}
object $indexedWrapperName{\n"""
val bottomWrapper = s"""\ndef $$main() = { $printCode }
override def toString = "$indexedWrapperName"
}
"""
val importsLen = topWrapper.length
(topWrapper + code + bottomWrapper, importsLen)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy