All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dotty.tools.dotc.transform.TailRec.scala Maven / Gradle / Ivy

The newest version!
package dotty.tools.dotc
package transform

import dotty.tools.dotc.ast.Trees._
import dotty.tools.dotc.ast.{TreeTypeMap, tpd}
import dotty.tools.dotc.config.Printers.tailrec
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds.{TailLabelName, TailLocalName, TailTempName}
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.reporting.diagnostic.messages.TailrecNotApplicable
import dotty.tools.dotc.transform.MegaPhase.MiniPhase

import scala.collection.mutable

/** A Tail Rec Transformer.
 *
 *  What it does:
 *
 *  Finds method calls in tail-position and replaces them with jumps.
 *  A call is in a tail-position if it is the last instruction to be
 *  executed in the body of a method. This includes being in
 *  tail-position of a `return` from a `Labeled` block which is itself
 *  in tail-position (which is critical for tail-recursive calls in the
 *  cases of a `match`). To identify tail positions, we recurse over
 *  the trees that may contain calls in tail-position (trees that can't
 *  contain such calls are not transformed).
 *
 *  When a method contains at least one tail-recursive call, its rhs
 *  is wrapped in the following structure:
 *
 *  ```
 *  var localForParam1: T1 = param1
 *  ...
 *  while () {
 *    tailResult[ResultType]: {
 *      return {
 *        // original rhs with tail recursive calls transformed (see below)
 *      }
 *    }
 *  }
 *  ```
 *
 *  Self-recursive calls in tail-position are then replaced by (a)
 *  reassigning the local `var`s substituting formal parameters and
 *  (b) a `return` from the `tailResult` labeled block, which has the
 *  net effect of looping back to the beginning of the method.
 *  If the receiver is modifed in a recursive call, an additional `var`
 *  is used to replace `this`.
 *
 *  As a complete example of the transformation, the classical `fact`
 *  function, defined as:
 *
 *  ```
 *  def fact(n: Int, acc: Int): Int =
 *    if (n == 0) acc
 *    else fact(n - 1, acc * n)
 *  ```
 *
 *  is rewritten as:
 *
 *  ```
 *  def fact(n: Int, acc: Int): Int = {
 *    var acc$tailLocal1: Int = acc
 *    var n$tailLocal1: Int = n
 *    while () {
 *      tailLabel1[Unit]: {
 *        return {
 *          if (n$tailLocal1 == 0)
 *            acc$tailLocal1
 *          else {
 *            val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1
 *            val acc$tailLocal1$tmp1: Int = acc$tailLocal1 * n$tailLocal1
 *            n$tailLocal1 = n$tailLocal1$tmp1
 *            acc$tailLocal1 = acc$tailLocal1$tmp1
 *            (return[tailLabel1] ()): Int
 *          }
 *        }
 *      }
 *    }
 *  }
 *  ```
 *
 *  As the JVM provides no way to jump from a method to another one,
 *  non-recursive calls in tail-position are not optimized.
 *
 *  A method call is self-recursive if it calls the current method and
 *  the method is final (otherwise, it could be a call to an overridden
 *  method in a subclass). Recursive calls on a different instance are
 *  optimized.
 *
 *  This phase has been moved after erasure to allow the use of vars
 *  for the parameters combined with a `WhileDo`. This is also
 *  beneficial to support polymorphic tail-recursive calls.
 *
 *  In scalac, if the method had type parameters, the call must contain
 *  the same parameters as type arguments. This is no longer the case in
 *  dotc thanks to being located after erasure.
 *  In scalac, this is named tailCall but it does only provide optimization for
 *  self recursive functions, that's why it's renamed to tailrec
 *
 *  @author
 *    Erik Stenman, Iulian Dragos,
 *    ported and heavily modified for dotty by Dmitry Petrashko
 *    moved after erasure and adapted to emit `Labeled` blocks by Sébastien Doeraene
 */
class TailRec extends MiniPhase {
  import tpd._

  override def phaseName: String = TailRec.name

  override def runsAfter: Set[String] = Set(Erasure.name) // tailrec assumes erased types

  override def transformDefDef(tree: DefDef)(implicit ctx: Context): Tree = {
    val method = tree.symbol
    val mandatory = method.hasAnnotation(defn.TailrecAnnot)
    def noTailTransform(failureReported: Boolean) = {
      // FIXME: want to report this error on `tree.nameSpan`, but
      // because of extension method getting a weird position, it is
      // better to report on method symbol so there's no overlap.
      // We don't report a new error if failures were reported
      // during the transformation.
      if (mandatory && !failureReported)
        ctx.error(TailrecNotApplicable(method), method.sourcePos)

      tree
    }

    val isCandidate = method.isEffectivelyFinal &&
      !(method.is(Accessor) || tree.rhs.eq(EmptyTree))

    if (isCandidate) {
      val enclosingClass = method.enclosingClass.asClass

      // Note: this can be split in two separate transforms(in different groups),
      // than first one will collect info about which transformations and rewritings should be applied
      // and second one will actually apply,
      // now this speculatively transforms tree and throws away result in many cases
      val transformer = new TailRecElimination(method, enclosingClass, tree.vparamss.head.map(_.symbol), mandatory)
      val rhsSemiTransformed = transformer.transform(tree.rhs)

      if (transformer.rewrote) {
        val varForRewrittenThis = transformer.varForRewrittenThis
        val rewrittenParamSyms = transformer.rewrittenParamSyms
        val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms

        val initialVarDefs = {
          val initialParamVarDefs = (rewrittenParamSyms, varsForRewrittenParamSyms).zipped.map {
            (param, local) => ValDef(local.asTerm, ref(param))
          }
          varForRewrittenThis match {
            case Some(local) => ValDef(local.asTerm, This(enclosingClass)) :: initialParamVarDefs
            case none => initialParamVarDefs
          }
        }

        val rhsFullyTransformed = varForRewrittenThis match {
          case Some(localThisSym) =>
            val thisRef = localThisSym.termRef
            new TreeTypeMap(
              typeMap = _.substThisUnlessStatic(enclosingClass, thisRef)
                .subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)),
              treeMap = {
                case tree: This if tree.symbol == enclosingClass => Ident(thisRef)
                case tree => tree
              }
            ).transform(rhsSemiTransformed)

          case none =>
            new TreeTypeMap(
              typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef))
            ).transform(rhsSemiTransformed)
        }

        cpy.DefDef(tree)(rhs =
          Block(
            initialVarDefs,
            WhileDo(EmptyTree, {
              Labeled(transformer.continueLabel.asTerm, {
                Return(rhsFullyTransformed, method)
              })
            })
          )
        )
      }
      else noTailTransform(failureReported = transformer.failureReported)
    }
    else noTailTransform(failureReported = false)
  }

  class TailRecElimination(method: Symbol, enclosingClass: ClassSymbol, paramSyms: List[Symbol], isMandatory: Boolean) extends TreeMap {

    var rewrote: Boolean = false
    var failureReported: Boolean = false

    /** The `tailLabelN` label symbol, used to encode a `continue` from the infinite `while` loop. */
    private[this] var myContinueLabel: Symbol = _
    def continueLabel(implicit ctx: Context): Symbol = {
      if (myContinueLabel == null)
        myContinueLabel = ctx.newSymbol(method, TailLabelName.fresh(), Label, defn.UnitType)
      myContinueLabel
    }

    /** The local `var` that replaces `this`, if it is modified in at least one recursive call. */
    var varForRewrittenThis: Option[Symbol] = None
    /** The subset of `paramSyms` that are modified in at least one recursive call, and which therefore need a replacement `var`. */
    var rewrittenParamSyms: List[Symbol] = Nil
    /** The replacement `var`s for the params in `rewrittenParamSyms`. */
    var varsForRewrittenParamSyms: List[Symbol] = Nil

    private def getVarForRewrittenThis()(implicit ctx: Context): Symbol = {
      varForRewrittenThis match {
        case Some(sym) => sym
        case none =>
          val tpe =
            if (enclosingClass.is(Module)) enclosingClass.thisType
            else enclosingClass.classInfo.selfType
          val sym = ctx.newSymbol(method, nme.SELF, Synthetic | Mutable, tpe)
          varForRewrittenThis = Some(sym)
          sym
      }
    }

    private def getVarForRewrittenParam(param: Symbol)(implicit ctx: Context): Symbol = {
      rewrittenParamSyms.indexOf(param) match {
        case -1 =>
          val sym = ctx.newSymbol(method, TailLocalName.fresh(param.name.toTermName), Synthetic | Mutable, param.info)
          rewrittenParamSyms ::= param
          varsForRewrittenParamSyms ::= sym
          sym
        case index => varsForRewrittenParamSyms(index)
      }
    }

    /** Symbols of Labeled blocks that are in tail position. */
    private val tailPositionLabeledSyms = new mutable.HashSet[Symbol]()

    private[this] var inTailPosition = true

    /** Rewrite this tree to contain no tail recursive calls */
    def transform(tree: Tree, tailPosition: Boolean)(implicit ctx: Context): Tree = {
      if (inTailPosition == tailPosition) transform(tree)
      else {
        val saved = inTailPosition
        inTailPosition = tailPosition
        try transform(tree)
        finally inTailPosition = saved
      }
    }

    def yesTailTransform(tree: Tree)(implicit ctx: Context): Tree =
      transform(tree, tailPosition = true)

    /** If not in tail position a tree traversal may not be needed.
     *
     *  A recursive  call may still be in tail position if within the return
     *  expression of a labeled block.
     *  A tree traversal may also be needed to report a failure to transform
     *  a recursive call of a @tailrec annotated method (i.e. `isMandatory`).
     */
    private def isTraversalNeeded =
      isMandatory || tailPositionLabeledSyms.nonEmpty

    def noTailTransform(tree: Tree)(implicit ctx: Context): Tree =
      if (isTraversalNeeded) transform(tree, tailPosition = false)
      else tree

    def noTailTransforms[Tr <: Tree](trees: List[Tr])(implicit ctx: Context): List[Tr] =
      if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
      else trees

    override def transform(tree: Tree)(implicit ctx: Context): Tree = {
      /* Rewrite an Apply to be considered for tail call transformation. */
      def rewriteApply(tree: Apply): Tree = {
        val arguments = noTailTransforms(tree.args)

        def continue =
          cpy.Apply(tree)(noTailTransform(tree.fun), arguments)

        def fail(reason: String) = {
          if (isMandatory) {
            failureReported = true
            ctx.error(s"Cannot rewrite recursive call: $reason", tree.sourcePos)
          }
          else
            tailrec.println("Cannot rewrite recursive call at: " + tree.span + " because: " + reason)
          continue
        }

        val calledMethod = tree.fun.symbol
        val prefix = tree.fun match {
          case Select(qual, _) => qual
          case x: Ident if x.symbol eq method => EmptyTree
          case x => x
        }

        val isRecursiveCall = calledMethod eq method
        def isRecursiveSuperCall = (method.name eq calledMethod.name) &&
          method.matches(calledMethod) &&
          enclosingClass.appliedRef.widen <:< prefix.tpe.widenDealias

        if (isRecursiveCall) {
          if (inTailPosition) {
            tailrec.println("Rewriting tail recursive call:  " + tree.span)
            rewrote = true

            val assignParamPairs = for {
              (param, arg) <- paramSyms.zip(arguments)
              if (arg match {
                case arg: Ident => arg.symbol != param
                case _ => true
              })
            } yield {
              (getVarForRewrittenParam(param), arg)
            }

            val assignThisAndParamPairs = {
              if (prefix eq EmptyTree) assignParamPairs
              else {
                // TODO Opt: also avoid assigning `this` if the prefix is `this.`
                (getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
              }
            }

            val assignments = assignThisAndParamPairs match {
              case (lhs, rhs) :: Nil =>
                Assign(ref(lhs), rhs) :: Nil
              case _ :: _ =>
                val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield {
                  val temp = ctx.newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Synthetic, lhs.info)
                  (ValDef(temp, rhs), Assign(ref(lhs), ref(temp)).withSpan(tree.span))
                }).unzip
                tempValDefs ::: assigns
              case nil =>
                Nil
            }

            /* The `Typed` node is necessary to perfectly preserve the type of the node.
             * Without it, lubbing in enclosing if/else or match can infer a different type,
             * which can cause Ycheck errors.
             */
            val tpt = TypeTree(method.info.resultType)
            seq(assignments, Typed(Return(unitLiteral.withSpan(tree.span), continueLabel), tpt))
          }
          else fail("it is not in tail position")
        }
        else if (isRecursiveSuperCall)
          fail("it targets a supertype")
        else
          continue
      }

      tree match {
        case tree @ Apply(fun, args) =>
          val meth = fun.symbol
          if (meth == defn.Boolean_|| || meth == defn.Boolean_&&)
            cpy.Apply(tree)(noTailTransform(fun), transform(args))
          else
            rewriteApply(tree)

        case tree @ Select(qual, name) =>
          cpy.Select(tree)(noTailTransform(qual), name)

        case tree @ Block(stats, expr) =>
          cpy.Block(tree)(
            noTailTransforms(stats),
            transform(expr)
          )

        case tree @ If(cond, thenp, elsep) =>
          cpy.If(tree)(
            noTailTransform(cond),
            transform(thenp),
            transform(elsep)
          )

        case tree: CaseDef =>
          cpy.CaseDef(tree)(body = transform(tree.body))

        case tree @ Match(selector, cases) =>
          cpy.Match(tree)(
            noTailTransform(selector),
            transformSub(cases)
          )

        case tree: Try =>
          val expr = noTailTransform(tree.expr)
          if (tree.finalizer eq EmptyTree) {
            // SI-1672 Catches are in tail position when there is no finalizer
            cpy.Try(tree)(expr, transformSub(tree.cases), EmptyTree)
          }
          else cpy.Try(tree)(
            expr,
            noTailTransforms(tree.cases),
            noTailTransform(tree.finalizer)
          )

        case tree @ WhileDo(cond, body) =>
          cpy.WhileDo(tree)(
            noTailTransform(cond),
            noTailTransform(body)
          )

        case _: Alternative | _: Bind =>
          assert(false, "We should never have gotten inside a pattern")
          tree

        case tree: ValOrDefDef =>
          if (isMandatory) noTailTransform(tree.rhs)
          tree

        case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
          tree

        case Labeled(bind, expr) =>
          if (inTailPosition)
            tailPositionLabeledSyms += bind.symbol
          try cpy.Labeled(tree)(bind, transform(expr))
          finally {
            if (inTailPosition)
              tailPositionLabeledSyms -= bind.symbol
          }

        case Return(expr, from) =>
          val fromSym = from.symbol
          val inTailPosition = fromSym.is(Label) && tailPositionLabeledSyms.contains(fromSym)
          cpy.Return(tree)(transform(expr, inTailPosition), from)

        case _ =>
          super.transform(tree)
      }
    }
  }
}

object TailRec {
  val name: String = "tailrec"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy