ammonite.compiler.Compiler.scala Maven / Gradle / Ivy
The newest version!
package ammonite.compiler
import java.net.URL
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Path, Paths}
import java.io.{ByteArrayInputStream, OutputStream}
import ammonite.compiler.iface.{
Compiler => ICompiler,
CompilerBuilder => ICompilerBuilder,
CompilerLifecycleManager => ICompilerLifecycleManager,
Preprocessor => IPreprocessor,
_
}
import ammonite.compiler.internal.CompilerHelper
import ammonite.util.{ImportData, Imports, PositionOffsetConversion, Printer}
import ammonite.util.Util.newLine
import dotty.tools.dotc
import dotc.{CompilationUnit, Compiler => DottyCompiler, Run, ScalacCommand}
import dotc.ast.{tpd, untpd}
import dotc.ast.Positioned
import dotc.classpath
import dotc.config.{CompilerCommand, JavaPlatform}
import dotc.core.Contexts._
import dotc.core.{Flags, MacroClassLoader, Mode}
import dotc.core.Comments.{ContextDoc, ContextDocstrings}
import dotc.core.Phases.{Phase, unfusedPhases}
import dotc.core.Symbols.{defn, Symbol}
import dotc.fromtasty.TastyFileUtil
import dotc.interactive.Completion
import dotc.report
import dotc.reporting
import dotc.semanticdb
import dotc.transform.{PostTyper, Staging}
import dotc.util.{Property, SourceFile, SourcePosition}
import dotc.util.Spans.Span
import dotty.tools.io.{
AbstractFile,
ClassPath,
ClassRepresentation,
File,
VirtualDirectory,
VirtualFile,
PlainFile
}
import dotty.tools.repl.CollectTopLevelImports
class Compiler(
dynamicClassPath: AbstractFile,
initialClassPath: Seq[URL],
classPath: Seq[URL],
macroClassLoader: ClassLoader,
whiteList: Set[Seq[String]],
dependencyCompleteOpt: => Option[String => (Int, Seq[String])] = None,
contextInit: FreshContext => Unit = _ => (),
settings: Seq[String] = Nil,
reporter: Option[ICompilerBuilder.Message => Unit] = None
) extends ICompiler:
self =>
import Compiler.{enumerateVdFiles, files}
private val outputDir = new VirtualDirectory("(memory)")
private def initCtx: Context =
val base: ContextBase =
new ContextBase:
override protected def newPlatform(using Context) =
new JavaPlatform:
private var classPath0: ClassPath = null
override def classPath(using Context) =
if (classPath0 == null)
classPath0 = classpath.AggregateClassPath(Seq(
asDottyClassPath(initialClassPath, whiteListed = true),
asDottyClassPath(self.classPath),
classpath.ClassPathFactory.newClassPath(dynamicClassPath)
))
classPath0
base.initialCtx
private def sourcesRequired = false
private lazy val MacroClassLoaderKey =
val cls = macroClassLoader.loadClass("dotty.tools.dotc.core.MacroClassLoader$")
val fld = cls.getDeclaredField("MacroClassLoaderKey")
fld.setAccessible(true)
fld.get(null).asInstanceOf[Property.Key[ClassLoader]]
// Originally adapted from
// https://github.com/lampepfl/dotty/blob/3.0.0-M3/
// compiler/src/dotty/tools/dotc/Driver.scala/#L67-L81
private def setup(args: Array[String], rootCtx: Context): (List[String], Context) =
given ictx: FreshContext = rootCtx.fresh
val summary = ScalacCommand.distill(args, ictx.settings)(ictx.settingsState)(using ictx)
ictx.setSettings(summary.sstate)
ictx.setProperty(MacroClassLoaderKey, macroClassLoader)
Positioned.init
if !ictx.settings.YdropComments.value then
ictx.setProperty(ContextDoc, new ContextDocstrings)
val fileNamesOpt = ScalacCommand.checkUsage(
summary,
sourcesRequired
)(using ictx.settings)(using ictx.settingsState)
val fileNames = fileNamesOpt.getOrElse {
throw new Exception("Error initializing compiler")
}
contextInit(ictx)
(fileNames, ictx)
private def asDottyClassPath(
cp: Seq[URL],
whiteListed: Boolean = false
)(using Context): ClassPath =
val (dirs, jars) = cp.partition { url =>
url.getProtocol == "file" && Files.isDirectory(Paths.get(url.toURI))
}
val dirsCp = dirs.map(u => classpath.ClassPathFactory.newClassPath(AbstractFile.getURL(u)))
val jarsCp = jars
.filter(ammonite.util.Classpath.canBeOpenedAsJar)
.map(u => classpath.ZipAndJarClassPathFactory.create(AbstractFile.getURL(u)))
if (whiteListed) new dotty.ammonite.compiler.WhiteListClasspath(dirsCp ++ jarsCp, whiteList)
else classpath.AggregateClassPath(dirsCp ++ jarsCp)
// Originally adapted from
// https://github.com/lampepfl/dotty/blob/3.0.0-M3/
// compiler/src/dotty/tools/repl/ReplDriver.scala/#L67-L73
/** Create a fresh and initialized context with IDE mode enabled */
lazy val initialCtx =
val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive)
rootCtx.setSetting(rootCtx.settings.YcookComments, true)
// FIXME Disabled for the tests to pass
rootCtx.setSetting(rootCtx.settings.color, "never")
// FIXME We lose possible custom openStream implementations on the URLs of initialClassPath and
// classPath
val initialClassPath0 = initialClassPath
// .filter(!_.toURI.toASCIIString.contains("fansi_2.13"))
// .filter(!_.toURI.toASCIIString.contains("pprint_2.13"))
rootCtx.setSetting(rootCtx.settings.outputDir, outputDir)
val (_, ictx) = setup(settings.toArray, rootCtx)
ictx.base.initialize()(using ictx)
ictx
private var userCodeNestingLevel = -1
// Originally adapted from
// https://github.com/lampepfl/dotty/blob/3.0.0-M3/
// compiler/src/dotty/tools/repl/ReplCompiler.scala/#L34-L39
val compiler =
new DottyCompiler:
override protected def frontendPhases: List[List[Phase]] =
CompilerHelper.frontEndPhases ++
List(
List(new semanticdb.ExtractSemanticDB.ExtractSemanticInfo),
List(new AmmonitePhase(userCodeNestingLevel, userCodeNestingLevel == 2)),
List(new PostTyper)
)
// Originally adapted from
// https://github.com/lampepfl/dotty/blob/3.0.0-M3/
// compiler/src/dotty/tools/repl/Rendering.scala/#L97-L103
/** Formats errors using the `messageRenderer` */
private def formatError(dia: reporting.Diagnostic)(implicit ctx: Context): reporting.Diagnostic =
val renderedMessage = CompilerHelper.messageAndPos(Compiler.messageRenderer, dia)
new reporting.Diagnostic(
reporting.NoExplanation(renderedMessage),
dia.pos,
dia.level
)
def compile(
src: Array[Byte],
printer: Printer,
importsLen: Int,
userCodeNestingLevel: Int,
fileName: String
): Option[ICompiler.Output] =
// println(s"Compiling\n${new String(src, StandardCharsets.UTF_8)}\n")
self.userCodeNestingLevel = userCodeNestingLevel
val reporter0 = reporter match {
case None =>
Compiler.newStoreReporter()
case Some(rep) =>
val simpleReporter = new dotc.interfaces.SimpleReporter {
def report(diag: dotc.interfaces.Diagnostic) = {
val severity = diag.level match {
case dotc.interfaces.Diagnostic.ERROR => "ERROR"
case dotc.interfaces.Diagnostic.WARNING => "WARNING"
case dotc.interfaces.Diagnostic.INFO => "INFO"
case _ => "INFO" // should not happen
}
val pos = Some(diag.position).filter(_.isPresent).map(_.get)
val start = pos.fold(0)(_.start)
val end = pos.fold(new String(src, "UTF-8").length)(_.end)
val msg = ICompilerBuilder.Message(severity, start, end, diag.message)
rep(msg)
}
}
reporting.Reporter.fromSimpleReporter(simpleReporter)
}
val run = new Run(compiler, initialCtx.fresh.setReporter(reporter0))
val semanticDbEnabled = run.runContext.settings.Xsemanticdb.value(using run.runContext)
val sourceFile =
if (semanticDbEnabled) {
// semanticdb needs the sources to be written on disk, so we assume they're there already
val root = run.runContext.settings.sourceroot.value(using run.runContext)
SourceFile(AbstractFile.getFile(Paths.get(root).resolve(fileName)), "UTF-8")
}else{
val vf = new VirtualFile(fileName.split("/", -1).last, fileName)
val out = vf.output
out.write(src)
out.close()
new SourceFile(vf, new String(src, "UTF-8").toCharArray)
}
implicit val ctx: Context = run.runContext.withSource(sourceFile)
val unit =
new CompilationUnit(ctx.source, null):
// as done in
// https://github.com/lampepfl/dotty/blob/3.0.0-M3/
// compiler/src/dotty/tools/repl/ReplCompillationUnit.scala/#L8
override def isSuspendable: Boolean = false
ctx
.run
.compileUnits(unit :: Nil)
val result =
if (ctx.reporter.hasErrors) Left(reporter.fold(ctx.reporter.removeBufferedMessages)(_ => Nil))
else Right((reporter.fold(ctx.reporter.removeBufferedMessages)(_ => Nil), unit))
def formatDiagnostics(diagnostics: List[reporting.Diagnostic]): List[String] = {
val scalaPosToScPos = PositionOffsetConversion.scalaPosToScPos(
new String(src).drop(importsLen),
0,
0,
new String(src),
importsLen
)
val scFile = new SourceFile(sourceFile.file, sourceFile.content().drop(importsLen))
def scalaOffsetToScOffset(scalaOffset: Int): Option[Int] =
scalaPosToScPos(sourceFile.offsetToLine(scalaOffset), sourceFile.column(scalaOffset)).map {
case (scLine, scCol) => scFile.lineToOffset(scLine) + scCol
}
def scalaSpanToScSpan(scalaSpan: Span): Option[Span] =
for {
scStart <- scalaOffsetToScOffset(scalaSpan.start)
scEnd <- scalaOffsetToScOffset(scalaSpan.end)
scPoint <- scalaOffsetToScOffset(scalaSpan.point)
} yield Span(scStart, scEnd, scPoint)
def scalaSourcePosToScSourcePos(sourcePos: SourcePosition): Option[SourcePosition] =
if (sourcePos.source == sourceFile)
scalaSpanToScSpan(sourcePos.span).map { scSpan =>
SourcePosition(scFile, scSpan, sourcePos.outer)
}
else
None
def scalaDiagnosticToScDiagnostic(diag: reporting.Diagnostic): Option[reporting.Diagnostic] =
scalaSourcePosToScSourcePos(diag.pos).map { scPos =>
new reporting.Diagnostic(diag.msg, scPos, diag.level)
}
diagnostics
.map(d => scalaDiagnosticToScDiagnostic(d).getOrElse(d))
.map(formatError)
.map(_.msg.toString)
}
result match {
case Left(errors) =>
for (err <- formatDiagnostics(errors))
printer.error(err)
None
case Right((warnings, unit)) =>
for (warn <- formatDiagnostics(warnings))
printer.warning(warn)
val newImports = unfusedPhases.collectFirst {
case p: AmmonitePhase => p.importData
}.getOrElse(Seq.empty[ImportData])
val usedEarlierDefinitions = unfusedPhases.collectFirst {
case p: AmmonitePhase => p.usedEarlierDefinitions
}.getOrElse(Seq.empty[String])
val fileCount = enumerateVdFiles(outputDir).length
val classes = files(outputDir).toArray
// outputDir is None here, dynamicClassPath should already correspond to an on-disk directory
Compiler.addToClasspath(classes, dynamicClassPath, None)
outputDir.clear()
val lineShift = PositionOffsetConversion.offsetToPos(new String(src)).apply(importsLen).line
val mappings = Map(sourceFile.file.name -> (sourceFile.file.name, -lineShift))
val postProcessedClasses = classes.toVector.map {
case (path, byteCode) if path.endsWith(".class") =>
val updatedByteCodeOpt = AsmPositionUpdater.postProcess(
mappings,
new ByteArrayInputStream(byteCode)
)
(path, updatedByteCodeOpt.getOrElse(byteCode))
case other =>
other
}
val output = ICompiler.Output(
postProcessedClasses,
Imports(newImports),
Some(usedEarlierDefinitions)
)
Some(output)
}
def objCompiler = compiler
def preprocessor(fileName: String, markGeneratedSections: Boolean): IPreprocessor =
new Preprocessor(
initialCtx.fresh.withSource(SourceFile.virtual(fileName, "")),
markGeneratedSections: Boolean
)
// Originally adapted from
// https://github.com/lampepfl/dotty/blob/3.0.0-M3/
// compiler/src/dotty/tools/repl/ReplCompiler.scala/#L224-L286
def tryTypeCheck(
src: Array[Byte],
fileName: String
) =
val sourceFile = SourceFile.virtual(fileName, new String(src, StandardCharsets.UTF_8))
val reporter0 = Compiler.newStoreReporter()
val run = new Run(
compiler,
initialCtx.fresh
.addMode(Mode.ReadPositions | Mode.Interactive)
.setReporter(reporter0)
.setSetting(initialCtx.settings.YstopAfter, List("typer"))
)
implicit val ctx: Context = run.runContext.withSource(sourceFile)
val unit =
new CompilationUnit(ctx.source, null):
override def isSuspendable: Boolean = false
ctx
.run
.compileUnits(unit :: Nil, ctx)
(unit.tpdTree, ctx)
def complete(
offset: Int,
previousImports: String,
snippet: String
): (Int, Seq[String], Seq[String]) = {
val prefix = previousImports + newLine +
"object AutocompleteWrapper{ val expr: _root_.scala.Unit = {" + newLine
val suffix = newLine + "()}}"
val allCode = prefix + snippet + suffix
val index = offset + prefix.length
// Originally based on
// https://github.com/lampepfl/dotty/blob/3.0.0-M1/
// compiler/src/dotty/tools/repl/ReplDriver.scala/#L179-L191
val (tree, ctx0) = tryTypeCheck(allCode.getBytes("UTF-8"), "")
val ctx = ctx0.fresh
val file = SourceFile.virtual("", allCode, maybeIncomplete = true)
val unit = CompilationUnit(file)(using ctx)
unit.tpdTree = {
given Context = ctx
import tpd._
tree match {
case PackageDef(_, p) =>
p.collectFirst {
case TypeDef(_, tmpl: Template) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd }
.getOrElse(???)
}.getOrElse(???)
case _ => ???
}
}
val ctx1 = ctx.fresh.setCompilationUnit(unit)
val srcPos = SourcePosition(file, Span(index))
val (start, completions) = dotty.ammonite.compiler.AmmCompletion.completions(
srcPos,
dependencyCompleteOpt = dependencyCompleteOpt,
enableDeep = false
)(using ctx1)
val blacklistedPackages = Set("shaded")
def deepCompletion(name: String): List[String] = {
given Context = ctx1
def rec(t: Symbol): Seq[Symbol] = {
if (blacklistedPackages(t.name.toString))
Nil
else {
val children =
if (t.is(Flags.Package) || t.is(Flags.PackageVal) || t.is(Flags.PackageClass))
t.denot.info.allMembers.map(_.symbol).filter(_ != t).flatMap(rec)
else Nil
t +: children.toSeq
}
}
for {
member <- defn.RootClass.denot.info.allMembers.map(_.symbol).toList
sym <- rec(member)
// Scala 2 comment: sketchy name munging because I don't know how to do this properly
// Note lack of back-quoting support.
strippedName = sym.name.toString.stripPrefix("package$").stripSuffix("$")
if strippedName.startsWith(name)
(pref, _) = sym.fullName.toString.splitAt(sym.fullName.toString.lastIndexOf('.') + 1)
out = pref + strippedName
if out != ""
} yield out
}
def blacklisted(s: Symbol) = {
given Context = ctx1
val blacklist = Set(
"scala.Predef.any2stringadd.+",
"scala.Any.##",
"java.lang.Object.##",
"scala.",
"scala.",
"scala.",
"scala.",
"scala.Predef.StringFormat.formatted",
"scala.Predef.Ensuring.ensuring",
"scala.Predef.ArrowAssoc.->",
"scala.Predef.ArrowAssoc.→",
"java.lang.Object.synchronized",
"java.lang.Object.ne",
"java.lang.Object.eq",
"java.lang.Object.wait",
"java.lang.Object.notifyAll",
"java.lang.Object.notify",
"java.lang.Object.clone",
"java.lang.Object.finalize"
)
blacklist(s.fullName.toString) ||
s.isOneOf(Flags.GivenOrImplicit) ||
// Cache objects, which you should probably never need to
// access directly, and apart from that have annoyingly long names
"cache[a-f0-9]{32}".r.findPrefixMatchOf(s.name.decode.toString).isDefined ||
// s.isDeprecated ||
s.name.decode.toString == "" ||
s.name.decode.toString.contains('$')
}
val filteredCompletions = completions.filter { c =>
c.symbols.isEmpty || c.symbols.exists(!blacklisted(_))
}
val signatures = {
given Context = ctx1
for {
c <- filteredCompletions
s <- c.symbols
isMethod = s.denot.is(Flags.Method)
if isMethod
} yield s"def ${s.name}${s.denot.info.widenTermRefExpr.show}"
}
(start - prefix.length, filteredCompletions.map(_.label.replace(".package$.", ".")), signatures)
}
object Compiler:
/** Create empty outer store reporter */
def newStoreReporter(): reporting.StoreReporter =
new reporting.StoreReporter(null)
with reporting.UniqueMessagePositions with reporting.HideNonSensicalMessages
private def enumerateVdFiles(d: VirtualDirectory): Iterator[AbstractFile] =
val (subs, files) = d.iterator.partition(_.isDirectory)
files ++ subs.map(_.asInstanceOf[VirtualDirectory]).flatMap(enumerateVdFiles)
private def files(d: VirtualDirectory): Iterator[(String, Array[Byte])] =
for (x <- enumerateVdFiles(d) if x.name.endsWith(".class") || x.name.endsWith(".tasty")) yield {
val segments = x.path.split("/").toList.tail
(x.path.stripPrefix("(memory)/"), x.toByteArray)
}
private def writeDeep(
d: AbstractFile,
path: List[String]
): OutputStream = path match {
case head :: Nil => d.fileNamed(path.head).output
case head :: rest =>
writeDeep(
d.subdirectoryNamed(head), //.asInstanceOf[VirtualDirectory],
rest
)
// We should never write to an empty path, and one of the above cases
// should catch this and return before getting here
case Nil => ???
}
def addToClasspath(classFiles: Traversable[(String, Array[Byte])],
dynamicClasspath: AbstractFile,
outputDir: Option[Path]): Unit = {
val outputDir0 = outputDir.map(os.Path(_, os.pwd))
for((name, bytes) <- classFiles){
val elems = name.split('/').toList
val output = writeDeep(dynamicClasspath, elems)
output.write(bytes)
output.close()
for (dir <- outputDir0)
os.write.over(dir / elems, bytes, createFolders = true)
}
}
private[compiler] val messageRenderer =
new reporting.MessageRendering {}