All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.tools.nsc.transform.async.LiveVariables.scala Maven / Gradle / Ivy

/*
 * Scala (https://www.scala-lang.org)
 *
 * Copyright EPFL and Lightbend, Inc.
 *
 * Licensed under Apache License 2.0
 * (http://www.apache.org/licenses/LICENSE-2.0).
 *
 * See the NOTICE file distributed with this work for
 * additional information regarding copyright ownership.
 */

package scala.tools.nsc.transform.async

import scala.collection.immutable.ArraySeq
import scala.collection.mutable
import scala.reflect.internal.Flags._

trait LiveVariables extends ExprBuilder {
  import global._

  /**
   *  Live variables data-flow analysis.
   *
   *  Find, for each lifted field, the last state where the field is used.
   *
   *  @param   asyncStates the states of an `async` block
   *  @param   liftables   the lifted fields
   *  @return              a map which indicates fields which are used for the final time in each state.
   */
  def fieldsToNullOut(asyncStates: List[AsyncState], finalState: AsyncState,
                      liftables: List[Tree]): mutable.LinkedHashMap[Int, (mutable.LinkedHashSet[Symbol], mutable.LinkedHashSet[Symbol])] = {

    val liftedSyms = mutable.LinkedHashSet[Symbol]()

    // include only vars
    liftedSyms ++= liftables.iterator.collect {
      case vd : ValDef if vd.symbol.hasFlag(MUTABLE) =>
        vd.symbol
    }

    // determine which fields should be live also at the end (will not be nulled out)
    liftedSyms.foreach { sym =>
      val tpSym = sym.info.typeSymbol
      if ((tpSym.isPrimitiveValueClass || tpSym == definitions.NothingClass) || liftables.exists { tree =>
        !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym)})
        liftedSyms -= sym
    }

    /*
     *  Traverse statements of an `AsyncState`, collect `Ident`-s referring to lifted fields.
     *
     *  @param  as  a state of an `async` expression
     *  @return     a set of lifted fields that are used within state `as`
     */
    def fieldsUsedIn(as: AsyncState): (collection.Set[Symbol], collection.Set[Symbol]) = {
      class FindUseTraverser extends AsyncTraverser {
        val usedBeforeAssignment = new mutable.LinkedHashSet[Symbol]()
        val assignedFields = new mutable.LinkedHashSet[Symbol]()
        private def capturing[A](body: => A): A = {
          val saved = capturing
          try {
            capturing = true
            body
          } finally capturing = saved
        }
        private def capturingCheck(tree: Tree) = capturing(super[Traverser].traverse(tree))
        private var capturing: Boolean = false
        override def traverse(tree: Tree) = tree match {
          case Assign(i @ Ident(_), rhs) if liftedSyms(tree.symbol) =>
            if (!capturing)
              assignedFields += i.symbol
            traverse(rhs)
          case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
            if (!capturing)
              assignedFields += tree.symbol
            traverse(rhs)
          case Ident(_) if liftedSyms(tree.symbol) =>
            if (capturing) {
              liftedSyms -= tree.symbol
            } else if (!assignedFields.contains(tree.symbol)) {
              usedBeforeAssignment += tree.symbol
            }
          case _ =>
            super.traverse(tree)
        }

        override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef)

        override def nestedModuleClass(moduleClass: ClassDef): Unit = capturingCheck(moduleClass)

        override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef)

        override def synchronizedCall(arg: Tree): Unit = capturingCheck(arg)

        override def function(function: Function): Unit = capturingCheck(function)
        override def function(expandedFunction: ClassDef): Unit = capturingCheck(expandedFunction)
      }

      val findUses = new FindUseTraverser
      findUses.traverse(Block(as.stats: _*))
      (findUses.usedBeforeAssignment, findUses.assignedFields)
    }
    val graph: Graph[AsyncState] = {
      val g = new Graph[AsyncState]
      val stateIdToState = asyncStates.iterator.map(x => (x.state, x)).toMap
      for (asyncState <- asyncStates) {
        val (used, assigned) = fieldsUsedIn(asyncState)
        g.addNode(asyncState, used, assigned, asyncState.nextStates.map(stateIdToState).toList)
      }
      g.finish()
    }

    graph.lastReferences[Int](ArraySeq.unsafeWrapArray(liftedSyms.toArray[Symbol]))(_.t.state)
  }

  private final class Graph[T] {
    import java.util.BitSet
    private val nodes = mutable.LinkedHashMap[T, Node]()
    private class Node(val t: T, val refs: collection.Set[Symbol], val assign: collection.Set[Symbol], val succTs: List[T]) {
      val succ = new Array[Node](succTs.size)
      val pred = new mutable.ArrayBuffer[Node](4)
      // Live variables at node entry
      val entry: BitSet = new BitSet
      // Live variables at node exit
      var exit = new BitSet
      // Variables generated at this node
      val gen = new BitSet
      val kill = new BitSet
      var visited: Boolean = false

      def updateEntry(): Boolean = {
        val card = entry.cardinality()
        entry.clear()
        entry.or(exit)
        entry.andNot(kill)
        entry.or(gen)
        if (!visited) {
          visited = true
          true
        } else (entry.cardinality() != card)
      }
      def updateExit(): Boolean = {
        var changed = false
        if (exit == null) {
          changed = true
          exit = new BitSet()
        }
        var i = 0
        val card = exit.cardinality()
        while (i < succ.length) {
          exit.or(succ(i).entry)
          i += 1
        }
        card != exit.cardinality()
      }
      def deadOnEntryLiveOnPredecessorExit: BitSet = {
        val result = new BitSet
        if (!pred.isEmpty) {
          val it = pred.iterator
          while (it.hasNext) {
            val pred = it.next()
            result.or(pred.exit)
          }
          result.andNot(entry)
        }
        result
      }
      def deadOnExitLiveOnEntry: BitSet = {
        val result = entry.clone.asInstanceOf[BitSet]
        result.andNot(exit)
        result
      }
      override def toString = s"Node($t, gen = $gen, kill = $kill, entry = $entry, exit = $exit, null = $deadOnEntryLiveOnPredecessorExit)"
    }
    def addNode(t: T, refs: collection.Set[Symbol], assign: collection.Set[Symbol], succ: List[T]): Unit = {
      nodes(t) = new Node(t, refs, assign, succ)
    }
    private var finished = false
    def finish(): this.type = {
      assert(!finished, "cannot finish when already finished")
      for (node <- nodes.valuesIterator) {
        foreachWithIndex(node.succTs) {(succT, i) =>
          val succ = nodes(succT)
          node.succ(i) = succ
          succ.pred += node
        }
      }
      finished = true
      this
    }
    def lastReferences[K](syms: IndexedSeq[Symbol])(keyMapping: Node => K): mutable.LinkedHashMap[K, (mutable.LinkedHashSet[Symbol], mutable.LinkedHashSet[Symbol])] = {
      assert(finished, "lastReferences before finished")
      val symIndices: Map[Symbol, Int] = syms.zipWithIndex.toMap
      val nodeValues = nodes.values.toArray
      nodeValues.foreach { node =>
        for (ref <- node.refs) {
          symIndices.getOrElse(ref, -1) match {
            case -1 =>
            case n => node.gen.set(n)
          }
        }
        for (ref <- node.assign) {
          symIndices.getOrElse(ref, -1) match {
            case -1 =>
            case n => node.kill.set(n)
          }
        }
      }
      val terminal = nodeValues.last
      val workList = mutable.Queue[Node](terminal)
      while (!workList.isEmpty) {
        val node = workList.dequeue()
        node.updateExit()
        val entryChanged = node.updateEntry()
        if (entryChanged) {
          workList ++= node.pred
        }
      }
      val empty = mutable.LinkedHashSet[Symbol]()
      def toSymSet(indices: BitSet): mutable.LinkedHashSet[Symbol] = {
        if (indices.isEmpty) empty
        else {
          val result = mutable.LinkedHashSet[Symbol]()
          indices.stream().forEach(i => result += syms(i))
          result
        }
      }
      mutable.LinkedHashMap(ArraySeq.unsafeWrapArray(nodeValues.map { x =>
        val pre = toSymSet(x.deadOnEntryLiveOnPredecessorExit)
        val post = toSymSet(x.deadOnExitLiveOnEntry)
        (keyMapping(x), (pre, post))
      }): _*)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy