All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ammonite.compiler.Parsers.scala Maven / Gradle / Ivy

The newest version!
package ammonite.compiler

import java.util.Map

import ammonite.compiler.iface.{Compiler => _, Parser => IParser, _}
import ammonite.compiler.internal.CompilerHelper
import ammonite.util.ImportTree
import ammonite.util.Util.CodeSource

import dotty.tools.dotc
import dotc.ast.untpd
import dotc.CompilationUnit
import dotc.core.Contexts.{ctx, Context, ContextBase}
import dotc.parsing.Tokens
import dotc.util.SourceFile

import scala.collection.mutable

class Parsers extends IParser {

  // FIXME Get via Compiler?
  private lazy val initCtx: Context =
    (new ContextBase).initialCtx

  // From
  // https://github.com/lampepfl/dotty/blob/3.0.0-M3/
  //   compiler/src/dotty/tools/repl/ParseResult.scala/#L115-L120
  private def parseStats(using Context): List[untpd.Tree] = {
    val parser = new DottyParser(ctx.source)
    val stats = parser.blockStatSeq()
    parser.accept(Tokens.EOF)
    stats
  }

  // Adapted from
  // https://github.com/lampepfl/dotty/blob/3.0.0-M3/
  //   compiler/src/dotty/tools/repl/ParseResult.scala/#L163-L184
  /** Check if the input is incomplete.
   *
   *  This can be used in order to check if a newline can be inserted without
   *  having to evaluate the expression.
   */
  private def isComplete(sourceCode: String)(using Context): Boolean =
    val reporter = Compiler.newStoreReporter()
    val source   = SourceFile.virtual("", sourceCode, maybeIncomplete = true)
    val unit     = CompilationUnit(source, mustExist = false)
    val localCtx = ctx.fresh
                      .setCompilationUnit(unit)
                      .setReporter(reporter)
    var needsMore = false
    reporter.withIncompleteHandler((_, _) => needsMore = true) {
      parseStats(using localCtx)
    }
    reporter.hasErrors || !needsMore

  private val BlockPat = """(?s)^\s*\{(.*)\}\s*$""".r

  def split(
    code: String,
    ignoreIncomplete: Boolean = true,
    fileName: String = "(console)"
  ): Option[Either[String, Seq[String]]] =
    doSplit(code, ignoreIncomplete, fileName)
      .map(_.map(_.map(_._2)))

  private def doSplit(
    code: String,
    ignoreIncomplete: Boolean,
    fileName: String
  ): Option[Either[String, Seq[(Int, String)]]] = {
    val code0 = code match {
      case BlockPat(wrapped) => wrapped
      case _ => code
    }

    given Context = initCtx
    val reporter = Compiler.newStoreReporter()
    val source   = SourceFile.virtual("", code0, maybeIncomplete = true)
    val unit     = CompilationUnit(source, mustExist = false)
    val localCtx = ctx.fresh
                      .setCompilationUnit(unit)
                      .setReporter(reporter)
    var needsMore = false
    val stats = reporter.withIncompleteHandler((_, _) => needsMore = true) {
      parseStats(using localCtx)
    }

    val nl = System.lineSeparator
    def errors = reporter
      .removeBufferedMessages
      .map { e =>
        val maybeMsg = scala.util.Try {
          CompilerHelper.messageAndPos(Compiler.messageRenderer, e)
        }
        Compiler.messageRenderer.stripColor(maybeMsg.getOrElse("???"))
      }
      .mkString(nl)

    if (reporter.hasErrors)
      Some(Left(s"$fileName$nl$errors"))
    else if (needsMore)
      None
    else {
      val startIndices = stats.toArray.map(_.startPos(using localCtx).point)
      def startEndIndices = startIndices.iterator
        .zip(startIndices.iterator.drop(1) ++ Iterator(code0.length))
      val stmts = startEndIndices.map {
        case (start, end) =>
          code0.substring(start, end)
      }.toVector
      val statsAndStmts = stats.zip(stmts).zip(startIndices).iterator

      val stmts0 = new mutable.ListBuffer[(Int, String)]
      var current = Option.empty[(untpd.Tree, String, Int)]
      while (statsAndStmts.hasNext) {
        val next = statsAndStmts.next()
        val ((nextStat, nextStmt), nextIdx) = next
        (current, nextStat) match {
          case (Some((_: untpd.Import, stmt, idx)), _: untpd.Import)
              if stmt.startsWith("import ") && !nextStmt.startsWith("import ") =>
            current = Some((nextStat, stmt + nextStmt, idx))
          case _ =>
            current.foreach { case (_, stmt, idx) => stmts0.+=((idx, stmt)) }
            current = Some((nextStat, nextStmt, nextIdx))
        }
      }
      current.foreach { case (_, stmt, idx) => stmts0.+=((idx, stmt)) }

      Some(Right(stmts0.toList))
    }
  }

  private def importExprs(i: untpd.Import): Seq[String] = {
    def exprs(t: untpd.Tree): List[String] =
      t match {
        case untpd.Ident(name) => name.decode.toString :: Nil
        case untpd.Select(qual, name) => name.decode.toString :: exprs(qual)
        case _ => Nil // ???
      }
    exprs(i.expr).reverse
  }

  def importHooks(statement: String): (String, Seq[ImportTree]) = {

    given Context = initCtx
    val reporter = Compiler.newStoreReporter()
    val source   = SourceFile.virtual("", statement, maybeIncomplete = true)
    val unit     = CompilationUnit(source, mustExist = false)
    val localCtx = ctx.fresh
                      .setCompilationUnit(unit)
                      .setReporter(reporter)
    var needsMore = false
    val stats = reporter.withIncompleteHandler((_, _) => needsMore = true) {
      parseStats(using localCtx)
    }

    if (reporter.hasErrors || needsMore)
      (statement, Nil)
    else {
      var updatedStatement = statement
      var importTrees = Array.newBuilder[ImportTree]
      stats.foreach {
        case i: untpd.Import =>
          val exprs = importExprs(i)
          if (exprs.headOption.exists(_.startsWith("$"))) {
            val start = i.startPos.point
            val end = {
              var initialEnd = i.endPos.point
              // kind of meh
              // In statements like 'import $file.foo.{a, b}', endPos points at 'b' rather than '}',
              // so we work around that here.
              if (updatedStatement.iterator.drop(start).take(initialEnd - start).contains('{')) {
                while (updatedStatement.length > initialEnd &&
                        updatedStatement.charAt(initialEnd).isWhitespace)
                  initialEnd = initialEnd + 1
                if (updatedStatement.length > initialEnd &&
                      updatedStatement.charAt(initialEnd) == '}')
                  initialEnd = initialEnd + 1
              }
              initialEnd
            }
            val selectors = i.selectors.map { sel =>
              val from = sel.name.decode.toString
              val to = sel.rename.decode.toString
              from -> Some(to).filter(_ != from)
            }
            val updatedImport = updatedStatement.substring(start, end).takeWhile(_ != '.') + ".$"
            updatedStatement = updatedStatement.patch(
              start,
              updatedImport + (" ") * (end - start - updatedImport.length),
              end - start
            )

            val prefixLen = if (updatedStatement.startsWith("import ")) "import ".length else 0
            importTrees += ImportTree(
              exprs,
              Some(selectors).filter(_.nonEmpty),
              start + prefixLen, end
            )
          }
        case _ =>
      }
      (updatedStatement, importTrees.result)
    }
  }

  def parseImportHooksWithIndices(
    source: CodeSource,
    statements: Seq[(Int, String)]
  ): (Seq[String], Seq[ImportTree]) = {

    val (updatedStatements, trees) = statements.map {
      case (startIdx, stmt) =>
        val (hookedStmts, parsedTrees) = importHooks(stmt)

        val updatedParsedTrees = parsedTrees.map { importTree =>
          importTree.copy(
            start = startIdx + importTree.start,
            end = startIdx + importTree.end
          )
        }

        (hookedStmts, updatedParsedTrees)
    }.unzip

    (updatedStatements, trees.flatten)
  }

  private val scriptSplitPattern = "(?m)^\\s*@[\\s\\n\\r]+".r

  def splitScript(
    rawCode: String,
    fileName: String
  ): Either[String, IndexedSeq[(String, Seq[String])]] =
    scriptBlocksWithStartIndices(rawCode, fileName)
      .left.map(_.getMessage)
      .map(_.map(b => (b.ncomment, b.codeWithStartIndices.map(_._2))).toVector)

  def scriptBlocksWithStartIndices(
    rawCode: String,
    fileName: String
  ): Either[IParser.ScriptSplittingError, Seq[IParser.ScriptBlock]] = {

    val bounds = {
      def allBounds = Iterator(0) ++ scriptSplitPattern.findAllMatchIn(rawCode).flatMap { m =>
        Iterator(m.start, m.end)
      } ++ Iterator(rawCode.length)
      allBounds
        .grouped(2)
        .map { case Seq(start, end) => (start, end) }
        .toVector
    }

    val blocks = bounds.zipWithIndex.map {
      case ((start, end), idx) =>
        val blockCode = rawCode.substring(start, end)
        doSplit(blockCode, false, fileName) match {
          case None =>
            Right(
              IParser.ScriptBlock(
                start,
                "",
                Seq((start, blockCode))
              )
            )
          case Some(Left(err)) => Left(err)
          case Some(Right(stmts)) =>
            Right(
              IParser.ScriptBlock(
                start,
                blockCode.take(stmts.headOption.fold(0)(_._1)),
                stmts.map { case (idx, stmt) => (idx + start, stmt) }
              )
            )
        }
    }

    val errors = blocks.collect { case Left(err) => err }
    if (errors.isEmpty)
      Right(blocks.collect { case Right(elem) => elem })
    else
      Left(new IParser.ScriptSplittingError(errors.mkString(System.lineSeparator)))
  }

  def defaultHighlight(buffer: Vector[Char],
                       comment: fansi.Attrs,
                       `type`: fansi.Attrs,
                       literal: fansi.Attrs,
                       keyword: fansi.Attrs,
                       notImplemented: fansi.Attrs,
                       reset: fansi.Attrs): Vector[Char] = {
    val valDef = reset
    val annotation = reset
    new SyntaxHighlighting(
      reset,
      comment,
      keyword,
      valDef,
      literal,
      `type`,
      annotation,
      notImplemented
    ).highlight(buffer.mkString)(using initCtx).toVector
  }

  def isObjDef(code: String): Boolean =
    false // TODO

  def highlightIndices[T](buffer: Vector[Char],
                          ruleColors: PartialFunction[String, T],
                          endColor: T): Seq[(Int, T)] =
    // Not available in Scala 3
    Seq((0, endColor), (999999999, endColor))
}

object Parsers extends Parsers




© 2015 - 2024 Weber Informatics LLC | Privacy Policy