All Downloads are FREE. Search and download functionalities are using the official Maven repository.

polynote.kernel.ScalaCompiler.scala Maven / Gradle / Ivy

The newest version!
package polynote.kernel

import java.io.File
import java.net.URL
import java.util.concurrent.atomic.AtomicInteger
import polynote.kernel.environment.Config
import polynote.kernel.util.{KernelReporter, LimitedSharingClassLoader, pathOf}
import zio.blocking.Blocking
import zio.system.{System, env}
import zio.internal.{ExecutionMetrics, Executor}
import zio.{Has, RIO, Task, UIO, ZIO}

import scala.collection.mutable
import scala.reflect.internal.util.{AbstractFileClassLoader, NoSourceFile, Position, SourceFile}
import scala.reflect.io.VirtualDirectory
import scala.reflect.runtime.universe
import scala.tools.nsc.Settings
import scala.tools.nsc.interactive.{Global, NscThief}
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
import ScalaCompiler.OriginalPos

import scala.util.control.NonFatal

class ScalaCompiler private (
  val global: Global,
  val notebookPackage: String,
  val classLoader: AbstractFileClassLoader,
  val dependencies: List[File],
  val otherClasspath: List[File]
) {
  import global._
  private val packageName = TermName(notebookPackage)
  private val reporter = global.reporter.asInstanceOf[KernelReporter]

  private val compilerThread: Executor = new Executor {
    def yieldOpCount: Int = Int.MaxValue
    def metrics: Option[ExecutionMetrics] = None
    def submit(runnable: Runnable): Boolean = {
      global.ask(runnable.run)
      true
    }
    def here: Boolean = false
  }

  private def formatTypeInternal(typ: Type): String = typ match {
    case mt @ MethodType(params: List[Symbol], result: Type) =>
      val paramStr = params.map {
        sym => s"${sym.nameString}: ${formatTypeInternal(sym.typeSignatureIn(mt))}"
      }.mkString(", ")
      val resultType = formatTypeInternal(result)
      s"($paramStr) => $resultType"
    case ConstantType(Constant(v)) => v.toString
    case NoType => ""
    case typ if typ.typeSymbol.name == "TypedPythonObject" =>
      typ.typeArgs.headOption.fold("PythonObject")(name => formatTypeInternal(name) + " (Python)")
    case _ =>
      val typName = typ.typeSymbolDirect.rawname
      val typNameStr = typName.decoded
      typ.typeArgs.map(formatTypeInternal) match {
        case Nil => typNameStr
        case a if typNameStr == "" => s"=> $a"
        case a :: b :: Nil if typName.isOperatorName => s"$a $typNameStr $b"
        case a :: b :: Nil if typ.typeSymbol.owner.nameString == "scala" && (typNameStr == "Function1") =>
          s"$a => $b"
        case args if typ.typeSymbol.owner.nameString == "scala" && (typNameStr startsWith "Function") =>
          s"(${args.dropRight(1).mkString(",")}) => ${args.last}"
        case args => s"$typName[${args.mkString(", ")}]"
      }
  }

  private object saveOriginalPos extends Traverser {
    override def traverse(tree: Tree): Unit = {
      tree.updateAttachment(OriginalPos(tree.pos))
      super.traverse(tree)
    }
  }

  private[kernel] def unsafeFormatType(typ: Type): String = formatTypeInternal(typ)

  private val runtimeMirror = effectMemoize(scala.reflect.runtime.universe.runtimeMirror(classLoader))

  private val importer: global.Importer { val from: scala.reflect.runtime.universe.type } = global.mkImporter(scala.reflect.runtime.universe)

  def importType[T : scala.reflect.runtime.universe.TypeTag]: Type = importer.importType(scala.reflect.runtime.universe.typeOf[T])

  def reflect(value: Any): ZIO[Any, Throwable, universe.InstanceMirror] = runtimeMirror.flatMap {
    mirror => ZIO(mirror.reflect(value))
  }


  def formatType(typ: Type): RIO[Blocking, String] =
    zio.blocking.effectBlocking(formatTypeInternal(typ)).lock(compilerThread)

  def formatTypes(types: List[Type]): RIO[Blocking, List[String]] =
    zio.blocking.effectBlocking(types.map(formatTypeInternal)).lock(compilerThread)

  def compileCell(
    cellCode: CellCode
  ): RIO[Blocking, Option[Class[_]]] =
    for {
      compiled  <- cellCode.compile()
      className  = s"${packageName.encodedName.toString}.${cellCode.assignedTypeName.encodedName.toString}"
      loadClass  = zio.blocking.effectBlocking(Option(Class.forName(className, false, classLoader).asInstanceOf[Class[AnyRef]]))
      cls       <- if (cellCode.cellClassSymbol.nonEmpty) loadClass else ZIO.succeed(None)
    } yield cls

  private def parse(unit: CompilationUnit, strictParse: Boolean, packageCell: Boolean) = {
    def parseList(): List[Tree] = if (packageCell) List(newUnitParser(unit).parse()) else newUnitParser(unit).parseStats()
    if (strictParse) ZIO(reporter.attempt(parseList())).absolve else ZIO(parseList())
  }

  def cellCode(
    name: String,
    code: String,
    priorCells: List[CellCode] = Nil,
    inputs: List[ValDef] = Nil,
    inheritedImports: Imports = Imports(),
    strictParse: Boolean = true
  ): Task[CellCode] = for {
    sourceFile  <- ZIO(newSourceFile(code, name))
    compileUnit <- ZIO(new global.RichCompilationUnit(sourceFile))
    parsed      <- parse(compileUnit, strictParse, code.trim().startsWith("package "))
  } yield {
    val definedTerms = parsed.collect {
      case tree: DefTree if tree.name.isTermName => tree.name.decoded
    }.toSet

    val (allowedPriorCells, allowedInputs) = parsed match {
      case PackageDef(_, _) :: Nil => (Nil, Nil)
      case _ => (priorCells, inputs.filterNot(definedTerms contains _.name.decoded))
    }

    CellCode(
      name, parsed, allowedPriorCells, allowedInputs, inheritedImports, compileUnit, sourceFile
    )
  }

  // TODO: currently, there's no caching of the implicits, because what if you define a new implicit? Should we cache?
  def inferImplicits(types: List[Type]): RIO[Blocking, List[Option[AnyRef]]] = {
    val (names, trees) = types.map {
      typ =>
        val name = freshTermName("anon$")(globalFreshNameCreator)
        name.encodedName.toString -> q"val $name: ${TypeTree(typ)} = implicitly[${TypeTree(typ)}]"
    }.unzip

    val cellName = "anonImplicits"

    def construct: RIO[Class[_], AnyRef] =
      ZIO.accessM[Class[_]](cls => ZIO(cls.getDeclaredConstructors.head.newInstance().asInstanceOf[AnyRef]))

    def getField(name: String)(instance: AnyRef): RIO[Class[_], Option[AnyRef]] =
      ZIO.accessM[Class[_]](cls => ZIO(Option(cls.getDeclaredMethod(name).invoke(instance))) orElse ZIO.succeed(None))

    def getFields(names: List[String])(instance: AnyRef): RIO[Class[_], List[Option[AnyRef]]] =
      ZIO.foreach(names)(getField(_)(instance))

    // first we'll try to get all of them at once.
    compileCell(CellCode(cellName, trees)).flatMap {
      case Some(cls) => (construct >>= getFields(names)).provide(cls)
      case None => ZIO.fail(new IllegalStateException("Compiler provided no class for implicit results"))
    }.catchAll {
      err =>
        // if that doesn't compile (i.e. some implicits are missing) we'll try to compile each individually
        ZIO.foreach(trees.zip(names)) {
          case (tree, name) =>
            compileCell(CellCode(cellName, List(tree))).flatMap {
              case Some(cls) => (construct >>= getField(name)).provide(cls)
              case None => ZIO.fail(new IllegalStateException("Compiler provided no class for implicit results"))
            }.catchAllCause {
              cause =>
                // TODO: log it?
                ZIO.none
            }
        }
    }
  }

  private def copyAndReset[T <: Tree](tree: T): T = resetAttrs(tree.duplicate.setPos(NoPosition)).asInstanceOf[T]
  private def copyAndReset[T <: Tree](trees: List[T]): List[T] = trees.map(tree => copyAndReset(tree))

  private def template(stats: List[Tree]): Template = Template(
    List(treeBuilder.scalaDot(typeNames.AnyRef)), noSelfType, stats
  )

  // the default implementation of atPos does some additional junk that messes up things (and isn't threadsafe or nesting-safe)
  private def atPos[T <: Tree](pos: Position)(tree: T): T = {
    if (tree.pos == null || !tree.pos.isDefined) {
      tree.setPos(pos)
    }
    tree.children.foreach(atPos(tree.pos))
    tree
  }

  case class CellCode private[ScalaCompiler] (
    name: String,
    code: List[Tree],
    priorCells: List[CellCode] = Nil,
    inputs: List[ValDef] = Nil,
    inheritedImports: Imports = Imports(),
    compilationUnit: RichCompilationUnit = new global.RichCompilationUnit(NoSourceFile),
    sourceFile: SourceFile = NoSourceFile
  ) {

    // this copy of the code will be mutated by compile
    lazy val compiledCode: List[Tree] = code.foldLeft(List.empty[Tree]) {
      (accum, stat) =>
        val copied = duplicateAndKeepPositions(stat)
        if (copied.pos != null && copied.pos.isDefined) {
          saveOriginalPos.traverse(copied)
        } else {
          val pos = accum.headOption.map(_.pos.focusEnd.makeTransparent).getOrElse(beforePos.makeTransparent)
          atPos(pos)(copied)
        }
        copied :: accum
    }.reverse


    // The name of the class (and its companion object, in case one is needed)
    lazy val assignedTypeName: TypeName = freshTypeName(s"$name$$")(global.globalFreshNameCreator)
    lazy val assignedTermName: TermName = assignedTypeName.toTermName

    private lazy val priorCellNames = priorCells.map(_.assignedTypeName)

    // create constructor parameters to hold instances of prior cells; these are needed to access path-dependent types
    // for any classes, traits, type aliases, etc defined by previous cells
    lazy val priorCellInputs: List[ValDef] = priorCells.flatMap {
      cell =>
        cell.cellInstType.toList.map {
          cellInstType => ValDef(Modifiers(), TermName(s"_input${cell.assignedTypeName.decodedName.toString}"), TypeTree(cellInstType).setType(cellInstType), EmptyTree)
        }
    }

    lazy val typedInputs: List[ValDef] = inputs

    // Separate the implicit inputs, since they must be in their own parameter list
    lazy val (implicitInputs: List[ValDef], nonImplicitInputs: List[ValDef]) = typedInputs.partition(_.mods.isImplicit)

    // a position to encompass the whole synthetic tree
    private val end = math.max(0, sourceFile.length - 1)
    private lazy val wrappedPos = Position.range(sourceFile, 0, 0, end)

    // positions for imports mustn't be transparent, or they won't be added to context
    // so we should try to place everything at opaque positions
    private lazy val beforePos = Position.range(sourceFile, 0, 0, 0)
    private lazy val afterPos = Position.range(sourceFile, end, end, end)

    // create imports for all types and methods defined by previous cells
    lazy val priorCellImports: List[Import] = priorCells.zip(priorCellInputs).foldLeft(Map.empty[String, (TermName, Name)]) {
      case (accum, (cell, input)) =>
        val newNames = cell.definedTypesAndMethods.collect {
          case name if !accum.contains(name.decoded) => name
        }
        accum ++ newNames.map(name => name.decoded -> (input.name, name)).toMap
    }.values.groupBy(_._1).toList.map {
      case (input, imports) =>
        val importSelectors = imports.map(_._2).zipWithIndex.map {
          case (name, index) => ImportSelector(name, index, name, index)
        }.toList
        Import(Ident(input), importSelectors)
    }

    // what output values does this code define?
    lazy val outputs: List[ValDef] = code.collect {
      case valDef: ValDef if valDef.mods.isPublic => duplicateAndKeepPositions(valDef).asInstanceOf[ValDef]
    }

    lazy val typedOutputs: List[ValDef] = {
      val prev = outputs.map(v => v.name -> v).toMap

      def notUnit(tpe: Type) = {
        val notNullType = tpe != null
        // val notUnitType = !(tpe <:< typeOf[Unit])
        // val notBoxedUnitType = !(tpe <:< typeOf[BoxedUnit])
        // for some reason the type comparisons sometimes return false positives! WHAT?
        val notUnitType = tpe.typeSymbol.name.decoded != "Unit"
        val notBoxedUnitType = tpe.typeSymbol.name.decoded != "BoxedUnit"
        notNullType && notUnitType && notBoxedUnitType
      }

      // We have to deal with dependent types. If a cell declares a class, and another cell creates a value of that
      // class, it will be typed as i.e. Cell2.this._inputCell1.Foo. This will ordinarily be OK, but if we try
      // to ascribe a type to it, like `foo: Foo` inside of Cell3, this will break because here `Foo` means
      // `Cell3.this._inputCell1.Foo`. So the solution is that we create a fake value of each cell type in its
      // companion, which is never used and is just null. *But*, we target all the values of a cell-defined type
      // to be dependent on that dummy value's singleton type, so that the types will agree everywhere. And when
      // a class takes a previous cell as input, we also type that input as the singleton type. It is a bit odd, but
      // it works at runtime because the dummy value (as a *term*) is never actually referenced.
      val mySym = cellClassSymbol
      val myType = cellClassType
      val dependentTypeMapping = priorCellInputs.zip(priorCells).flatMap {
        case (input, cell) => cell.cellInstType.toList.map(input.name.toString -> _)
      }.toMap

      def transformType(typ: Type): Type = typ match {
        case TypeRef(pre, sym, Nil) =>
          typeRef(transformType(pre), sym, Nil)
        case TypeRef(pre, sym, args) =>
          typeRef(transformType(pre), sym, args.mapConserve(transformType))
        case SingleType(pre, sym) if mySym.contains(pre.typeSymbol) && (dependentTypeMapping contains sym.name.toString) =>
          global.internal.singleType(dependentTypeMapping(sym.name.toString), sym)
        case SingleType(pre, sym) =>
          global.internal.singleType(transformType(pre), sym)
        case typ if myType contains typ =>
          cellInstType.getOrElse(typ)
        case _ =>
          typ
      }

      exitingTyper(compiledCode).collect {
        case vd@ValDef(_, name, tpt, _) if (prev contains name) && notUnit(vd.symbol.originalInfo) =>
          val tpe = vd.symbol.originalInfo
          val preTyper = prev(name)
          val subst = transformType(tpe)
          val typeTree = TypeTree(subst)
          typeTree.setType(subst)
          ValDef(preTyper.mods, preTyper.name, typeTree, EmptyTree).setPos(preTyper.pos)
      }
    }

    // what things does this code import?
    lazy val imports: List[Import] = code.collect {
      case i: Import => i.duplicate
    }

    lazy val compiledImports: List[Import] = compiledCode.collect {
      case i: Import => i.duplicate
    }

    // what types (classes, traits, objects, type aliases) and methods does this code define?
    lazy val definedTypesAndMethods: List[Name] = code.collect {
      case ModuleDef(mods, name, _) if mods.isPublic => name
      case ClassDef(mods, name, _, _) if mods.isPublic => name.toTermName
      case TypeDef(mods, name, _, _) if mods.isPublic => name.toTermName
      case DefDef(mods, name, _, _, _, _) if mods.isPublic => name
    }.distinct

    lazy val wrappedImports = copyAndReset(inheritedImports.externalImports) ++ copyAndReset(inheritedImports.localImports)

    // The code all wrapped up in a class definition, with constructor arguments for the given prior cells and inputs
    private lazy val wrappedClass: ClassDef = {
      val transparentBefore = beforePos.makeTransparent

      def toParam(param: ValDef) = atPos(transparentBefore)(param.copy(mods = param.mods | Flag.PARAMACCESSOR))
      val priorCellParamList = copyAndReset(priorCellInputs).map(toParam)
      val nonImplicitParamList = copyAndReset(nonImplicitInputs).map(toParam)
      val implicitParamList = copyAndReset(implicitInputs).map(toParam)

      // have to do this manually rather than with a quasiquote; latter messes up positions
      val cls = gen.mkClassDef(
        Modifiers(),
        assignedTypeName,
        Nil,
        gen.mkTemplate(
          List(atPos(transparentBefore)(gen.scalaDot(tpnme.Serializable))),
          noSelfType.setPos(transparentBefore),
          Modifiers(),
          List(priorCellParamList, nonImplicitParamList, implicitParamList).filter(_.nonEmpty),
          (priorCellImports ++ wrappedImports).map(atPos(beforePos.makeTransparent)) ++ compiledCode,
          wrappedPos
        )).setPos(wrappedPos)

      /*
        // equivalent tree:
          q"""final class $assignedTypeName(..$priorCellParamList)(..$nonImplicitParamList)(..$implicitParamList) extends ${atPos(transparentBefore)(gen.scalaDot(tpnme.Serializable))} {
            ..${priorCellImports.map(atPos(beforePos))}
            ..${wrappedImports.map(atPos(beforePos))}
            ..${compiledCode}
          }""".setPos(wrappedPos)
       */

      cls
    }

    private lazy val companion: ModuleDef = {
      q"""object $assignedTermName extends {
           final val instance: $assignedTypeName = null
           type Inst = instance.type
         }"""
    }

    // Wrap the code in a class within the given package. Constructing the class runs the code.
    // The constructor parameters are
    private lazy val wrapped: PackageDef = compiledCode match {
      case (pkg @ PackageDef(_, stats)) :: Nil => atPos(wrappedPos)(pkg)
      case code => atPos(wrappedPos) {
        q"""package ${atPos(beforePos)(Ident(packageName))} {
            $wrappedClass
            ${atPos(afterPos)(companion)}
          }"""
      }
    }

    // the type representing this cell's class. It may be null or NoType if invoked before compile is done!
    private def ifNotPackage[T](t: => T): Option[T] = code match {
      case PackageDef(_, _) :: Nil => None
      case _ => Option(t)
    }

    lazy val cellClassType: Option[Type] = ifNotPackage(exitingTyper(wrappedClass.symbol.info))
    lazy val cellClassSymbol: Option[ClassSymbol] = ifNotPackage(exitingTyper(wrappedClass.symbol.asClass))
    lazy val cellCompanionSymbol: Option[ModuleSymbol] = ifNotPackage(exitingTyper(companion.symbol.asModule))
    lazy val cellInstSymbol: Option[Symbol] = cellCompanionSymbol.map(sym => exitingTyper(sym.info.member(TermName("instance")).accessedOrSelf))
    lazy val cellInstType: Option[Type] = cellCompanionSymbol.map(sym => exitingTyper(sym.info.member(TypeName("Inst")).info.dealias))

    // Note – you mustn't typecheck and then compile the same instance; those trees are done for. Instead, make a copy
    // of this CellCode and typecheck that if you need info about the typed trees without compiling all the way
    private[kernel] lazy val typed = {
      val run = new Run()
      compilationUnit.body = wrapped
      compilationUnit.lastBody = wrapped
      unitOfFile.put(sourceFile.file, compilationUnit)
      global.globalPhase = run.namerPhase // make sure globalPhase matches run phase
      run.namerPhase.asInstanceOf[global.GlobalPhase].apply(compilationUnit)
      global.globalPhase = run.typerPhase // make sure globalPhase matches run phase
      run.typerPhase.asInstanceOf[global.GlobalPhase].apply(compilationUnit)
      exitingTyper(compilationUnit.body)
    }

    private[kernel] def typedTreeAt(pos: Int) = zio.blocking.effectBlocking {
      val run = new Run()
      compilationUnit.body = wrapped
      compilationUnit.lastBody = wrapped
      unitOfFile.put(sourceFile.file, compilationUnit)
      // this does a targeted-typecheck, which is more tolerant of errors. But, we have to make sure it's past the
      // namer first or it will go through the parser which will throw position validation errors.
      enteringPhase(currentRun.namerPhase)(currentRun.namerPhase.asInstanceOf[GlobalPhase].applyPhase(unitOfFile(sourceFile.file)))
      compilationUnit.status = JustParsed
      NscThief.typedTreeAt(global, Position.offset(sourceFile, pos))
    }.lock(compilerThread)

    private[ScalaCompiler] def compile() = ZIO {
      val run = new Run()
      compilationUnit.body = wrapped
      unitOfFile.put(sourceFile.file, compilationUnit)
      ZIO {
        reporter.attempt {
          withContextClassLoader(classLoader)(run.compileUnits(List(compilationUnit), run.namerPhase))
          exitingTyper(compilationUnit.body)
          // materialize these lazy vals now while the run is still active
          val _1 = cellClassSymbol
          val _2 = cellClassType
          val _3 = cellCompanionSymbol
          val _4 = cellInstSymbol
          val _5 = cellInstType
          val _6 = typedOutputs
        }
      }.lock(compilerThread).absolve
    }.flatten

    // compute which inputs (values from previous cells) are actually referenced in the code
    private def usedInputs = {
      val classSymbol = cellClassSymbol
      val inputNames = inputs.map(_.name).toSet
      compilationUnit.body.collect {
        case Select(This(`assignedTypeName`), name: TermName) if inputNames contains name => name
        case id@Ident(name: TermName) if (id.symbol.ownerChain contains classSymbol) && (inputNames contains name) => name
      }
    }

    // compute which prior cells are actually used – due to referencing a type or method from that cell
    private def usedPriorCells: List[CellCode] = cellClassSymbol.map { cellClassSymbol =>
      val inputCellSymbols = cellClassSymbol.primaryConstructor.paramss.head.toSet
      val inputCellSymbolNames = inputCellSymbols.map(_.name)
      val used = new mutable.HashSet[Name]()

      // first pull out all of the trees that correspond to the input code after typechecking. This includes all
      // synthetic stuff that was generated for the user's code, but not any of our wrapper junk.
      val classBody = compilationUnit.body.collect {
        case ClassDef(_, `assignedTypeName`, _, Template(_, _, stats)) => stats.filter {
          case ValDef(mods, _, _, _) if mods.isParamAccessor => false
          case DefDef(mods, name, _, _, _, _) if mods.isParamAccessor || name.decoded == "" => false
          case tree => !tree.pos.isTransparent
        }
      }.flatten

      // Now go through each of those trees and look for references to previous cells. Ideally this could be just
      // symbol matching, but the symbols don't always seem to match when they ought to, so it matches on the
      // previous cell's accessor name.
      val traverser = new Traverser {
        override def traverse(tree: Tree): Unit = {
          // TODO: this doesn't cover type trees which reference a type imported from a cell
          if (tree.symbol != null) {
            if (inputCellSymbolNames.contains(tree.symbol.name)) {
              used.add(tree.symbol.name)
            }
          }

          // type trees don't get traversed in a useful way by default; traverse its original tree if it has one
          tree match {
            case tree@TypeTree() if tree.original != null => super.traverse(tree.original)
            case tree@TypeTree() =>
              // HACKY: check types in type args and search symbol chains to see whether they originate in previous cells.
              val typeChains = tree.tpe.typeArgs.flatMap(_.prefixChain)
              typeChains.foreach {
                tpe =>
                  val name = tpe.termSymbol.name
                  if (inputCellSymbolNames.contains(name)) {
                    used.add(name)
                  }
              }
            case _ =>
          }

          super.traverse(tree)
        }
      }

      classBody.foreach(traverser.traverse)

      val results = priorCells.zip(priorCellInputs).filter {
        case (cell, term) => used contains term.name
      }.map(_._1)

      results
    }.getOrElse(Nil)

    def splitImports(): Imports = code match {
      case (pkg @ PackageDef(pkgId, stats)) :: Nil =>
        // TODO: should things you import from inside a package cell be imported elsewhere? Or should it be isolated?
        val selectors = stats.collect {
          case defTree: DefTree => defTree.name
        }.groupBy(_.toString).values.map(_.maxBy(_.isTermName)).zipWithIndex.map {
          case (name, index) => ImportSelector(name.toTermName, index, name.toTermName, index)
        }.toList

        if (selectors.nonEmpty) {
          Imports(Nil, List(Import(copyAndReset(pkgId), selectors)))
        } else {
          Imports(Nil, Nil)
        }
      case _ =>
        val priorCellSymbols = priorCells.flatMap(_.cellClassSymbol)
        val imports = compiledCode.zip(code).collect {
          case (i: Import, orig: Import) => (i, orig)
        }
        val (local, external) = imports.partition {
          case (Import(expr, names), _) =>
            (expr.symbol.ownerChain contains cellClassSymbol.getOrElse(NoSymbol)) || expr.symbol.ownerChain.intersect(priorCellSymbols).nonEmpty
        }
        Imports(local.map(_._2.duplicate), external.map(_._2.duplicate))
    }

    private def typedWithContextClassloader =
      ZIO(withContextClassLoader(classLoader)(reporter.attempt(typed))).lock(compilerThread).absolve

    /**
      * Make a new [[CellCode]] that uses a minimal subset of inputs and prior cells.
      * After invoking, this [[CellCode]] will not be compilable – bin it!
      */
    def pruneInputs(): Task[CellCode] = if (inputs.nonEmpty || priorCells.nonEmpty) {
      for {
        typedTree  <- typedWithContextClassloader
        usedNames  <- ZIO(usedInputs).lock(compilerThread)
        usedDeps   <- ZIO(usedPriorCells)
        usedNameSet = usedNames.map(_.decodedName.toString).toSet
      } yield copy(priorCells = usedDeps, inputs = inputs.filter(usedNameSet contains _.name.decodedName.toString))
    } else ZIO.succeed(copy(priorCells = Nil, inputs = Nil))

    /**
      * Transform the code statements using the given function.
      */
    def transformStats(fn: List[Tree] => List[Tree]): CellCode = code match {
      case PackageDef(_, _) :: Nil => this
      case code                    => copy(code = fn(code))
    }
  }

  case class Imports(
    localImports: List[Import] = Nil,
    externalImports: List[Import] = Nil
  ) {
    def ++(that: Imports): Imports = Imports(localImports ++ that.localImports, externalImports ++ that.externalImports)
  }

}

object ScalaCompiler {

  def access: ZIO[Provider, Nothing, ScalaCompiler]    = ZIO.access[ScalaCompiler.Provider](_.get)
  def settings: ZIO[Provider, Nothing, Settings]       = access.map(_.global.settings)
  def dependencies: ZIO[Provider, Nothing, List[File]] = access.map(_.dependencies)

  private val _kernelCounter = new AtomicInteger(0)
  private[kernel] def kernelCounter: UIO[Int] = ZIO.effectTotal(_kernelCounter.getAndIncrement())

  // Available for testing
  private[polynote] def apply(settings: Settings, classLoader: AbstractFileClassLoader, notebookPackage: String = "notebook"): Task[ScalaCompiler] =
    ZIO {
      val global = new Global(settings, KernelReporter(settings))
      new ScalaCompiler(global, notebookPackage, classLoader, Nil, settings.classpath.value.split(File.pathSeparatorChar).toList.map(new File(_)))
    }


  /**
    * @param dependencyClasspath List of class path entries for direct dependencies. Classes from these entries will be
    *                            prioritized by auto-import/classpath-based completions in participating JVM languages.
    * @param transitiveClasspath List of class path entries for transitive dependencies. These are still loaded in the
    *                            notebook class loader, but they don't get higher priority for autocomplete.
    * @param otherClasspath      List of class path entries which the compiler needs to know about, but which aren't
    *                            going to be loaded by the dependency class loader (and thus will be loaded by the boot
    *                            class loader). This basically means Spark and its ilk.
    * @param modifySettings      A function which will receive the base compiler [[Settings]], and can return modified
    *                            settings which will be used to construct the compiler.
    * @return A [[ScalaCompiler]] instance
    */
  def apply(
    dependencyClasspath: List[File],
    transitiveClasspath: List[File],
    otherClasspath: List[File],
    modifySettings: Settings => Settings
  ): RIO[Config with System, ScalaCompiler] = for {
    settings          <- ZIO(modifySettings(defaultSettings(new Settings(), dependencyClasspath ++ transitiveClasspath ++ otherClasspath)))
    global            <- ZIO(new Global(settings, KernelReporter(settings)))
    counter           <- kernelCounter
    notebookPackage    = s"notebook$counter"
    classLoader       <- makeClassLoader(settings, dependencyClasspath ++ transitiveClasspath)
  } yield new ScalaCompiler(global, notebookPackage, classLoader, dependencyClasspath, otherClasspath)

  def apply(dependencyClasspath: List[File], transitiveClasspath: List[File], otherClasspath: List[File]): RIO[Config with System, ScalaCompiler] =
    apply(dependencyClasspath, transitiveClasspath, otherClasspath, identity[Settings])

  def makeClassLoader(settings: Settings, dependencyClasspath: List[File]): RIO[Config, AbstractFileClassLoader] = for {
    dependencyClassLoader <- makeDependencyClassLoader(settings, dependencyClasspath)
    compilerOutput        <- ZIO.fromOption(settings.outputDirs.getSingleOutput).mapError(_ => new IllegalArgumentException("Compiler must have a single output directory"))
  } yield new AbstractFileClassLoader(compilerOutput, dependencyClassLoader)

  def makeDependencyClassLoader(settings: Settings, dependencyClasspath: List[File]): RIO[Config, URLClassLoader] = Config.access.flatMap {
    config => ZIO {
      if (config.behavior.dependencyIsolation) {
        new LimitedSharingClassLoader(
          config.behavior.getSharedString,
          dependencyClasspath.map(_.toURI.toURL),
          getClass.getClassLoader)
      } else {
        new URLClassLoader(dependencyClasspath.map(_.toURI.toURL), getClass.getClassLoader)
      }
    }
  }

  private def pathAsFile(url: URL): File = url match {
    case url if url.getProtocol == "file" => new File(url.getPath)
    case url => throw new IllegalStateException(s"Required path $url must be a local file, not ${url.getProtocol}")
  }

  val requiredPolynotePaths: List[File] = List(
    pathOf(polynote.runtime.Runtime.getClass),
    pathOf(classOf[jep.python.PyObject])
  ).map(pathAsFile)

  val requiredPaths: List[File] = requiredPolynotePaths ++ List(
    pathOf(classOf[List[_]]),
    pathOf(classOf[scala.reflect.runtime.JavaUniverse]),
    pathOf(classOf[scala.tools.nsc.Global])
  ).distinct.map(pathAsFile)

  def defaultSettings(initial: Settings, classPath: List[File] = Nil): Settings = {
    val cp = classPath ++ requiredPaths

    val settings = initial.copy()
    settings.target.value = if (settings.target.choices.contains("jvm-1.8")) "jvm-1.8" else "8" // set Java8 by default
    settings.classpath.append(cp.map(_.getCanonicalPath).mkString(File.pathSeparator))
    settings.Yrangepos.value = true

    // -YpartialUnification was removed as an option in 2.13 (because it's always-on). We still want to enable it for
    // previous Scala versions, but it will cause compile error if we try to set it in 2.13. So this is a hack that
    // attempts to enable it via reflection, and ignores errors from trying to do so
    try {
      Option(settings.getClass.getMethod("YpartialUnification")).foreach {
        method => method.invoke(settings).asInstanceOf[settings.BooleanSetting].value = true
      }
    } catch {
      case NonFatal(_) =>  // not on Scala 2.11.11+ - that's OK, just won't get partial unification
    }
    settings.exposeEmptyPackage.value = true
    settings.Ymacroexpand.value = settings.MacroExpand.Normal
    settings.YpresentationAnyThread.value = true
    settings.Ydelambdafy.value = "inline"
    settings.outputDirs.setSingleOutput(new VirtualDirectory("(memory)", None))
    settings
  }

  type Provider = Has[ScalaCompiler]

  object Provider {
    def of(compiler: ScalaCompiler): Provider = Has(compiler)
  }

  // Attachment for saving the original position of a tree (before the typer)
  case class OriginalPos(pos: Position)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy