All Downloads are FREE. Search and download functionalities are using the official Maven repository.

polynote.kernel.interpreter.scal.ScalaInterpreter.scala Maven / Gradle / Ivy

The newest version!
package polynote.kernel
package interpreter
package scal

import java.lang.reflect.{Constructor, InvocationTargetException}

import scala.reflect.internal.util.{NoPosition, Position}
import polynote.messages.CellID
import zio.blocking.{Blocking, effectBlockingInterrupt}
import zio.{RIO, Task, ZIO}
import ScalaInterpreter.{addPositionUpdates, captureLastExpression}
import polynote.kernel.environment.{CurrentNotebook, CurrentRuntime, CurrentTask}
import polynote.kernel.task.TaskManager

class ScalaInterpreter private[scal] (
  val scalaCompiler: ScalaCompiler,
  indexer: ClassIndexer
) extends Interpreter {
  import scalaCompiler.{CellCode, global, Imports}
  import global.{Tree, ValDef, TermName, Modifiers, EmptyTree, TypeTree, Import, Name, Type, Quasiquote, typeOf, atPos, NoType}

  // Interpreter interface methods //

  override def run(code: String, state: State): RIO[InterpreterEnv, State] = for {
    collectedState <- injectState(collectState(state))
    valDefs         = collectedState.values.mapValues(_._1).values.toList
    cellCode       <- scalaCompiler.cellCode(s"Cell${}", code, collectedState.prevCells, valDefs, collectedState.imports)
    inputNames      =
    inputs          =
    cls            <- scalaCompiler.compileCell(cellCode)
    resultInstance <- => runClass(cls, cellCode, inputs, state).map(Some(_))).getOrElse(ZIO.succeed(None))
    resultValues   <- => getResultValues(, cellCode, resultInstance)).getOrElse(ZIO.succeed(Nil))
  } yield ScalaCellState(, state.prev, resultValues, cellCode, resultInstance)

  override def completionsAt(code: String, pos: Int, state: State): RIO[Blocking, List[Completion]] = for {
    collectedState   <- injectState(collectState(state)).provideLayer(CurrentRuntime.noRuntime)
    valDefs           = collectedState.values.mapValues(_._1).values.toList
    cellCode         <- scalaCompiler.cellCode(s"Cell${}", s"\n\n${code.substring(0, math.min(pos, code.length))}  ", collectedState.prevCells, valDefs, collectedState.imports, strictParse = false)
    completions      <- completer.completions(cellCode, pos + 2)
  } yield completions

  override def parametersAt(code: String, pos: Int, state: State): RIO[Blocking, Option[Signatures]] = for {
    collectedState <- injectState(collectState(state)).provideLayer(CurrentRuntime.noRuntime)
    valDefs         = collectedState.values.mapValues(_._1).values.toList
    cellCode       <- scalaCompiler.cellCode(s"Cell${}", s"\n\n$code  ", collectedState.prevCells, valDefs, collectedState.imports, strictParse = false)
    hints          <- completer.paramHints(cellCode, pos + 2)
  } yield hints

  override def init(state: State): RIO[InterpreterEnv, State] = ZIO.succeed(state)

  override def shutdown(): Task[Unit] = ZIO.unit

  // Overrideable scala-specific stuff //

    * Overrideable method to inject some pre-defined state (values and imports) into the initial cell. The base implementation
    * injects the `kernel` value, making it available to the notebook. Override to inject more imports or values.
  protected def injectState(collectedState: CollectedState): RIO[CurrentRuntime, CollectedState] = {
      kernelRuntime =>
        collectedState.copy(values = collectedState.values + ( -> (runtimeValDef, kernelRuntime: Any)))

    * Transforms the Scala statements, to inject additional things like updating the execution status and ensuring
    * the last expression is captured. Override to add additional transformations (don't forget to call super.transformCode!)
  protected def transformCode(code: List[Tree]): List[Tree] = {

  // Protected structures for subclass implementors //

    * Container for information about available values, etc from previous cells or predefined things
  protected case class CollectedState(
    values: Map[String, (ValDef, Any)] = Map.empty,
    imports: Imports = Imports(),
    prevCells: List[CellCode] = Nil)

  // Private scala-specific stuff //

  private val completer = ScalaCompleter(scalaCompiler, indexer)

  // for testing reliably
  private[scal] def awaitIndexer = indexer.await

  // create the parameter that's used to inject the `kernel` value into cell scope
  private def runtimeValDef = ValDef(Modifiers(global.Flag.IMPLICIT), TermName("kernel"), tq"polynote.runtime.KernelRuntime", EmptyTree)

    * Goes backward through the state and collects all the output values and imports from previous cells. For Scala cells,
    * it also builds up a list of prior CellCode instances. We start by wrapping a cell's code in a class which has all
    * these available values and all available prior cells as constructor arguments, and then we prune it to only keep
    * the constructor arguments which the code requires.
  private def collectState(state: State): CollectedState = state.prev.collect {
    case ScalaCellState(_, _, values, cellCode, _) =>
      val valuesMap = => -> v.value).toMap
      val inputs =
        .flatMap {
          v =>
            val (nameString, input) = match {
              case "$Out" => "Out" -> v.copy(name = TermName("Out"))
              case name   => name -> v
            valuesMap.get(nameString).map(value => nameString -> (input, value)).toList
      (inputs, Option(cellCode))
    case state =>
      val inputs = {
        v =>
          val name = TermName(
          name.encodedName.toString -> (ValDef(Modifiers(), name, TypeTree(v.scalaType.asInstanceOf[global.Type]), EmptyTree), v.value)
      (inputs, None)
  }.foldRight(CollectedState()) {
    case ((nextInputs, cellCode), CollectedState(inputs, imports, priorCells)) =>
      val nextImports =, Nil))
      CollectedState(inputs ++ nextInputs, imports ++ nextImports, :: priorCells).getOrElse(priorCells))

    * Ensure an input [[ValDef]] is suitable as a constructor parameter
  private def cleanInput(input: ValDef): ValDef =
    input.copy(mods = input.mods &~ global.Flag.LAZY).duplicate.setPos(NoPosition)

  private def collectPrevInstances(code: CellCode, state: State): List[AnyRef] = {
    val allInstances = state.prev.collect {
      case ScalaCellState(_, _, _, cellCode, Some(inst)) => cellCode.cellClassSymbol -> inst

    val usedInstances = {
      cell => allInstances(cell.cellClassSymbol)


  private def partitionInputs(code: CellCode, inputValues: List[Any]) = {
    val (implicitInputs, nonImplicitInputs) =

  private def createInstance(constructor: Constructor[_], prevInstances: List[AnyRef], inputs: List[Any]): AnyRef = {
    constructor.newInstance(prevInstances ++[AnyRef]): _*).asInstanceOf[AnyRef]

    * Run the cell given the loaded compiled class and the input values carried from previous cells (not including
    * prior cell instances themselves). Collects any required prior cell instances (for dependent types and imports)
    * and constructs the cell class (running the code) in an interruptible task while capturing standard output.
  private def runClass(cls: Class[_], code: CellCode, inputValues: List[Any], state: State) = for {
    constructor   <- ZIO(cls.getDeclaredConstructors()(0))
    prevInstances  = collectPrevInstances(code, state)
    (nonImplicitInputs, implicitInputs) = partitionInputs(code, inputValues)
    instance      <- effectBlockingInterrupt(createInstance(constructor, prevInstances, nonImplicitInputs ++ implicitInputs)).catchSome {
      case err: InvocationTargetException if !(err.getCause eq err) && err.getCause != null =>
  } yield instance

  private def getResultValues(id: CellID, code: CellCode, result: AnyRef) = {
    val cls = result.getClass
    val typedOuts = code.typedOutputs
    scalaCompiler.formatTypes( {
      typeNames => effectBlockingInterrupt { {
          case (v, typeName) if !( == "Unit") && !( == "BoxedUnit") =>
            val value = cls.getDeclaredMethod(
            val name = match {
              case "$Out" => "Out"
              case name => name
            ResultValue(name, typeName, Nil, id, value, v.tpt.tpe, Some((v.pos.start, v.pos.end)))


    * A [[State]] implementation for Scala cells. It tracks the CellCode and the instance of the cell class, which
    * we'll need to pass into future cells if they use types, classes, etc from this cell.
  case class ScalaCellState(id: CellID, prev: State, values: List[ResultValue], cellCode: CellCode, instance: Option[AnyRef]) extends State {
    override def withPrev(prev: State): ScalaCellState = copy(prev = prev)
    override def updateValues(fn: ResultValue => ResultValue): State = copy(values =
    override def updateValuesM[R](fn: ResultValue => RIO[R, ResultValue]): RIO[R, State] =
      ZIO.collectAll( => copy(values = values))


object ScalaInterpreter {

  def apply(): RIO[Blocking with ScalaCompiler.Provider, ScalaInterpreter] = for {
    compiler <- ScalaCompiler.access
    index    <- ClassIndexer.default
  } yield new ScalaInterpreter(compiler, index)

  // capture the last statement in a value Out, if it's a free expression
  def captureLastExpression(global: Global)(trees: List[global.Tree]): List[global.Tree] = {
    import global._
    trees.reverse match {
      case Nil => Nil
      case l :: r => l match {
        case v: ValDef => (v :: r).reverse
        case expr if expr.isTerm => (atPos(expr.pos)(ValDef(Modifiers(), TermName("$Out"), TypeTree(NoType), expr)) :: r).reverse
        case v => (v :: r).reverse

  // Notify the `kernel` of progress and execution status during the cell execution
  def addPositionUpdates(global: Global)(trees: List[global.Tree]): List[global.Tree] = {
    import global._
    val numTrees = trees.size
    if (numTrees == 0) return Nil
    val lastTree = trees.last
    trees.zipWithIndex.flatMap {
      case (tree, index) =>
        val treeProgress = Literal(Constant(index.toDouble / numTrees))
        val lineStr = s"Line ${tree.pos.line}"
        val sPos = tree.pos.makeTransparent
        // code to notify kernel of progress in the cell
        def setProgress(detail: String) =
          atPos(sPos)(q"""kernel.setProgress($treeProgress, ${Literal(Constant(detail))})""")
        def setPos(mark: Tree) =
            Some(atPos(sPos)(q"""kernel.setExecutionStatus(${Literal(Constant(mark.pos.start))}, ${Literal(Constant(mark.pos.end))})"""))
          else None

        def wrapWithProgress(name: String, tree: Tree): List[Tree] =
          setPos(tree).toList ++ List(setProgress(name), tree)

        tree match {
          case tree: global.ValDef => wrapWithProgress(, tree)
          case tree: global.MemberDef => List(tree)
          case tree: global.Import => List(tree)
          case tree => wrapWithProgress(lineStr, tree)
    } :+ atPos(lastTree.pos.makeTransparent)(q"kernel.clearExecutionStatus()")

  trait Factory extends Interpreter.Factory {
    val languageName = "Scala"
    def apply(): RIO[BaseEnv with GlobalEnv with ScalaCompiler.Provider with CurrentNotebook with CurrentTask with TaskManager, ScalaInterpreter]

    * The Scala interpreter factory is a little bit special, in that it doesn't do any dependency fetching. This is
    * because the JVM dependencies must already be fetched when the kernel is booted.
  object Factory extends Factory {
    override def apply(): RIO[BaseEnv with GlobalEnv with ScalaCompiler.Provider with CurrentNotebook with CurrentTask with TaskManager, ScalaInterpreter] = ScalaInterpreter()

© 2015 - 2024 Weber Informatics LLC | Privacy Policy