All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala.tools.nsc.backend.opt.ConstantOptimization.scala Maven / Gradle / Ivy

The newest version!
/* NSC -- new Scala compiler
 * Copyright 2005-2013 LAMP/EPFL
 * @author  James Iry
 */

package scala
package tools.nsc
package backend.opt

import scala.annotation.tailrec

/**
 * ConstantOptimization uses abstract interpretation to approximate for
 * each instruction what constants a variable or stack slot might hold
 * or cannot hold. From this it will eliminate unreachable conditionals
 * where only one branch is reachable, e.g. to eliminate unnecessary
 * null checks.
 *
 * With some more work it could be extended to
 * - cache stable values (final fields, modules) in locals
 * - replace the copy propagation in ClosureElilmination
 * - fold constants
 * - eliminate unnecessary stores and loads
 * - propagate knowledge gathered from conditionals for further optimization
 */
abstract class ConstantOptimization extends SubComponent {
  import global._
  import icodes._
  import icodes.opcodes._

  val phaseName = "constopt"

  /** Create a new phase */
  override def newPhase(p: Phase) = new ConstantOptimizationPhase(p)

  override val enabled: Boolean = settings.YconstOptimization

  /**
   * The constant optimization phase.
   */
  class ConstantOptimizationPhase(prev: Phase) extends ICodePhase(prev) {

    def name = phaseName

    override def apply(c: IClass) {
      if (settings.YconstOptimization) {
        val analyzer = new ConstantOptimizer
        analyzer optimizeClass c
      }
    }
  }

  class ConstantOptimizer {
    def optimizeClass(cls: IClass) {
      log(s"Analyzing ${cls.methods.size} methods in $cls.")
      cls.methods foreach { m =>
        optimizeMethod(m)
      }
    }

    def optimizeMethod(m: IMethod) {
      if (m.hasCode) {
        log(s"Analyzing ${m.symbol}")
        val replacementInstructions = interpretMethod(m)
        for (block <- m.blocks) {
          if (replacementInstructions contains block) {
            val instructions = replacementInstructions(block)
            block.replaceInstruction(block.lastInstruction, instructions)
          }
        }
      }
    }

    /**
     * A single possible (or impossible) datum that can be held in Contents
     */
    private sealed abstract class Datum
    /**
     * A constant datum
     */
    private case class Const(c: Constant) extends Datum {
      def isIntAssignable = c.tag >= BooleanTag && c.tag <= IntTag
      def toInt = c.tag match {
        case BooleanTag => if (c.booleanValue) 1 else 0
        case _ => c.intValue
      }

      /**
       * True if this constant would compare to other as true under primitive eq
       */
      override def equals(other: Any) = other match {
        case oc @ Const(o) => (this eq oc) || (if (this.isIntAssignable && oc.isIntAssignable) this.toInt == oc.toInt else c.value == o.value)
        case _ => false
      }

      /**
       * Hash code consistent with equals
       */
      override def hashCode = if (this.isIntAssignable) this.toInt else c.hashCode

    }
    /**
     * A datum that has been Boxed via a BOX instruction
     */
    private case class Boxed(c: Datum) extends Datum

    /**
     * The knowledge we have about the abstract state of one location in terms
     * of what constants it might or cannot hold. Forms a lower
     * lattice where lower elements in the lattice indicate less knowledge.
     *
     * With the following partial ordering (where '>' indicates more precise knowledge)
     *
     * Possible(xs) > Possible(xs + y)
     * Possible(xs) > Impossible(ys)
     * Impossible(xs + y) > Impossible(xs)
     *
     * and the following merges, which indicate merging knowledge from two paths through
     * the code,
     *
     * // left must be 1 or 2, right must be 2 or 3 then we must have a 1, 2 or 3
     * Possible(xs) merge Possible(ys) => Possible(xs union ys)
     *
     * // Left says can't be 2 or 3, right says can't be 3 or 4
     * // then it's not 3 (it could be 2 from the right or 4 from the left)
     * Impossible(xs) merge Impossible(ys) => Impossible(xs intersect ys)
     *
     * // Left says it can't be 2 or 3, right says it must be 3 or 4, then
     * // it can't be 2 (left rules out 4 and right says 3 is possible)
     * Impossible(xs) merge Possible(ys) => Impossible(xs -- ys)
     *
     * Intuitively, Possible(empty) says that a location can't hold anything,
     * it's uninitialized. However, Possible(empty) never appears in the code.
     *
     * Conversely, Impossible(empty) says nothing is impossible, it could be
     * anything. Impossible(empty) is given a synonym UNKNOWN and is used
     * for, e.g., the result of an arbitrary method call.
     */
    private sealed abstract class Contents {
      /**
       * Join this Contents with another coming from another path. Join enforces
       * the lattice structure. It is symmetrical and never moves upward in the
       * lattice
       */
      final def merge(other: Contents): Contents = if (this eq other) this else (this, other) match {
        case (Possible(possible1), Possible(possible2)) =>
          Possible(possible1 union possible2)
        case (Impossible(impossible1), Impossible(impossible2)) =>
          Impossible(impossible1 intersect impossible2)
        case (Impossible(impossible), Possible(possible)) =>
          Impossible(impossible -- possible)
        case (Possible(possible), Impossible(impossible)) =>
          Impossible(impossible -- possible)
      }
      // TODO we could have more fine-grained knowledge, e.g. know that 0 < x < 3. But for now equality/inequality is a good start.
      def mightEqual(other: Contents): Boolean
      def mightNotEqual(other: Contents): Boolean
    }
    private def SingleImpossible(x: Datum) = new Impossible(Set(x))

    /**
     * The location is known to have one of a set of values.
     */
    private case class Possible(possible: Set[Datum]) extends Contents {
      assert(possible.nonEmpty, "Contradiction: had an empty possible set indicating an uninitialized location")
      def mightEqual(other: Contents): Boolean = (this eq other) || (other match {
        // two Possibles might be equal if they have any possible members in common
        case Possible(possible2) => (possible intersect possible2).nonEmpty
        // a possible can be equal to an impossible if the impossible doesn't rule
        // out all the possibilities
        case Impossible(possible2) => (possible -- possible2).nonEmpty
      })
      def mightNotEqual(other: Contents): Boolean = (this ne other) && (other match {
        // two Possibles might not be equal if either has possible members that the other doesn't
        case Possible(possible2) => (possible -- possible2).nonEmpty || (possible2 -- possible).nonEmpty
        case Impossible(_) => true
      })
    }
    private def SinglePossible(x: Datum) = new Possible(Set(x))

    /**
     * The location is known to not have any of a set of values value (e.g null).
     */
    private case class Impossible(impossible: Set[Datum]) extends Contents {
      def mightEqual(other: Contents): Boolean = (this eq other) || (other match {
        case Possible(_) => other mightEqual this
        case _ => true
      })
      def mightNotEqual(other: Contents): Boolean = (this eq other) || (other match {
        case Possible(_) => other mightNotEqual this
        case _ => true
      })
    }

    /**
     * Our entire knowledge about the contents of all variables and the stack. It forms
     * a lattice primarily driven by the lattice structure of Contents.
     *
     * In addition to the rules of contents, State has the following properties:
     * - The merge of two sets of locals holds the merges of locals found in the intersection
     * of the two sets of locals. Locals not found in a
     * locals map are thus possibly uninitialized and attempting to load them results
     * in an error.
     * - The stack heights of two states must match otherwise it's an error to merge them
     *
     * State is immutable in order to aid in structure sharing of local maps and stacks
     */
    private case class State(locals: Map[Local, Contents], stack: List[Contents]) {
      def mergeLocals(olocals: Map[Local, Contents]): Map[Local, Contents] = if (locals eq olocals) locals else Map((for {
        key <- (locals.keySet intersect olocals.keySet).toSeq
      } yield (key, locals(key) merge olocals(key))): _*)

      def merge(other: State): State = if (this eq other) this else {
        @tailrec def mergeStacks(l: List[Contents], r: List[Contents], out: List[Contents]): List[Contents] = (l, r) match {
          case (Nil, Nil) => out.reverse
          case (l, r) if l eq r => out.reverse ++ l
          case (lhead :: ltail, rhead :: rtail) => mergeStacks(ltail, rtail, (lhead merge rhead) :: out)
          case _ => sys.error("Mismatched stack heights")
        }

        val newLocals = mergeLocals(other.locals)

        val newStack = if (stack eq other.stack) stack else mergeStacks(stack, other.stack, Nil)
        State(newLocals, newStack)
      }

      /**
       * Peek at the top of the stack without modifying it. Error if the stack is empty
       */
      def peek(n: Int): Contents = stack(n)
      /**
       * Push contents onto a stack
       */
      def push(contents: Contents): State = this copy (stack = contents :: stack)
      /**
       * Drop n elements from the stack
       */
      def drop(number: Int): State = this copy (stack = stack drop number)
      /**
       * Store the top of the stack into the specified local. An error if the stack
       * is empty
       */
      def store(variable: Local): State = {
        val contents = stack.head
        val newVariables = locals + ((variable, contents))
        new State(newVariables, stack.tail)
      }
      /**
       * Load the specified local onto the top of the stack. An error the the local is uninitialized.
       */
      def load(variable: Local): State = {
        val contents: Contents = locals.getOrElse(variable, sys.error(s"$variable is not initialized"))
        push(contents)
      }
      /**
       * A copy of this State with an empty stack
       */
      def cleanStack: State = if (stack.isEmpty) this else this copy (stack = Nil)
    }

    // some precomputed constants
    private val NULL = Const(Constant(null: Any))
    private val UNKNOWN = Impossible(Set.empty)
    private val NOT_NULL = SingleImpossible(NULL)
    private val CONST_UNIT = SinglePossible(Const(Constant(())))
    private val CONST_FALSE = SinglePossible(Const(Constant(false)))
    private val CONST_ZERO_BYTE = SinglePossible(Const(Constant(0: Byte)))
    private val CONST_ZERO_SHORT = SinglePossible(Const(Constant(0: Short)))
    private val CONST_ZERO_CHAR = SinglePossible(Const(Constant(0: Char)))
    private val CONST_ZERO_INT = SinglePossible(Const(Constant(0: Int)))
    private val CONST_ZERO_LONG = SinglePossible(Const(Constant(0: Long)))
    private val CONST_ZERO_FLOAT = SinglePossible(Const(Constant(0.0f)))
    private val CONST_ZERO_DOUBLE = SinglePossible(Const(Constant(0.0d)))
    private val CONST_NULL = SinglePossible(NULL)

    /**
     * Given a TypeKind, figure out what '0' for it means in order to interpret CZJUMP
     */
    private def getZeroOf(k: TypeKind): Contents = k match {
      case UNIT => CONST_UNIT
      case BOOL => CONST_FALSE
      case BYTE => CONST_ZERO_BYTE
      case SHORT => CONST_ZERO_SHORT
      case CHAR => CONST_ZERO_CHAR
      case INT => CONST_ZERO_INT
      case LONG => CONST_ZERO_LONG
      case FLOAT => CONST_ZERO_FLOAT
      case DOUBLE => CONST_ZERO_DOUBLE
      case REFERENCE(_) => CONST_NULL
      case ARRAY(_) => CONST_NULL
      case BOXED(_) => CONST_NULL
      case ConcatClass => abort("no zero of ConcatClass")
    }

    // normal locals can't be null, so we use null to mean the magic 'this' local
    private val THIS_LOCAL: Local = null

    /**
     * interpret a single instruction to find its impact on the abstract state
     */
    private def interpretInst(in: State, inst: Instruction): State = {
      // pop the consumed number of values off the `in` state's stack, producing a new state
      def dropConsumed: State = in drop inst.consumed

      inst match {
        case THIS(_) =>
          in load THIS_LOCAL

        case CONSTANT(k) =>
          // treat NaN as UNKNOWN because NaN must never equal NaN
          val const = if (k.isNaN) UNKNOWN
          else SinglePossible(Const(k))
          in push const

        case LOAD_ARRAY_ITEM(_) | LOAD_FIELD(_, _) | CALL_PRIMITIVE(_) =>
          dropConsumed push UNKNOWN

        case LOAD_LOCAL(local) =>
          // TODO if a local is known to hold a constant then we can replace this instruction with a push of that constant
          in load local

        case STORE_LOCAL(local) =>
          in store local

        case STORE_THIS(_) =>
          // if a local is already known to have a constant and we're replacing with the same constant then we can
          // replace this with a drop
          in store THIS_LOCAL

        case CALL_METHOD(_, _) =>
          // TODO we could special case implementations of equals that are known, e.g. String#equals
          // We could turn Possible(string constants).equals(Possible(string constants) into an eq check
          // We could turn nonConstantString.equals(constantString) into constantString.equals(nonConstantString)
          //  and eliminate the null check that likely precedes this call
          val initial = dropConsumed
          (0 until inst.produced).foldLeft(initial) { case (know, _) => know push UNKNOWN }

        case BOX(_) =>
          val value = in peek 0
          // we simulate boxing by, um, boxing the possible/impossible contents
          // so if we have Possible(1,2) originally then we'll end up with
          // a Possible(Boxed(1), Boxed(2))
          // Similarly, if we know the input is not a 0 then we'll know the
          // output is not a Boxed(0)
          val newValue = value match {
            case Possible(values) => Possible(values map Boxed)
            case Impossible(values) => Impossible(values map Boxed)
          }
          dropConsumed push newValue

        case UNBOX(_) =>
          val value = in peek 0
          val newValue = value match {
            // if we have a Possible, then all the possibilities
            // should themselves be Boxes. In that
            // case we can merge them to figure out what the UNBOX will produce
            case Possible(inners) =>
              assert(inners.nonEmpty, "Empty possible set indicating an uninitialized location")
              val sanitized: Set[Contents] = (inners map {
                case Boxed(content) => SinglePossible(content)
                case _ => UNKNOWN
              })
              sanitized reduce (_ merge _)
            // if we have an impossible then the thing that's impossible
            // should be a box. We'll unbox that to see what we get
            case unknown@Impossible(inners) =>
              if (inners.isEmpty) {
                unknown
              } else {
                val sanitized: Set[Contents] = (inners map {
                  case Boxed(content) => SingleImpossible(content)
                  case _ => UNKNOWN
                })
                sanitized reduce (_ merge _)
              }
          }
          dropConsumed push newValue

        case LOAD_MODULE(_) | NEW(_) | LOAD_EXCEPTION(_) =>
          in push NOT_NULL

        case CREATE_ARRAY(_, _) =>
          dropConsumed push NOT_NULL

        case IS_INSTANCE(_) =>
          // TODO IS_INSTANCE is going to be followed by a C(Z)JUMP
          // and if IS_INSTANCE/C(Z)JUMP the branch for "true" can
          // know that whatever was checked was not a null
          // see the TODO on CJUMP for more information about propagating null
          // information
          // TODO if the top of stack is guaranteed null then we can eliminate this IS_INSTANCE check and
          // replace with a constant false, but how often is a knowable null checked for instanceof?
          // TODO we could track type information and statically know to eliminate IS_INSTANCE
          // which might be a nice win under specialization
          dropConsumed push UNKNOWN // it's actually a Possible(true, false) but since the following instruction
        // will be a conditional jump comparing to true or false there
        // nothing to be gained by being more precise

        case CHECK_CAST(_) =>
          // TODO we could track type information and statically know to eliminate CHECK_CAST
          // but that's probably not a huge win
          in

        case DUP(_) =>
          val value = in peek 0
          in push value

        case DROP(_) | MONITOR_ENTER() | MONITOR_EXIT() | STORE_ARRAY_ITEM(_) | STORE_FIELD(_, _) =>
          dropConsumed

        case SCOPE_ENTER(_) | SCOPE_EXIT(_) =>
          in

        case JUMP(_) | CJUMP(_, _, _, _) | CZJUMP(_, _, _, _) | RETURN(_) | THROW(_) | SWITCH(_, _) =>
          dumpClassesAndAbort("Unexpected block ending instruction: " + inst)
      }
    }
    /**
     * interpret the last instruction of a block which will be jump, a conditional branch, a throw, or a return.
     * It will result in a map from target blocks to the input state computed for that block. It
     * also computes a replacement list of instructions
     */
    private def interpretLast(in: State, inst: Instruction): (Map[BasicBlock, State], List[Instruction]) = {
      def canSwitch(in1: Contents, tagSet: List[Int]) = {
        in1 mightEqual Possible(tagSet.toSet map { tag: Int => Const(Constant(tag)) })
      }

      /* common code for interpreting CJUMP and CZJUMP */
      def interpretConditional(kind: TypeKind, val1: Contents, val2: Contents, success: BasicBlock, failure: BasicBlock, cond: TestOp): (Map[BasicBlock, State], List[Instruction]) = {
        // TODO use reaching analysis to update the state in the two branches
        // e.g. if the comparison was checking null equality on local x
        // then the in the success branch we know x is null and
        // on the failure branch we know it is not
        // in fact, with copy propagation we could propagate that knowledge
        // back through a chain of locations
        //
        // TODO if we do all that we need to be careful in the
        // case that success and failure are the same target block
        // because we're using a Map and don't want one possible state to clobber the other
        // alternative mayb we should just replace the conditional with a jump if both targets are the same

        def mightEqual = val1 mightEqual val2
        def mightNotEqual = val1 mightNotEqual val2
        def guaranteedEqual = mightEqual && !mightNotEqual

        def succPossible = cond match {
          case EQ => mightEqual
          case NE => mightNotEqual
          case LT | GT => !guaranteedEqual // if the two are guaranteed to be equal then they can't be LT/GT
          case LE | GE => true
        }

        def failPossible = cond match {
          case EQ => mightNotEqual
          case NE => mightEqual
          case LT | GT => true
          case LE | GE => !guaranteedEqual // if the two are guaranteed to be equal then they must be LE/GE
        }

        val out = in drop inst.consumed

        var result = Map[BasicBlock, State]()
        if (succPossible) {
          result += ((success, out))
        }

        if (failPossible) {
          result += ((failure, out))
        }

        val replacements = if (result.size == 1) List.fill(inst.consumed)(DROP(kind)) :+ JUMP(result.keySet.head)
        else inst :: Nil

        (result, replacements)
      }

      inst match {
        case JUMP(whereto) =>
          (Map((whereto, in)), inst :: Nil)

        case CJUMP(success, failure, cond, kind) =>
          val in1 = in peek 0
          val in2 = in peek 1
          interpretConditional(kind, in1, in2, success, failure, cond)

        case CZJUMP(success, failure, cond, kind) =>
          val in1 = in peek 0
          val in2 = getZeroOf(kind)
          interpretConditional(kind, in1, in2, success, failure, cond)

        case SWITCH(tags, labels) =>
          val in1 = in peek 0
          val reachableNormalLabels = tags zip labels collect { case (tagSet, label) if canSwitch(in1, tagSet) => label }
          val reachableLabels = if (tags.isEmpty) {
            assert(labels.size == 1, s"When SWITCH node has empty array of tags it should have just one (default) label: $labels")
            labels
          } else if (labels.lengthCompare(tags.length) > 0) {
            // if we've got an extra label then it's the default
            val defaultLabel = labels.last
            // see if the default is reachable by seeing if the input might be out of the set
            // of all tags
            val allTags = Possible(tags.flatten.toSet map { tag: Int => Const(Constant(tag)) })
            if (in1 mightNotEqual allTags) {
              reachableNormalLabels :+ defaultLabel
            } else {
              reachableNormalLabels
            }
          } else {
            reachableNormalLabels
          }
          // TODO similar to the comment in interpretConditional, we should update our the State going into each
          // branch based on which tag is being matched. Also, just like interpretConditional, if target blocks
          // are the same we need to merge State rather than clobber

          // alternative, maybe we should simplify the SWITCH to not have same target labels
          val newState = in drop inst.consumed
          val result = Map(reachableLabels map { label => (label, newState) }: _*)
          if (reachableLabels.size == 1) (result, DROP(INT) :: JUMP(reachableLabels.head) :: Nil)
          else (result, inst :: Nil)

        // these instructions don't have target blocks
        // (exceptions are assumed to be reachable from all instructions)
        case RETURN(_) | THROW(_) =>
          (Map.empty, inst :: Nil)

        case _ =>
          dumpClassesAndAbort("Unexpected non-block ending instruction: " + inst)
      }
    }

    /**
     * Analyze a single block to find how it transforms an input state into a states for its successor blocks
     * Also computes a list of instructions to be used to replace its last instruction
     */
    private def interpretBlock(in: State, block: BasicBlock): (Map[BasicBlock, State], Map[BasicBlock, State], List[Instruction]) = {
      debuglog(s"interpreting block $block")
      // number of instructions excluding the last one
      val normalCount = block.size - 1

      val exceptionState = in.cleanStack
      var normalExitState = in
      var idx = 0
      while (idx < normalCount) {
        val inst = block(idx)
        normalExitState = interpretInst(normalExitState, inst)
        if (normalExitState.locals ne exceptionState.locals)
          exceptionState.copy(locals = exceptionState mergeLocals normalExitState.locals)
        idx += 1
      }

      val pairs = block.exceptionSuccessors map { b => (b, exceptionState) }
      val exceptionMap = Map(pairs: _*)

      val (normalExitMap, newInstructions) = interpretLast(normalExitState, block.lastInstruction)

      (normalExitMap, exceptionMap, newInstructions)
    }

    /**
     * Analyze a single method to find replacement instructions
     */
    private def interpretMethod(m: IMethod): Map[BasicBlock, List[Instruction]] = {
      import scala.collection.mutable.{ Set => MSet, Map => MMap }

      debuglog(s"interpreting method $m")
      var iterations = 0

      // initially we know that 'this' is not null and the params are initialized to some unknown value
      val initThis: Iterator[(Local, Contents)] = if (m.isStatic) Iterator.empty else Iterator.single((THIS_LOCAL, NOT_NULL))
      val initOtherLocals: Iterator[(Local, Contents)] = m.params.iterator map { param => (param, UNKNOWN) }
      val initialLocals: Map[Local, Contents] = Map((initThis ++ initOtherLocals).toSeq: _*)
      val initialState = State(initialLocals, Nil)

      // worklist of basic blocks to process, initially the start block
      val worklist = MSet(m.startBlock)
      // worklist of exception basic blocks. They're kept in a separate set so they can be
      // processed after normal flow basic blocks. That's because exception basic blocks
      // are more likely to have multiple predecessors and queueing them for later
      // increases the chances that they'll only need to be interpreted once
      val exceptionlist = MSet[BasicBlock]()
      // our current best guess at what the input state is for each block
      // initially we only know about the start block
      val inputState = MMap[BasicBlock, State]((m.startBlock, initialState))

      // update the inputState map based on new information from interpreting a block
      // When the input state of a block changes, add it back to the work list to be
      // reinterpreted
      def updateInputStates(outputStates: Map[BasicBlock, State], worklist: MSet[BasicBlock]) {
        for ((block, newState) <- outputStates) {
          val oldState = inputState get block
          val updatedState = oldState map (x => x merge newState) getOrElse newState
          if (oldState != Some(updatedState)) {
            worklist add block
            inputState(block) = updatedState
          }
        }
      }

      // the instructions to be used as the last instructions on each block
      val replacements = MMap[BasicBlock, List[Instruction]]()

      while (worklist.nonEmpty || exceptionlist.nonEmpty) {
        if (worklist.isEmpty) {
          // once the worklist is empty, start processing exception blocks
          val block = exceptionlist.head
          exceptionlist remove block
          worklist add block
        } else {
          iterations += 1
          val block = worklist.head
          worklist remove block
          val (normalExitMap, exceptionMap, newInstructions) = interpretBlock(inputState(block), block)

          updateInputStates(normalExitMap, worklist)
          updateInputStates(exceptionMap, exceptionlist)
          replacements(block) = newInstructions
        }
      }

      debuglog(s"method $m with ${m.blocks.size} reached fixpoint in $iterations iterations")
      replacements.toMap
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy