scala.tools.selectivecps.SelectiveCPSTransform.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of continuations Show documentation
Show all versions of continuations Show documentation
Delimited continuations compilation for Scala
The newest version!
// $Id$
package scala.tools.selectivecps
import scala.tools.nsc.transform._
import scala.tools.nsc.plugins._
import scala.tools.nsc.ast._
/**
* In methods marked @cps, CPS-transform assignments introduced by ANF-transform phase.
*/
abstract class SelectiveCPSTransform extends PluginComponent with
InfoTransform with TypingTransformers with CPSUtils with TreeDSL {
// inherits abstract value `global` and class `Phase` from Transform
import global._ // the global environment
import definitions._ // standard classes and methods
import typer.atOwner // methods to type trees
override def description = "@cps-driven transform of selectiveanf assignments"
/** the following two members override abstract members in Transform */
val phaseName: String = "selectivecps"
protected def newTransformer(unit: CompilationUnit): Transformer =
new CPSTransformer(unit)
/** This class does not change linearization */
override def changesBaseClasses = false
/** - return symbol's transformed type,
*/
def transformInfo(sym: Symbol, tp: Type): Type = {
if (!cpsEnabled) return tp
val newtp = transformCPSType(tp)
if (newtp != tp)
debuglog("transformInfo changed type for " + sym + " to " + newtp);
if (sym == MethReifyR)
debuglog("transformInfo (not)changed type for " + sym + " to " + newtp);
newtp
}
def transformCPSType(tp: Type): Type = { // TODO: use a TypeMap? need to handle more cases?
tp match {
case PolyType(params,res) => PolyType(params, transformCPSType(res))
case NullaryMethodType(res) => NullaryMethodType(transformCPSType(res))
case MethodType(params,res) => MethodType(params, transformCPSType(res))
case TypeRef(pre, sym, args) => TypeRef(pre, sym, args.map(transformCPSType(_)))
case _ =>
getExternalAnswerTypeAnn(tp) match {
case Some((res, outer)) =>
appliedType(Context.tpeHK, List(removeAllCPSAnnotations(tp), res, outer))
case _ =>
removeAllCPSAnnotations(tp)
}
}
}
class CPSTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
private val patmatTransformer = patmat.newTransformer(unit)
override def transform(tree: Tree): Tree = {
if (!cpsEnabled) return tree
postTransform(mainTransform(tree))
}
def postTransform(tree: Tree): Tree = {
tree.setType(transformCPSType(tree.tpe))
}
def mainTransform(tree: Tree): Tree = {
tree match {
// TODO: can we generalize this?
case Apply(TypeApply(fun, targs), args)
if (fun.symbol == MethShift) =>
debuglog("found shift: " + tree)
atPos(tree.pos) {
val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct?
//gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage),
//ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR)
//gen.mkAttributedRef(ModCPS.tpe, MethShiftR) // TODO: correct?
debuglog("funR.tpe: " + funR.tpe)
Apply(
TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
args.map(transform(_))
).setType(transformCPSType(tree.tpe))
}
case Apply(TypeApply(fun, targs), args)
if (fun.symbol == MethShiftUnit) =>
debuglog("found shiftUnit: " + tree)
atPos(tree.pos) {
val funR = gen.mkAttributedRef(MethShiftUnitR) // TODO: correct?
debuglog("funR.tpe: " + funR.tpe)
Apply(
TypeApply(funR, List(targs(0), targs(1))).setType(appliedType(funR.tpe,
List(targs(0).tpe, targs(1).tpe))),
args.map(transform(_))
).setType(appliedType(Context.tpeHK, List(targs(0).tpe,targs(1).tpe,targs(1).tpe)))
}
case Apply(TypeApply(fun, targs), args)
if (fun.symbol == MethReify) =>
log("found reify: " + tree)
atPos(tree.pos) {
val funR = gen.mkAttributedRef(MethReifyR) // TODO: correct?
debuglog("funR.tpe: " + funR.tpe)
Apply(
TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
args.map(transform(_))
).setType(transformCPSType(tree.tpe))
}
case Try(block, catches, finalizer) =>
// currently duplicates the catch block into a partial function.
// this is kinda risky, but we don't expect there will be lots
// of try/catches inside catch blocks (exp. blowup unlikely).
// CAVEAT: finalizers are surprisingly tricky!
// the problem is that they cannot easily be removed
// from the regular control path and hence will
// also be invoked after creating the Context object.
/*
object Test {
def foo1 = {
throw new Exception("in sub")
shift((k:Int=>Int) => k(1))
10
}
def foo2 = {
shift((k:Int=>Int) => k(2))
20
}
def foo3 = {
shift((k:Int=>Int) => k(3))
throw new Exception("in sub")
30
}
def foo4 = {
shift((k:Int=>Int) => 4)
throw new Exception("in sub")
40
}
def bar(x: Int) = try {
if (x == 1)
foo1
else if (x == 2)
foo2
else if (x == 3)
foo3
else //if (x == 4)
foo4
} catch {
case _ =>
println("exception")
0
} finally {
println("done")
}
}
reset(Test.bar(1)) // should print: exception,done,0
reset(Test.bar(2)) // should print: done,20 <-- but prints: done,done,20
reset(Test.bar(3)) // should print: exception,done,0 <-- but prints: done,exception,done,0
reset(Test.bar(4)) // should print: 4 <-- but prints: done,4
*/
val block1 = transform(block)
val catches1 = transformCaseDefs(catches)
val finalizer1 = transform(finalizer)
if (hasAnswerTypeAnn(tree.tpe)) {
//vprintln("CPS Transform: " + tree + "/" + tree.tpe + "/" + block1.tpe)
val (stms, expr1) = block1 match {
case Block(stms, expr) => (stms, expr)
case expr => (Nil, expr)
}
val targettp = transformCPSType(tree.tpe)
val pos = catches.head.pos
val funSym = currentOwner.newValueParameter(cpsNames.catches, pos).setInfo(appliedType(PartialFunctionClass, ThrowableTpe, targettp))
val funDef = localTyper.typedPos(pos) {
ValDef(funSym, Match(EmptyTree, catches1))
}
val expr2 = localTyper.typedPos(pos) {
Apply(Select(expr1, expr1.tpe.member(cpsNames.flatMapCatch)), List(Ident(funSym)))
}
val exSym = currentOwner.newValueParameter(cpsNames.ex, pos).setInfo(ThrowableTpe)
import CODE._
// generate a case that is supported directly by the back-end
val catchIfDefined = CaseDef(
Bind(exSym, Ident(nme.WILDCARD)),
EmptyTree,
IF ((REF(funSym) DOT nme.isDefinedAt)(REF(exSym))) THEN (REF(funSym) APPLY (REF(exSym))) ELSE Throw(REF(exSym))
)
val catch2 = localTyper.typedCases(List(catchIfDefined), ThrowableTpe, targettp)
//typedCases(tree, catches, ThrowableTpe, pt)
patmatTransformer.transform(localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1))))
/*
disabled for now - see notes above
val expr3 = if (!finalizer.isEmpty) {
val pos = finalizer.pos
val finalizer2 = duplicateTree(finalizer1)
val fun = Function(List(), finalizer2)
val expr3 = localTyper.typedPos(pos) { Apply(Select(expr2, expr2.tpe.member("mapFinally")), List(fun)) }
val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol)
chown.traverse(finalizer2)
expr3
} else
expr2
*/
} else {
treeCopy.Try(tree, block1, catches1, finalizer1)
}
case Block(stms, expr) =>
val (stms1, expr1) = transBlock(stms, expr)
treeCopy.Block(tree, stms1, expr1)
case _ =>
super.transform(tree)
}
}
def transBlock(stms: List[Tree], expr: Tree): (List[Tree], Tree) = {
stms match {
case Nil =>
(Nil, transform(expr))
case stm::rest =>
stm match {
case vd @ ValDef(mods, name, tpt, rhs)
if (vd.symbol.hasAnnotation(MarkerCPSSym)) =>
debuglog("found marked ValDef "+name+" of type " + vd.symbol.tpe)
val tpe = vd.symbol.tpe
val rhs1 = atOwner(vd.symbol) { transform(rhs) }
rhs1.changeOwner(vd.symbol -> currentOwner) // TODO: don't traverse twice
debuglog("valdef symbol " + vd.symbol + " has type " + tpe)
debuglog("right hand side " + rhs1 + " has type " + rhs1.tpe)
debuglog("currentOwner: " + currentOwner)
debuglog("currentMethod: " + currentMethod)
val (bodyStms, bodyExpr) = transBlock(rest, expr)
// FIXME: result will later be traversed again by TreeSymSubstituter and
// ChangeOwnerTraverser => exp. running time.
// Should be changed to fuse traversals into one.
val specialCaseTrivial = bodyExpr match {
case Apply(fun, args) =>
// for now, look for explicit tail calls only.
// are there other cases that could profit from specializing on
// trivial contexts as well?
(bodyExpr.tpe.typeSymbol == Context) && (currentMethod == fun.symbol)
case _ => false
}
def applyTrivial(ctxValSym: Symbol, body: Tree) = {
val body1 = (new TreeSymSubstituter(List(vd.symbol), List(ctxValSym)))(body)
val body2 = localTyper.typedPos(vd.symbol.pos) { body1 }
// in theory it would be nicer to look for an @cps annotation instead
// of testing for Context
if ((body2.tpe == null) || !(body2.tpe.typeSymbol == Context)) {
//println(body2 + "/" + body2.tpe)
unit.error(rhs.pos, "cannot compute type for CPS-transformed function result")
}
body2
}
def applyCombinatorFun(ctxR: Tree, body: Tree) = {
val arg = currentOwner.newValueParameter(name, ctxR.pos).setInfo(tpe)
val body1 = (new TreeSymSubstituter(List(vd.symbol), List(arg)))(body)
val fun = localTyper.typedPos(vd.symbol.pos) { Function(List(ValDef(arg)), body1) } // types body as well
arg.owner = fun.symbol
body1.changeOwner(currentOwner -> fun.symbol)
// see note about multiple traversals above
debuglog("fun.symbol: "+fun.symbol)
debuglog("fun.symbol.owner: "+fun.symbol.owner)
debuglog("arg.owner: "+arg.owner)
debuglog("fun.tpe:"+fun.tpe)
debuglog("return type of fun:"+body1.tpe)
var methodName = nme.map
if (body1.tpe != null) {
if (body1.tpe.typeSymbol == Context)
methodName = nme.flatMap
}
else
unit.error(rhs.pos, "cannot compute type for CPS-transformed function result")
debuglog("will use method:"+methodName)
localTyper.typedPos(vd.symbol.pos) {
Apply(Select(ctxR, ctxR.tpe.member(methodName)), List(fun))
}
}
// TODO use gen.mkBlock after 2.11.0-M6. Why wait? It allows us to still build in development
// mode with `ant -DskipLocker=1`
def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr
try {
if (specialCaseTrivial) {
debuglog("will optimize possible tail call: " + bodyExpr)
// FIXME: flatMap impl has become more complicated due to
// exceptions. do we need to put a try/catch in the then part??
// val ctx =
// if (ctx.isTrivial)
// val = ctx.getTrivialValue; ... <--- TODO: try/catch ??? don't bother for the moment...
// else
// ctx.flatMap { => ... }
val ctxSym = currentOwner.newValue(newTermName("" + vd.symbol.name + cpsNames.shiftSuffix)).setInfo(rhs1.tpe)
val ctxDef = localTyper.typed(ValDef(ctxSym, rhs1))
def ctxRef = localTyper.typed(Ident(ctxSym))
val argSym = currentOwner.newValue(vd.symbol.name.toTermName).setInfo(tpe)
val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member(cpsNames.getTrivialValue))))
val switchExpr = localTyper.typedPos(vd.symbol.pos) {
val body2 = mkBlock(bodyStms, bodyExpr).duplicate // dup before typing!
If(Select(ctxRef, ctxSym.tpe.member(cpsNames.isTrivial)),
applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)),
applyCombinatorFun(ctxRef, body2))
}
(List(ctxDef), switchExpr)
} else {
// ctx.flatMap { => ... }
// or
// ctx.map { => ... }
(Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr)))
}
} catch {
case ex:TypeError =>
unit.error(ex.pos, ex.msg)
(bodyStms, bodyExpr)
}
case _ =>
val stm1 = transform(stm)
val (a, b) = transBlock(rest, expr)
(stm1::a, b)
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy