scala.tools.nsc.backend.opt.ConstantOptimization.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of scala-compiler Show documentation
Show all versions of scala-compiler Show documentation
Compiler for the SubScript extension of the Scala Programming Language
The newest version!
/* NSC -- new Scala compiler
* Copyright 2005-2013 LAMP/EPFL
* @author James Iry
*/
package scala
package tools.nsc
package backend.opt
import scala.annotation.tailrec
/**
* ConstantOptimization uses abstract interpretation to approximate for
* each instruction what constants a variable or stack slot might hold
* or cannot hold. From this it will eliminate unreachable conditionals
* where only one branch is reachable, e.g. to eliminate unnecessary
* null checks.
*
* With some more work it could be extended to
* - cache stable values (final fields, modules) in locals
* - replace the copy propagation in ClosureElilmination
* - fold constants
* - eliminate unnecessary stores and loads
* - propagate knowledge gathered from conditionals for further optimization
*/
abstract class ConstantOptimization extends SubComponent {
import global._
import icodes._
import icodes.opcodes._
val phaseName = "constopt"
/** Create a new phase */
override def newPhase(p: Phase) = new ConstantOptimizationPhase(p)
override val enabled: Boolean = settings.YconstOptimization
/**
* The constant optimization phase.
*/
class ConstantOptimizationPhase(prev: Phase) extends ICodePhase(prev) {
def name = phaseName
override def apply(c: IClass) {
if (settings.YconstOptimization) {
val analyzer = new ConstantOptimizer
analyzer optimizeClass c
}
}
}
class ConstantOptimizer {
def optimizeClass(cls: IClass) {
log(s"Analyzing ${cls.methods.size} methods in $cls.")
cls.methods foreach { m =>
optimizeMethod(m)
}
}
def optimizeMethod(m: IMethod) {
if (m.hasCode) {
log(s"Analyzing ${m.symbol}")
val replacementInstructions = interpretMethod(m)
for (block <- m.blocks) {
if (replacementInstructions contains block) {
val instructions = replacementInstructions(block)
block.replaceInstruction(block.lastInstruction, instructions)
}
}
}
}
/**
* A single possible (or impossible) datum that can be held in Contents
*/
private sealed abstract class Datum
/**
* A constant datum
*/
private case class Const(c: Constant) extends Datum {
def isIntAssignable = c.tag >= BooleanTag && c.tag <= IntTag
def toInt = c.tag match {
case BooleanTag => if (c.booleanValue) 1 else 0
case _ => c.intValue
}
/**
* True if this constant would compare to other as true under primitive eq
*/
override def equals(other: Any) = other match {
case oc @ Const(o) => (this eq oc) || (if (this.isIntAssignable && oc.isIntAssignable) this.toInt == oc.toInt else c.value == o.value)
case _ => false
}
/**
* Hash code consistent with equals
*/
override def hashCode = if (this.isIntAssignable) this.toInt else c.hashCode
}
/**
* A datum that has been Boxed via a BOX instruction
*/
private case class Boxed(c: Datum) extends Datum
/**
* The knowledge we have about the abstract state of one location in terms
* of what constants it might or cannot hold. Forms a lower
* lattice where lower elements in the lattice indicate less knowledge.
*
* With the following partial ordering (where '>' indicates more precise knowledge)
*
* Possible(xs) > Possible(xs + y)
* Possible(xs) > Impossible(ys)
* Impossible(xs + y) > Impossible(xs)
*
* and the following merges, which indicate merging knowledge from two paths through
* the code,
*
* // left must be 1 or 2, right must be 2 or 3 then we must have a 1, 2 or 3
* Possible(xs) merge Possible(ys) => Possible(xs union ys)
*
* // Left says can't be 2 or 3, right says can't be 3 or 4
* // then it's not 3 (it could be 2 from the right or 4 from the left)
* Impossible(xs) merge Impossible(ys) => Impossible(xs intersect ys)
*
* // Left says it can't be 2 or 3, right says it must be 3 or 4, then
* // it can't be 2 (left rules out 4 and right says 3 is possible)
* Impossible(xs) merge Possible(ys) => Impossible(xs -- ys)
*
* Intuitively, Possible(empty) says that a location can't hold anything,
* it's uninitialized. However, Possible(empty) never appears in the code.
*
* Conversely, Impossible(empty) says nothing is impossible, it could be
* anything. Impossible(empty) is given a synonym UNKNOWN and is used
* for, e.g., the result of an arbitrary method call.
*/
private sealed abstract class Contents {
/**
* Join this Contents with another coming from another path. Join enforces
* the lattice structure. It is symmetrical and never moves upward in the
* lattice
*/
final def merge(other: Contents): Contents = if (this eq other) this else (this, other) match {
case (Possible(possible1), Possible(possible2)) =>
Possible(possible1 union possible2)
case (Impossible(impossible1), Impossible(impossible2)) =>
Impossible(impossible1 intersect impossible2)
case (Impossible(impossible), Possible(possible)) =>
Impossible(impossible -- possible)
case (Possible(possible), Impossible(impossible)) =>
Impossible(impossible -- possible)
}
// TODO we could have more fine-grained knowledge, e.g. know that 0 < x < 3. But for now equality/inequality is a good start.
def mightEqual(other: Contents): Boolean
def mightNotEqual(other: Contents): Boolean
}
private def SingleImpossible(x: Datum) = new Impossible(Set(x))
/**
* The location is known to have one of a set of values.
*/
private case class Possible(possible: Set[Datum]) extends Contents {
assert(possible.nonEmpty, "Contradiction: had an empty possible set indicating an uninitialized location")
def mightEqual(other: Contents): Boolean = (this eq other) || (other match {
// two Possibles might be equal if they have any possible members in common
case Possible(possible2) => (possible intersect possible2).nonEmpty
// a possible can be equal to an impossible if the impossible doesn't rule
// out all the possibilities
case Impossible(possible2) => (possible -- possible2).nonEmpty
})
def mightNotEqual(other: Contents): Boolean = (this ne other) && (other match {
// two Possibles might not be equal if either has possible members that the other doesn't
case Possible(possible2) => (possible -- possible2).nonEmpty || (possible2 -- possible).nonEmpty
case Impossible(_) => true
})
}
private def SinglePossible(x: Datum) = new Possible(Set(x))
/**
* The location is known to not have any of a set of values value (e.g null).
*/
private case class Impossible(impossible: Set[Datum]) extends Contents {
def mightEqual(other: Contents): Boolean = (this eq other) || (other match {
case Possible(_) => other mightEqual this
case _ => true
})
def mightNotEqual(other: Contents): Boolean = (this eq other) || (other match {
case Possible(_) => other mightNotEqual this
case _ => true
})
}
/**
* Our entire knowledge about the contents of all variables and the stack. It forms
* a lattice primarily driven by the lattice structure of Contents.
*
* In addition to the rules of contents, State has the following properties:
* - The merge of two sets of locals holds the merges of locals found in the intersection
* of the two sets of locals. Locals not found in a
* locals map are thus possibly uninitialized and attempting to load them results
* in an error.
* - The stack heights of two states must match otherwise it's an error to merge them
*
* State is immutable in order to aid in structure sharing of local maps and stacks
*/
private case class State(locals: Map[Local, Contents], stack: List[Contents]) {
def mergeLocals(olocals: Map[Local, Contents]): Map[Local, Contents] = if (locals eq olocals) locals else Map((for {
key <- (locals.keySet intersect olocals.keySet).toSeq
} yield (key, locals(key) merge olocals(key))): _*)
def merge(other: State): State = if (this eq other) this else {
@tailrec def mergeStacks(l: List[Contents], r: List[Contents], out: List[Contents]): List[Contents] = (l, r) match {
case (Nil, Nil) => out.reverse
case (l, r) if l eq r => out.reverse ++ l
case (lhead :: ltail, rhead :: rtail) => mergeStacks(ltail, rtail, (lhead merge rhead) :: out)
case _ => sys.error("Mismatched stack heights")
}
val newLocals = mergeLocals(other.locals)
val newStack = if (stack eq other.stack) stack else mergeStacks(stack, other.stack, Nil)
State(newLocals, newStack)
}
/**
* Peek at the top of the stack without modifying it. Error if the stack is empty
*/
def peek(n: Int): Contents = stack(n)
/**
* Push contents onto a stack
*/
def push(contents: Contents): State = this copy (stack = contents :: stack)
/**
* Drop n elements from the stack
*/
def drop(number: Int): State = this copy (stack = stack drop number)
/**
* Store the top of the stack into the specified local. An error if the stack
* is empty
*/
def store(variable: Local): State = {
val contents = stack.head
val newVariables = locals + ((variable, contents))
new State(newVariables, stack.tail)
}
/**
* Load the specified local onto the top of the stack. An error the the local is uninitialized.
*/
def load(variable: Local): State = {
val contents: Contents = locals.getOrElse(variable, sys.error(s"$variable is not initialized"))
push(contents)
}
/**
* A copy of this State with an empty stack
*/
def cleanStack: State = if (stack.isEmpty) this else this copy (stack = Nil)
}
// some precomputed constants
private val NULL = Const(Constant(null: Any))
private val UNKNOWN = Impossible(Set.empty)
private val NOT_NULL = SingleImpossible(NULL)
private val CONST_UNIT = SinglePossible(Const(Constant(())))
private val CONST_FALSE = SinglePossible(Const(Constant(false)))
private val CONST_ZERO_BYTE = SinglePossible(Const(Constant(0: Byte)))
private val CONST_ZERO_SHORT = SinglePossible(Const(Constant(0: Short)))
private val CONST_ZERO_CHAR = SinglePossible(Const(Constant(0: Char)))
private val CONST_ZERO_INT = SinglePossible(Const(Constant(0: Int)))
private val CONST_ZERO_LONG = SinglePossible(Const(Constant(0: Long)))
private val CONST_ZERO_FLOAT = SinglePossible(Const(Constant(0.0f)))
private val CONST_ZERO_DOUBLE = SinglePossible(Const(Constant(0.0d)))
private val CONST_NULL = SinglePossible(NULL)
/**
* Given a TypeKind, figure out what '0' for it means in order to interpret CZJUMP
*/
private def getZeroOf(k: TypeKind): Contents = k match {
case UNIT => CONST_UNIT
case BOOL => CONST_FALSE
case BYTE => CONST_ZERO_BYTE
case SHORT => CONST_ZERO_SHORT
case CHAR => CONST_ZERO_CHAR
case INT => CONST_ZERO_INT
case LONG => CONST_ZERO_LONG
case FLOAT => CONST_ZERO_FLOAT
case DOUBLE => CONST_ZERO_DOUBLE
case REFERENCE(_) => CONST_NULL
case ARRAY(_) => CONST_NULL
case BOXED(_) => CONST_NULL
case ConcatClass => abort("no zero of ConcatClass")
}
// normal locals can't be null, so we use null to mean the magic 'this' local
private val THIS_LOCAL: Local = null
/**
* interpret a single instruction to find its impact on the abstract state
*/
private def interpretInst(in: State, inst: Instruction): State = {
// pop the consumed number of values off the `in` state's stack, producing a new state
def dropConsumed: State = in drop inst.consumed
inst match {
case THIS(_) =>
in load THIS_LOCAL
case CONSTANT(k) =>
// treat NaN as UNKNOWN because NaN must never equal NaN
val const = if (k.isNaN) UNKNOWN
else SinglePossible(Const(k))
in push const
case LOAD_ARRAY_ITEM(_) | LOAD_FIELD(_, _) | CALL_PRIMITIVE(_) =>
dropConsumed push UNKNOWN
case LOAD_LOCAL(local) =>
// TODO if a local is known to hold a constant then we can replace this instruction with a push of that constant
in load local
case STORE_LOCAL(local) =>
in store local
case STORE_THIS(_) =>
// if a local is already known to have a constant and we're replacing with the same constant then we can
// replace this with a drop
in store THIS_LOCAL
case CALL_METHOD(_, _) =>
// TODO we could special case implementations of equals that are known, e.g. String#equals
// We could turn Possible(string constants).equals(Possible(string constants) into an eq check
// We could turn nonConstantString.equals(constantString) into constantString.equals(nonConstantString)
// and eliminate the null check that likely precedes this call
val initial = dropConsumed
(0 until inst.produced).foldLeft(initial) { case (know, _) => know push UNKNOWN }
case BOX(_) =>
val value = in peek 0
// we simulate boxing by, um, boxing the possible/impossible contents
// so if we have Possible(1,2) originally then we'll end up with
// a Possible(Boxed(1), Boxed(2))
// Similarly, if we know the input is not a 0 then we'll know the
// output is not a Boxed(0)
val newValue = value match {
case Possible(values) => Possible(values map Boxed)
case Impossible(values) => Impossible(values map Boxed)
}
dropConsumed push newValue
case UNBOX(_) =>
val value = in peek 0
val newValue = value match {
// if we have a Possible, then all the possibilities
// should themselves be Boxes. In that
// case we can merge them to figure out what the UNBOX will produce
case Possible(inners) =>
assert(inners.nonEmpty, "Empty possible set indicating an uninitialized location")
val sanitized: Set[Contents] = (inners map {
case Boxed(content) => SinglePossible(content)
case _ => UNKNOWN
})
sanitized reduce (_ merge _)
// if we have an impossible then the thing that's impossible
// should be a box. We'll unbox that to see what we get
case unknown@Impossible(inners) =>
if (inners.isEmpty) {
unknown
} else {
val sanitized: Set[Contents] = (inners map {
case Boxed(content) => SingleImpossible(content)
case _ => UNKNOWN
})
sanitized reduce (_ merge _)
}
}
dropConsumed push newValue
case LOAD_MODULE(_) | NEW(_) | LOAD_EXCEPTION(_) =>
in push NOT_NULL
case CREATE_ARRAY(_, _) =>
dropConsumed push NOT_NULL
case IS_INSTANCE(_) =>
// TODO IS_INSTANCE is going to be followed by a C(Z)JUMP
// and if IS_INSTANCE/C(Z)JUMP the branch for "true" can
// know that whatever was checked was not a null
// see the TODO on CJUMP for more information about propagating null
// information
// TODO if the top of stack is guaranteed null then we can eliminate this IS_INSTANCE check and
// replace with a constant false, but how often is a knowable null checked for instanceof?
// TODO we could track type information and statically know to eliminate IS_INSTANCE
// which might be a nice win under specialization
dropConsumed push UNKNOWN // it's actually a Possible(true, false) but since the following instruction
// will be a conditional jump comparing to true or false there
// nothing to be gained by being more precise
case CHECK_CAST(_) =>
// TODO we could track type information and statically know to eliminate CHECK_CAST
// but that's probably not a huge win
in
case DUP(_) =>
val value = in peek 0
in push value
case DROP(_) | MONITOR_ENTER() | MONITOR_EXIT() | STORE_ARRAY_ITEM(_) | STORE_FIELD(_, _) =>
dropConsumed
case SCOPE_ENTER(_) | SCOPE_EXIT(_) =>
in
case JUMP(_) | CJUMP(_, _, _, _) | CZJUMP(_, _, _, _) | RETURN(_) | THROW(_) | SWITCH(_, _) =>
dumpClassesAndAbort("Unexpected block ending instruction: " + inst)
}
}
/**
* interpret the last instruction of a block which will be jump, a conditional branch, a throw, or a return.
* It will result in a map from target blocks to the input state computed for that block. It
* also computes a replacement list of instructions
*/
private def interpretLast(in: State, inst: Instruction): (Map[BasicBlock, State], List[Instruction]) = {
def canSwitch(in1: Contents, tagSet: List[Int]) = {
in1 mightEqual Possible(tagSet.toSet map { tag: Int => Const(Constant(tag)) })
}
/* common code for interpreting CJUMP and CZJUMP */
def interpretConditional(kind: TypeKind, val1: Contents, val2: Contents, success: BasicBlock, failure: BasicBlock, cond: TestOp): (Map[BasicBlock, State], List[Instruction]) = {
// TODO use reaching analysis to update the state in the two branches
// e.g. if the comparison was checking null equality on local x
// then the in the success branch we know x is null and
// on the failure branch we know it is not
// in fact, with copy propagation we could propagate that knowledge
// back through a chain of locations
//
// TODO if we do all that we need to be careful in the
// case that success and failure are the same target block
// because we're using a Map and don't want one possible state to clobber the other
// alternative mayb we should just replace the conditional with a jump if both targets are the same
def mightEqual = val1 mightEqual val2
def mightNotEqual = val1 mightNotEqual val2
def guaranteedEqual = mightEqual && !mightNotEqual
def succPossible = cond match {
case EQ => mightEqual
case NE => mightNotEqual
case LT | GT => !guaranteedEqual // if the two are guaranteed to be equal then they can't be LT/GT
case LE | GE => true
}
def failPossible = cond match {
case EQ => mightNotEqual
case NE => mightEqual
case LT | GT => true
case LE | GE => !guaranteedEqual // if the two are guaranteed to be equal then they must be LE/GE
}
val out = in drop inst.consumed
var result = Map[BasicBlock, State]()
if (succPossible) {
result += ((success, out))
}
if (failPossible) {
result += ((failure, out))
}
val replacements = if (result.size == 1) List.fill(inst.consumed)(DROP(kind)) :+ JUMP(result.keySet.head)
else inst :: Nil
(result, replacements)
}
inst match {
case JUMP(whereto) =>
(Map((whereto, in)), inst :: Nil)
case CJUMP(success, failure, cond, kind) =>
val in1 = in peek 0
val in2 = in peek 1
interpretConditional(kind, in1, in2, success, failure, cond)
case CZJUMP(success, failure, cond, kind) =>
val in1 = in peek 0
val in2 = getZeroOf(kind)
interpretConditional(kind, in1, in2, success, failure, cond)
case SWITCH(tags, labels) =>
val in1 = in peek 0
val reachableNormalLabels = tags zip labels collect { case (tagSet, label) if canSwitch(in1, tagSet) => label }
val reachableLabels = if (tags.isEmpty) {
assert(labels.size == 1, s"When SWITCH node has empty array of tags it should have just one (default) label: $labels")
labels
} else if (labels.lengthCompare(tags.length) > 0) {
// if we've got an extra label then it's the default
val defaultLabel = labels.last
// see if the default is reachable by seeing if the input might be out of the set
// of all tags
val allTags = Possible(tags.flatten.toSet map { tag: Int => Const(Constant(tag)) })
if (in1 mightNotEqual allTags) {
reachableNormalLabels :+ defaultLabel
} else {
reachableNormalLabels
}
} else {
reachableNormalLabels
}
// TODO similar to the comment in interpretConditional, we should update our the State going into each
// branch based on which tag is being matched. Also, just like interpretConditional, if target blocks
// are the same we need to merge State rather than clobber
// alternative, maybe we should simplify the SWITCH to not have same target labels
val newState = in drop inst.consumed
val result = Map(reachableLabels map { label => (label, newState) }: _*)
if (reachableLabels.size == 1) (result, DROP(INT) :: JUMP(reachableLabels.head) :: Nil)
else (result, inst :: Nil)
// these instructions don't have target blocks
// (exceptions are assumed to be reachable from all instructions)
case RETURN(_) | THROW(_) =>
(Map.empty, inst :: Nil)
case _ =>
dumpClassesAndAbort("Unexpected non-block ending instruction: " + inst)
}
}
/**
* Analyze a single block to find how it transforms an input state into a states for its successor blocks
* Also computes a list of instructions to be used to replace its last instruction
*/
private def interpretBlock(in: State, block: BasicBlock): (Map[BasicBlock, State], Map[BasicBlock, State], List[Instruction]) = {
debuglog(s"interpreting block $block")
// number of instructions excluding the last one
val normalCount = block.size - 1
val exceptionState = in.cleanStack
var normalExitState = in
var idx = 0
while (idx < normalCount) {
val inst = block(idx)
normalExitState = interpretInst(normalExitState, inst)
if (normalExitState.locals ne exceptionState.locals)
exceptionState.copy(locals = exceptionState mergeLocals normalExitState.locals)
idx += 1
}
val pairs = block.exceptionSuccessors map { b => (b, exceptionState) }
val exceptionMap = Map(pairs: _*)
val (normalExitMap, newInstructions) = interpretLast(normalExitState, block.lastInstruction)
(normalExitMap, exceptionMap, newInstructions)
}
/**
* Analyze a single method to find replacement instructions
*/
private def interpretMethod(m: IMethod): Map[BasicBlock, List[Instruction]] = {
import scala.collection.mutable.{ Set => MSet, Map => MMap }
debuglog(s"interpreting method $m")
var iterations = 0
// initially we know that 'this' is not null and the params are initialized to some unknown value
val initThis: Iterator[(Local, Contents)] = if (m.isStatic) Iterator.empty else Iterator.single((THIS_LOCAL, NOT_NULL))
val initOtherLocals: Iterator[(Local, Contents)] = m.params.iterator map { param => (param, UNKNOWN) }
val initialLocals: Map[Local, Contents] = Map((initThis ++ initOtherLocals).toSeq: _*)
val initialState = State(initialLocals, Nil)
// worklist of basic blocks to process, initially the start block
val worklist = MSet(m.startBlock)
// worklist of exception basic blocks. They're kept in a separate set so they can be
// processed after normal flow basic blocks. That's because exception basic blocks
// are more likely to have multiple predecessors and queueing them for later
// increases the chances that they'll only need to be interpreted once
val exceptionlist = MSet[BasicBlock]()
// our current best guess at what the input state is for each block
// initially we only know about the start block
val inputState = MMap[BasicBlock, State]((m.startBlock, initialState))
// update the inputState map based on new information from interpreting a block
// When the input state of a block changes, add it back to the work list to be
// reinterpreted
def updateInputStates(outputStates: Map[BasicBlock, State], worklist: MSet[BasicBlock]) {
for ((block, newState) <- outputStates) {
val oldState = inputState get block
val updatedState = oldState map (x => x merge newState) getOrElse newState
if (oldState != Some(updatedState)) {
worklist add block
inputState(block) = updatedState
}
}
}
// the instructions to be used as the last instructions on each block
val replacements = MMap[BasicBlock, List[Instruction]]()
while (worklist.nonEmpty || exceptionlist.nonEmpty) {
if (worklist.isEmpty) {
// once the worklist is empty, start processing exception blocks
val block = exceptionlist.head
exceptionlist remove block
worklist add block
} else {
iterations += 1
val block = worklist.head
worklist remove block
val (normalExitMap, exceptionMap, newInstructions) = interpretBlock(inputState(block), block)
updateInputStates(normalExitMap, worklist)
updateInputStates(exceptionMap, exceptionlist)
replacements(block) = newInstructions
}
}
debuglog(s"method $m with ${m.blocks.size} reached fixpoint in $iterations iterations")
replacements.toMap
}
}
}