All Downloads are FREE. Search and download functionalities are using the official Maven repository.

polynote.kernel.interpreter.scal.ScalaInterpreter.scala Maven / Gradle / Ivy

The newest version!
package polynote.kernel
package interpreter
package scal

import java.lang.reflect.{Constructor, InvocationTargetException}

import scala.reflect.internal.util.{NoPosition, Position}
import scala.tools.nsc.interactive.Global
import polynote.messages.CellID
import zio.blocking.{Blocking, effectBlockingInterrupt}
import zio.{RIO, Task, ZIO}
import ScalaInterpreter.{addPositionUpdates, captureLastExpression}
import polynote.kernel.environment.{CurrentNotebook, CurrentRuntime, CurrentTask}
import polynote.kernel.task.TaskManager

class ScalaInterpreter private[scal] (
  val scalaCompiler: ScalaCompiler,
  indexer: ClassIndexer
) extends Interpreter {
  import scalaCompiler.{CellCode, global, Imports}
  import global.{Tree, ValDef, TermName, Modifiers, EmptyTree, TypeTree, Import, Name, Type, Quasiquote, typeOf, atPos, NoType}

  ///////////////////////////////////
  // Interpreter interface methods //
  ///////////////////////////////////

  override def run(code: String, state: State): RIO[InterpreterEnv, State] = for {
    collectedState <- injectState(collectState(state))
    valDefs         = collectedState.values.mapValues(_._1).values.toList
    cellCode       <- scalaCompiler.cellCode(s"Cell${state.id.toString}", code, collectedState.prevCells, valDefs, collectedState.imports)
      .flatMap(_.transformStats(transformCode).pruneInputs())
    inputNames      = cellCode.inputs.map(_.name.decodedName.toString)
    inputs          = inputNames.map(collectedState.values).map(_._2)
    cls            <- scalaCompiler.compileCell(cellCode)
    resultInstance <- cls.map(cls => runClass(cls, cellCode, inputs, state).map(Some(_))).getOrElse(ZIO.succeed(None))
    resultValues   <- resultInstance.map(resultInstance => getResultValues(state.id, cellCode, resultInstance)).getOrElse(ZIO.succeed(Nil))
  } yield ScalaCellState(state.id, state.prev, resultValues, cellCode, resultInstance)

  override def completionsAt(code: String, pos: Int, state: State): RIO[Blocking, List[Completion]] = for {
    collectedState   <- injectState(collectState(state)).provideLayer(CurrentRuntime.noRuntime)
    valDefs           = collectedState.values.mapValues(_._1).values.toList
    cellCode         <- scalaCompiler.cellCode(s"Cell${state.id.toString}", s"\n\n${code.substring(0, math.min(pos, code.length))}  ", collectedState.prevCells, valDefs, collectedState.imports, strictParse = false)
    completions      <- completer.completions(cellCode, pos + 2)
  } yield completions

  override def parametersAt(code: String, pos: Int, state: State): RIO[Blocking, Option[Signatures]] = for {
    collectedState <- injectState(collectState(state)).provideLayer(CurrentRuntime.noRuntime)
    valDefs         = collectedState.values.mapValues(_._1).values.toList
    cellCode       <- scalaCompiler.cellCode(s"Cell${state.id.toString}", s"\n\n$code  ", collectedState.prevCells, valDefs, collectedState.imports, strictParse = false)
    hints          <- completer.paramHints(cellCode, pos + 2)
  } yield hints

  override def init(state: State): RIO[InterpreterEnv, State] = ZIO.succeed(state)

  override def shutdown(): Task[Unit] = ZIO.unit

  ///////////////////////////////////////
  // Overrideable scala-specific stuff //
  ///////////////////////////////////////

  /**
    * Overrideable method to inject some pre-defined state (values and imports) into the initial cell. The base implementation
    * injects the `kernel` value, making it available to the notebook. Override to inject more imports or values.
    */
  protected def injectState(collectedState: CollectedState): RIO[CurrentRuntime, CollectedState] =
    CurrentRuntime.access.map {
      kernelRuntime =>
        collectedState.copy(values = collectedState.values + (runtimeValDef.name.toString -> (runtimeValDef, kernelRuntime: Any)))
    }

  /**
    * Transforms the Scala statements, to inject additional things like updating the execution status and ensuring
    * the last expression is captured. Override to add additional transformations (don't forget to call super.transformCode!)
    */
  protected def transformCode(code: List[Tree]): List[Tree] = {
    addPositionUpdates(global)(captureLastExpression(global)(code))
  }


  ////////////////////////////////////////////////////
  // Protected structures for subclass implementors //
  ////////////////////////////////////////////////////

  /**
    * Container for information about available values, etc from previous cells or predefined things
    */
  protected case class CollectedState(
    values: Map[String, (ValDef, Any)] = Map.empty,
    imports: Imports = Imports(),
    prevCells: List[CellCode] = Nil)


  //////////////////////////////////
  // Private scala-specific stuff //
  //////////////////////////////////

  private val completer = ScalaCompleter(scalaCompiler, indexer)

  // for testing reliably
  private[scal] def awaitIndexer = indexer.await

  // create the parameter that's used to inject the `kernel` value into cell scope
  private def runtimeValDef = ValDef(Modifiers(global.Flag.IMPLICIT), TermName("kernel"), tq"polynote.runtime.KernelRuntime", EmptyTree)

  /**
    * Goes backward through the state and collects all the output values and imports from previous cells. For Scala cells,
    * it also builds up a list of prior CellCode instances. We start by wrapping a cell's code in a class which has all
    * these available values and all available prior cells as constructor arguments, and then we prune it to only keep
    * the constructor arguments which the code requires.
    */
  private def collectState(state: State): CollectedState = state.prev.collect {
    case ScalaCellState(_, _, values, cellCode, _) =>
      val valuesMap = values.map(v => v.name -> v.value).toMap
      val inputs = cellCode.typedOutputs.map(cleanInput)
        .flatMap {
          v =>
            val (nameString, input) = v.name.decodedName.toString match {
              case "$Out" => "Out" -> v.copy(name = TermName("Out"))
              case name   => name -> v
            }
            valuesMap.get(nameString).map(value => nameString -> (input, value)).toList
        }.toMap
      (inputs, Option(cellCode))
    case state =>
      val inputs = state.values.map {
        v =>
          val name = TermName(v.name)
          name.encodedName.toString -> (ValDef(Modifiers(), name, TypeTree(v.scalaType.asInstanceOf[global.Type]), EmptyTree), v.value)
      }.toMap
      (inputs, None)
  }.foldRight(CollectedState()) {
    case ((nextInputs, cellCode), CollectedState(inputs, imports, priorCells)) =>
      val nextImports = cellCode.map(_.splitImports()).getOrElse(Imports(Nil, Nil))
      CollectedState(inputs ++ nextInputs, imports ++ nextImports, cellCode.map(_ :: priorCells).getOrElse(priorCells))
  }

  /**
    * Ensure an input [[ValDef]] is suitable as a constructor parameter
    */
  private def cleanInput(input: ValDef): ValDef =
    input.copy(mods = input.mods &~ global.Flag.LAZY).duplicate.setPos(NoPosition)

  private def collectPrevInstances(code: CellCode, state: State): List[AnyRef] = {
    val allInstances = state.prev.collect {
      case ScalaCellState(_, _, _, cellCode, Some(inst)) => cellCode.cellClassSymbol -> inst
    }.toMap

    val usedInstances = code.priorCells.map {
      cell => allInstances(cell.cellClassSymbol)
    }

    usedInstances
  }

  private def partitionInputs(code: CellCode, inputValues: List[Any]) = {
    val (implicitInputs, nonImplicitInputs) = code.inputs.zip(inputValues).partition(_._1.mods.isImplicit)
    (nonImplicitInputs.map(_._2.asInstanceOf[AnyRef]), implicitInputs.map(_._2.asInstanceOf[AnyRef]))
  }

  private def createInstance(constructor: Constructor[_], prevInstances: List[AnyRef], inputs: List[Any]): AnyRef = {
    constructor.newInstance(prevInstances ++ inputs.map(_.asInstanceOf[AnyRef]): _*).asInstanceOf[AnyRef]
  }

  /**
    * Run the cell given the loaded compiled class and the input values carried from previous cells (not including
    * prior cell instances themselves). Collects any required prior cell instances (for dependent types and imports)
    * and constructs the cell class (running the code) in an interruptible task while capturing standard output.
    */
  private def runClass(cls: Class[_], code: CellCode, inputValues: List[Any], state: State) = for {
    constructor   <- ZIO(cls.getDeclaredConstructors()(0))
    prevInstances  = collectPrevInstances(code, state)
    (nonImplicitInputs, implicitInputs) = partitionInputs(code, inputValues)
    instance      <- effectBlockingInterrupt(createInstance(constructor, prevInstances, nonImplicitInputs ++ implicitInputs)).catchSome {
      case err: InvocationTargetException if !(err.getCause eq err) && err.getCause != null => ZIO.fail(err.getCause)
    }
  } yield instance

  private def getResultValues(id: CellID, code: CellCode, result: AnyRef) = {
    val cls = result.getClass
    val typedOuts = code.typedOutputs
    scalaCompiler.formatTypes(typedOuts.map(_.tpt.tpe)).flatMap {
      typeNames => effectBlockingInterrupt {
        typedOuts.zip(typeNames).collect {
          case (v, typeName) if !(v.tpt.tpe.typeSymbol.name.decoded == "Unit") && !(v.tpt.tpe.typeSymbol.name.decoded == "BoxedUnit") =>
            val value = cls.getDeclaredMethod(v.name.encodedName.toString).invoke(result)
            val name = v.name.decoded match {
              case "$Out" => "Out"
              case name => name
            }
            ResultValue(name, typeName, Nil, id, value, v.tpt.tpe, Some((v.pos.start, v.pos.end)))
        }
      }
    }
  }

  ///////////////////////////////////////////

  /**
    * A [[State]] implementation for Scala cells. It tracks the CellCode and the instance of the cell class, which
    * we'll need to pass into future cells if they use types, classes, etc from this cell.
    */
  case class ScalaCellState(id: CellID, prev: State, values: List[ResultValue], cellCode: CellCode, instance: Option[AnyRef]) extends State {
    override def withPrev(prev: State): ScalaCellState = copy(prev = prev)
    override def updateValues(fn: ResultValue => ResultValue): State = copy(values = values.map(fn))
    override def updateValuesM[R](fn: ResultValue => RIO[R, ResultValue]): RIO[R, State] =
      ZIO.collectAll(values.map(fn)).map(values => copy(values = values))
  }

}

object ScalaInterpreter {

  def apply(): RIO[Blocking with ScalaCompiler.Provider, ScalaInterpreter] = for {
    compiler <- ScalaCompiler.access
    index    <- ClassIndexer.default
  } yield new ScalaInterpreter(compiler, index)

  // capture the last statement in a value Out, if it's a free expression
  def captureLastExpression(global: Global)(trees: List[global.Tree]): List[global.Tree] = {
    import global._
    trees.reverse match {
      case Nil => Nil
      case l :: r => l match {
        case v: ValDef => (v :: r).reverse
        case expr if expr.isTerm => (atPos(expr.pos)(ValDef(Modifiers(), TermName("$Out"), TypeTree(NoType), expr)) :: r).reverse
        case v => (v :: r).reverse
      }
    }
  }

  // Notify the `kernel` of progress and execution status during the cell execution
  def addPositionUpdates(global: Global)(trees: List[global.Tree]): List[global.Tree] = {
    import global._
    val numTrees = trees.size
    if (numTrees == 0) return Nil
    val lastTree = trees.last
    trees.zipWithIndex.flatMap {
      case (tree, index) =>
        val treeProgress = Literal(Constant(index.toDouble / numTrees))
        val lineStr = s"Line ${tree.pos.line}"
        val sPos = tree.pos.makeTransparent
        // code to notify kernel of progress in the cell
        def setProgress(detail: String) =
          atPos(sPos)(q"""kernel.setProgress($treeProgress, ${Literal(Constant(detail))})""")
        def setPos(mark: Tree) =
          if(mark.pos.isRange)
            Some(atPos(sPos)(q"""kernel.setExecutionStatus(${Literal(Constant(mark.pos.start))}, ${Literal(Constant(mark.pos.end))})"""))
          else None

        def wrapWithProgress(name: String, tree: Tree): List[Tree] =
          setPos(tree).toList ++ List(setProgress(name), tree)

        tree match {
          case tree: global.ValDef => wrapWithProgress(tree.name.decodedName.toString, tree)
          case tree: global.MemberDef => List(tree)
          case tree: global.Import => List(tree)
          case tree => wrapWithProgress(lineStr, tree)
        }
    } :+ atPos(lastTree.pos.makeTransparent)(q"kernel.clearExecutionStatus()")
  }

  trait Factory extends Interpreter.Factory {
    val languageName = "Scala"
    def apply(): RIO[BaseEnv with GlobalEnv with ScalaCompiler.Provider with CurrentNotebook with CurrentTask with TaskManager, ScalaInterpreter]
  }

  /**
    * The Scala interpreter factory is a little bit special, in that it doesn't do any dependency fetching. This is
    * because the JVM dependencies must already be fetched when the kernel is booted.
    */
  object Factory extends Factory {
    override def apply(): RIO[BaseEnv with GlobalEnv with ScalaCompiler.Provider with CurrentNotebook with CurrentTask with TaskManager, ScalaInterpreter] = ScalaInterpreter()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy