Lib.Krakatau.java.javamethod.py Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of krakatau-lib Show documentation
Show all versions of krakatau-lib Show documentation
Assembler, disassebmler, decompiler and compiler tools library for Java.
import collections
import operator
from functools import partial
from ..ssa import objtypes
from .. import graph_util
from ..namegen import NameGen, LabelGen
from ..verifier.descriptors import parseMethodDescriptor
from . import ast, ast2, boolize
from . import graphproxy, structuring, astgen
class DeclInfo(object):
__slots__ = "declScope scope defs".split()
def __init__(self):
self.declScope = self.scope = None
self.defs = []
def findVarDeclInfo(root, predeclared):
info = collections.OrderedDict()
def visit(scope, expr):
for param in expr.params:
visit(scope, param)
if expr.isLocalAssign():
left, right = expr.params
info[left].defs.append(right)
elif isinstance(expr, (ast.Local, ast.Literal)):
#this would be so much nicer if we had Ordered defaultdicts
info.setdefault(expr, DeclInfo())
info[expr].scope = ast.StatementBlock.join(info[expr].scope, scope)
def visitDeclExpr(scope, expr):
info.setdefault(expr, DeclInfo())
assert(scope is not None and info[expr].declScope is None)
info[expr].declScope = scope
for expr in predeclared:
visitDeclExpr(root, expr)
stack = [(root,root)]
while stack:
scope, stmt = stack.pop()
if isinstance(stmt, ast.StatementBlock):
stack.extend((stmt,sub) for sub in stmt.statements)
else:
stack.extend((subscope,subscope) for subscope in stmt.getScopes())
#temp hack
if stmt.expr is not None:
visit(scope, stmt.expr)
if isinstance(stmt, ast.TryStatement):
for catchdecl, body in stmt.pairs:
visitDeclExpr(body, catchdecl.local)
return info
def reverseBoolExpr(expr):
assert(expr.dtype == objtypes.BoolTT)
if isinstance(expr, ast.BinaryInfix):
symbols = "== != < >= > <=".split()
floatts = (objtypes.FloatTT, objtypes.DoubleTT)
if expr.opstr in symbols:
sym2 = symbols[symbols.index(expr.opstr) ^ 1]
left, right = expr.params
#be sure not to reverse floating point comparisons since it's not equivalent for NaN
if expr.opstr in symbols[:2] or (left.dtype not in floatts and right.dtype not in floatts):
return ast.BinaryInfix(sym2, (left,right), objtypes.BoolTT)
elif isinstance(expr, ast.UnaryPrefix) and expr.opstr == '!':
return expr.params[0]
return ast.UnaryPrefix('!', expr)
def getSubscopeIter(root):
stack = [root]
while stack:
scope = stack.pop()
if isinstance(scope, ast.StatementBlock):
stack.extend(scope.statements)
yield scope
else:
stack.extend(scope.getScopes())
def mayBreakTo(root, forbidden):
assert(None not in forbidden)
for scope in getSubscopeIter(root):
if scope.jumpKey in forbidden:
#We return true if scope has forbidden jump and is reachable
#We assume there is no unreachable code, so in order for a scope
#jump to be unreachable, it must end in a return, throw, or a
#compound statement, all of which are not reachable or do not
#break out of the statement. We omit adding last.breakKey to
#forbidden since it should always match scope.jumpKey anyway
if not scope.statements:
return True
last = scope.statements[-1]
if not last.getScopes():
if not isinstance(last, (ast.ReturnStatement, ast.ThrowStatement)):
return True
else:
#If and switch statements may allow fallthrough
#A while statement with condition may break implicitly
if isinstance(last, ast.IfStatement) and len(last.getScopes()) == 1:
return True
if isinstance(last, ast.SwitchStatement) and not last.hasDefault():
return True
if isinstance(last, ast.WhileStatement) and last.expr != ast.Literal.TRUE:
return True
if not isinstance(last, ast.WhileStatement):
for sub in last.getScopes():
assert(sub.breakKey == last.breakKey == scope.jumpKey)
return False
def replaceKeys(top, replace):
assert(None not in replace)
get = lambda k:replace.get(k,k)
if top.getScopes():
if isinstance(top, ast.StatementBlock) and get(top.breakKey) is None:
#breakkey can be None with non-None jumpkey when we're a scope in a switch statement that falls through
#and the end of the switch statement is unreachable
assert(get(top.jumpKey) is None or not top.labelable)
top.breakKey = get(top.breakKey)
if isinstance(top, ast.StatementBlock):
top.jumpKey = get(top.jumpKey)
for item in top.statements:
replaceKeys(item, replace)
else:
for scope in top.getScopes():
replaceKeys(scope, replace)
NONE_SET = frozenset([None])
def _preorder(scope, func):
newitems = []
for i, item in enumerate(scope.statements):
for sub in item.getScopes():
_preorder(sub, func)
val = func(scope, item)
vals = [item] if val is None else val
newitems.extend(vals)
scope.statements = newitems
def _fixObjectCreations(scope, item):
'''Combines new/invokeinit pairs into Java constructor calls'''
#Thanks to the copy propagation pass prior to AST generation, as well as the fact that
#unitialized types never merge, we can safely assume there are no copies to worry about
expr = item.expr
if isinstance(expr, ast.Assignment):
left, right = expr.params
if isinstance(right, ast.Dummy) and right.isNew:
return [] #remove item
elif isinstance(expr, ast.MethodInvocation) and expr.name == '':
left = expr.params[0]
newexpr = ast.ClassInstanceCreation(ast.TypeName(left.dtype), expr.tts[1:], expr.params[1:])
item.expr = ast.Assignment(left, newexpr)
def _pruneRethrow_cb(item):
'''Convert try{A} catch(T e) {throw t;} to {A}'''
while item.pairs:
decl, body = item.pairs[-1]
caught, lines = decl.local, body.statements
if len(lines) == 1:
line = lines[0]
if isinstance(line, ast.ThrowStatement) and line.expr == caught:
item.pairs = item.pairs[:-1]
continue
break
if not item.pairs:
new = item.tryb
assert(new.breakKey == item.breakKey)
assert(new.continueKey == item.continueKey)
assert(not new.labelable)
new.labelable = True
return new
return item
def _pruneIfElse_cb(item):
'''Convert if(A) {B} else {} to if(A) {B}'''
if len(item.scopes) > 1:
tblock, fblock = item.scopes
#if true block is empty, swap it with false so we can remove it
if not tblock.statements and tblock.doesFallthrough():
item.expr = reverseBoolExpr(item.expr)
tblock, fblock = fblock, tblock
item.scopes = tblock, fblock
if not fblock.statements and fblock.doesFallthrough():
item.scopes = tblock,
# If cond is !(x), reverse it back to simplify cond
elif isinstance(item.expr, ast.UnaryPrefix) and item.expr.opstr == '!':
item.expr = reverseBoolExpr(item.expr)
item.scopes = fblock, tblock
# if(A) {if(B) {C}} -> if(A && B) {C}
tblock = item.scopes[0]
if len(item.scopes) == 1 and len(tblock.statements) == 1 and tblock.doesFallthrough():
first = tblock.statements[0]
if isinstance(first, ast.IfStatement) and len(first.scopes) == 1:
item.expr = ast.BinaryInfix('&&',[item.expr, first.expr], objtypes.BoolTT)
item.scopes = first.scopes
return item
def _whileCondition_cb(item):
'''Convert while(true) {if(A) {B break;} else {C} D} to while(!A) {{C} D} {B}'''
failure = [], item #what to return if we didn't inline
body = item.getScopes()[0]
if not body.statements or not isinstance(body.statements[0], ast.IfStatement):
return failure
head = body.statements[0]
cond = head.expr
trueb, falseb = (head.getScopes() + (None,))[:2]
#Make sure it doesn't continue the loop or break out of the if statement
badjumps1 = frozenset([head.breakKey, item.continueKey]) - NONE_SET
if mayBreakTo(trueb, badjumps1):
if falseb is not None and not mayBreakTo(falseb, badjumps1):
cond = reverseBoolExpr(cond)
trueb, falseb = falseb, trueb
else:
return failure
assert(not mayBreakTo(trueb, badjumps1))
trivial = not trueb.statements and trueb.jumpKey == item.breakKey
#If we already have a condition, only a simple break is allowed
if not trivial and item.expr != ast.Literal.TRUE:
return failure
#If break body is nontrival, we can't insert this after the end of the loop unless
#We're sure that nothing else in the loop breaks out
badjumps2 = frozenset([item.breakKey]) - NONE_SET
if not trivial:
restloop = [falseb] if falseb is not None else []
restloop += body.statements[1:]
if body.jumpKey == item.breakKey or any(mayBreakTo(s, badjumps2) for s in restloop):
return failure
#Now inline everything
item.expr = _simplifyExpressions(ast.BinaryInfix('&&', [item.expr, reverseBoolExpr(cond)]))
if falseb is None:
body.statements.pop(0)
else:
body.statements[0] = falseb
falseb.labelable = True
trueb.labelable = True
if item.breakKey is None: #Make sure to maintain invariant that bkey=None -> jkey=None
assert(trueb.doesFallthrough())
trueb.jumpKey = trueb.breakKey = None
trueb.breakKey = item.breakKey
assert(trueb.continueKey is not None)
if not trivial:
item.breakKey = trueb.continueKey
#Trueb doesn't break to head.bkey but there might be unreacahble jumps, so we replace
#it too. We don't replace item.ckey because it should never appear, even as an
#unreachable jump
replaceKeys(trueb, {head.breakKey:trueb.breakKey, item.breakKey:trueb.breakKey})
return [item], trueb
def _simplifyBlocksSub(scope, item, isLast):
rest = []
if isinstance(item, ast.TryStatement):
item = _pruneRethrow_cb(item)
elif isinstance(item, ast.IfStatement):
item = _pruneIfElse_cb(item)
elif isinstance(item, ast.WhileStatement):
rest, item = _whileCondition_cb(item)
if isinstance(item, ast.StatementBlock):
assert(item.breakKey is not None or item.jumpKey is None)
#If bkey is None, it can't be broken to
#If contents can also break to enclosing scope, it's always safe to inline
bkey = item.breakKey
if bkey is None or (bkey == scope.breakKey and scope.labelable):
rest, item.statements = rest + item.statements, []
for sub in item.statements[:]:
if sub.getScopes() and sub.breakKey != bkey and mayBreakTo(sub, frozenset([bkey])):
break
rest.append(item.statements.pop(0))
if not item.statements:
if item.jumpKey != bkey:
assert(isLast)
scope.jumpKey = item.jumpKey
assert(scope.breakKey is not None or scope.jumpKey is None)
return rest
return rest + [item]
def _simplifyBlocks(scope):
newitems = []
for item in reversed(scope.statements):
isLast = not newitems #may be true if all subsequent items pruned
if isLast and item.getScopes():
if item.breakKey != scope.jumpKey:# and item.breakKey is not None:
# print 'sib replace', scope, item, item.breakKey, scope.jumpKey
replaceKeys(item, {item.breakKey: scope.jumpKey})
for sub in reversed(item.getScopes()):
_simplifyBlocks(sub)
vals = _simplifyBlocksSub(scope, item, isLast)
newitems += reversed(vals)
scope.statements = newitems[::-1]
_op2bits = {'==':2, '!=':13, '<':1, '<=':3, '>':4, '>=':6}
_bit2ops_float = {v:k for k,v in _op2bits.items()}
_bit2ops = {(v & 7):k for k,v in _op2bits.items()}
def _getBitfield(expr):
if isinstance(expr, ast.BinaryInfix):
if expr.opstr in ('==','!=','<','<=','>','>='):
# We don't want to merge expressions if they could have side effects
# so only allow literals and locals
if all(isinstance(p, (ast.Literal, ast.Local)) for p in expr.params):
return _op2bits[expr.opstr], tuple(expr.params)
elif expr.opstr in ('&','&&','|','||'):
bits1, args1 = _getBitfield(expr.params[0])
bits2, args2 = _getBitfield(expr.params[1])
if args1 == args2:
bits = (bits1 & bits2) if '&' in expr.opstr else (bits1 | bits2)
return bits, args1
elif isinstance(expr, ast.UnaryPrefix) and expr.opstr == '!':
bits, args = _getBitfield(expr.params[0])
return ~bits, args
return 0, None
def _mergeComparisons(expr):
# a <= b && a != b -> a < b, etc.
bits, args = _getBitfield(expr)
if args is None:
return expr
assert(not hasSideEffects(args[0]) and not hasSideEffects(args[1]))
if args[0].dtype in (objtypes.FloatTT, objtypes.DoubleTT):
mask, d = 15, _bit2ops_float
else:
mask, d = 7, _bit2ops
bits &= mask
notbits = (~bits) & mask
if bits == 0:
return ast.Literal.TRUE
elif notbits == 0:
return ast.Literal.FALSE
elif bits in d:
return ast.BinaryInfix(d[bits], args, objtypes.BoolTT)
elif notbits in d:
return ast.UnaryPrefix('!', ast.BinaryInfix(d[notbits], args, objtypes.BoolTT))
return expr
def _simplifyExpressions(expr):
TRUE, FALSE = ast.Literal.TRUE, ast.Literal.FALSE
bools = {True:TRUE, False:FALSE}
opfuncs = {'<': operator.lt, '<=': operator.le, '>': operator.gt, '>=': operator.ge}
simplify = _simplifyExpressions
expr.params = map(simplify, expr.params)
if isinstance(expr, ast.BinaryInfix):
left, right = expr.params
op = expr.opstr
if op in ('==','!=','<','<=','>','>=') and isinstance(right, ast.Literal):
# la cmp lb -> result (i.e. constant propagation on literal comparisons)
if isinstance(left, ast.Literal):
if op in ('==','!='):
#these could be string or class literals, but those are always nonnull so it still works
res = (left == right) == (op == '==')
else:
assert(left.dtype == right.dtype)
res = opfuncs[op](left.val, right.val)
expr = bools[res]
# (a ? lb : c) cmp ld -> a ? (lb cmp ld) : (c cmp ld)
elif isinstance(left, ast.Ternary) and isinstance(left.params[1], ast.Literal):
left.params[1] = simplify(ast.BinaryInfix(op, [left.params[1], right], expr._dtype))
left.params[2] = simplify(ast.BinaryInfix(op, [left.params[2], right], expr._dtype))
expr = left
# a ? true : b -> a || b
# a ? false : b -> !a && b
if isinstance(expr, ast.Ternary) and expr.dtype == objtypes.BoolTT:
cond, val1, val2 = expr.params
if not isinstance(val1, ast.Literal): #try to get bool literal to the front
cond, val1, val2 = reverseBoolExpr(cond), val2, val1
if val1 == TRUE:
expr = ast.BinaryInfix('||', [cond, val2], objtypes.BoolTT)
elif val1 == FALSE:
expr = ast.BinaryInfix('&&', [reverseBoolExpr(cond), val2], objtypes.BoolTT)
# true && a -> a, etc.
if isinstance(expr, ast.BinaryInfix) and expr.opstr in ('&&','||'):
left, right = expr.params
if expr.opstr == '&&':
if left == TRUE or (right == FALSE and not hasSideEffects(left)):
expr = right
elif left == FALSE or right == TRUE:
expr = left
else:
if left == TRUE or right == FALSE:
expr = left
elif left == FALSE or (right == TRUE and not hasSideEffects(left)):
expr = right
# a > b || a == b -> a >= b, etc.
expr = _mergeComparisons(expr)
# a == true -> a
# a == false -> !a
if isinstance(expr, ast.BinaryInfix) and expr.opstr in ('==, !=') and expr.params[0].dtype == objtypes.BoolTT:
left, right = expr.params
if not isinstance(left, ast.Literal): #try to get bool literal to the front
left, right = right, left
if isinstance(left, ast.Literal):
flip = (left == TRUE) != (expr.opstr == '==')
expr = reverseBoolExpr(right) if flip else right
# !a ? b : c -> a ? c : b
if isinstance(expr, ast.Ternary) and isinstance(expr.params[0], ast.UnaryPrefix):
cond, val1, val2 = expr.params
if cond.opstr == '!':
expr.params = [reverseBoolExpr(cond), val2, val1]
# 0 - a -> -a
if isinstance(expr, ast.BinaryInfix) and expr.opstr == '-':
if expr.params[0] == ast.Literal.LZERO:
expr = ast.UnaryPrefix('-', expr.params[1])
return expr
def _setScopeParents(scope):
for item in scope.statements:
for sub in item.getScopes():
sub.bases = scope.bases + (sub,)
_setScopeParents(sub)
def _replaceExpressions(scope, item, rdict):
#Must be done before local declarations are created since it doesn't touch/remove them
if item.expr is not None:
item.expr = item.expr.replaceSubExprs(rdict)
#remove redundant assignments i.e. x=x;
if isinstance(item.expr, ast.Assignment):
assert(isinstance(item, ast.ExpressionStatement))
left, right = item.expr.params
if left == right:
return []
return [item]
def _mergeVariables(root, predeclared):
_setScopeParents(root)
info = findVarDeclInfo(root, predeclared)
lvars = [expr for expr in info if isinstance(expr, ast.Local)]
forbidden = set()
#If var has any defs which aren't a literal or local, mark it as a leaf node (it can't be merged into something)
for var in lvars:
if not all(isinstance(expr, (ast.Local, ast.Literal)) for expr in info[var].defs):
forbidden.add(var)
elif info[var].declScope is not None:
forbidden.add(var)
sccs = graph_util.tarjanSCC(lvars, lambda var:([] if var in forbidden else info[var].defs))
#the sccs will be in topolgical order
varmap = {}
for scc in sccs:
if forbidden.isdisjoint(scc):
alldefs = []
for expr in scc:
for def_ in info[expr].defs:
if def_ not in scc:
alldefs.append(varmap[def_])
if len(set(alldefs)) == 1:
target = alldefs[0]
if all(var.dtype == target.dtype for var in scc):
scope = ast.StatementBlock.join(*(info[var].scope for var in scc))
scope = ast.StatementBlock.join(scope, info[target].declScope) #scope is unchanged if declScope is none like usual
if info[target].declScope is None or info[target].declScope == scope:
for var in scc:
varmap[var] = target
info[target].scope = ast.StatementBlock.join(scope, info[target].scope)
continue
#fallthrough if merging is impossible
for var in scc:
varmap[var] = var
if len(info[var].defs) > 1:
forbidden.add(var)
_preorder(root, partial(_replaceExpressions, rdict=varmap))
_oktypes = ast.BinaryInfix, ast.Local, ast.Literal, ast.Parenthesis, ast.Ternary, ast.TypeName, ast.UnaryPrefix
def hasSideEffects(expr):
if not isinstance(expr, _oktypes):
return True
#check for division by 0. If it's a float or dividing by nonzero literal, it's ok
elif isinstance(expr, ast.BinaryInfix) and expr.opstr in ('/','%'):
if expr.dtype not in (objtypes.FloatTT, objtypes.DoubleTT):
divisor = expr.params[-1]
if not isinstance(divisor, ast.Literal) or divisor.val == 0:
return True
return False
def _inlineVariables(root):
#first find all variables with a single def and use
defs = collections.defaultdict(list)
uses = collections.defaultdict(int)
def visitExprFindDefs(expr):
if expr.isLocalAssign():
defs[expr.params[0]].append(expr)
elif isinstance(expr, ast.Local):
uses[expr] += 1
def visitFindDefs(scope, item):
if item.expr is not None:
stack = [item.expr]
while stack:
expr = stack.pop()
visitExprFindDefs(expr)
stack.extend(expr.params)
_preorder(root, visitFindDefs)
#These should have 2 uses since the initial assignment also counts
replacevars = {k for k,v in defs.items() if len(v)==1 and uses[k]==2 and k.dtype == v[0].params[1].dtype}
def doReplacement(item, pairs):
old, new = item.expr.params
assert(isinstance(old, ast.Local) and old.dtype == new.dtype)
stack = [(True, (True, item2, expr)) for item2, expr in reversed(pairs) if expr is not None]
while stack:
recurse, args = stack.pop()
if recurse:
canReplace, parent, expr = args
stack.append((False, expr))
#For ternaries, we don't want to replace into the conditionally
#evaluated part, but we still need to check those parts for
#barriers. For both ternaries and short circuit operators, the
#first param is always evaluated, so it is safe
if isinstance(expr, ast.Ternary) or isinstance(expr, ast.BinaryInfix) and expr.opstr in ('&&','||'):
for param in reversed(expr.params[1:]):
stack.append((True, (False, expr, param)))
stack.append((True, (canReplace, expr, expr.params[0])))
#For assignments, we unroll the LHS arguments, because if assigning
#to an array or field, we don't want that to serve as a barrier
elif isinstance(expr, ast.Assignment):
left, right = expr.params
stack.append((True, (canReplace, expr, right)))
if isinstance(left, (ast.ArrayAccess, ast.FieldAccess)):
for param in reversed(left.params):
stack.append((True, (canReplace, left, param)))
else:
assert(isinstance(left, ast.Local))
else:
for param in reversed(expr.params):
stack.append((True, (canReplace, expr, param)))
if expr == old:
if canReplace:
if isinstance(parent, ast.JavaExpression):
params = parent.params = list(parent.params)
params[params.index(old)] = new
else: #replacing in a top level statement
assert(parent.expr == old)
parent.expr = new
return canReplace
else:
expr = args
if hasSideEffects(expr):
return False
return False
def visitReplace(scope):
newstatements = []
for item in reversed(scope.statements):
for sub in item.getScopes():
visitReplace(sub)
if isinstance(item.expr, ast.Assignment) and item.expr.params[0] in replacevars:
expr_roots = []
for item2 in newstatements:
#Don't inline into a while condition as it may be evaluated more than once
if not isinstance(item2, ast.WhileStatement):
expr_roots.append((item2, item2.expr))
if item2.getScopes():
break
success = doReplacement(item, expr_roots)
if success:
continue
newstatements.insert(0, item)
scope.statements = newstatements
visitReplace(root)
def _createDeclarations(root, predeclared):
_setScopeParents(root)
info = findVarDeclInfo(root, predeclared)
localdefs = collections.defaultdict(list)
newvars = [var for var in info if isinstance(var, ast.Local) and info[var].declScope is None]
remaining = set(newvars)
#The compiler treats statements as if they can throw any exception at any time, so
#it may think variables are not definitely assigned even when they really are.
#Therefore, we give an unused initial value to every variable declaration
#TODO - find a better way to handle this
_init_d = {objtypes.BoolTT: ast.Literal.FALSE,
objtypes.IntTT: ast.Literal.ZERO,
objtypes.FloatTT: ast.Literal.FZERO,
objtypes.DoubleTT: ast.Literal.DZERO}
def mdVisitVarUse(var):
decl = ast.VariableDeclarator(ast.TypeName(var.dtype), var)
right = _init_d.get(var.dtype, ast.Literal.NULL)
localdefs[info[var].scope].append( ast.LocalDeclarationStatement(decl, right) )
remaining.remove(var)
def mdVisitScope(scope):
if isinstance(scope, ast.StatementBlock):
for i,stmt in enumerate(scope.statements):
if isinstance(stmt, ast.ExpressionStatement):
if isinstance(stmt.expr, ast.Assignment):
var, right = stmt.expr.params
if var in remaining and scope == info[var].scope:
decl = ast.VariableDeclarator(ast.TypeName(var.dtype), var)
new = ast.LocalDeclarationStatement(decl, right)
scope.statements[i] = new
remaining.remove(var)
if stmt.expr is not None:
top = stmt.expr
for expr in top.postFlatIter():
if expr in remaining:
mdVisitVarUse(expr)
for sub in stmt.getScopes():
mdVisitScope(sub)
mdVisitScope(root)
# print remaining
assert(not remaining)
assert(None not in localdefs)
for scope, ldefs in localdefs.items():
scope.statements = ldefs + scope.statements
def _createTernaries(scope, item):
if isinstance(item, ast.IfStatement) and len(item.getScopes()) == 2:
block1, block2 = item.getScopes()
if (len(block1.statements) == len(block2.statements) == 1) and block1.jumpKey == block2.jumpKey:
s1, s2 = block1.statements[0], block2.statements[0]
e1, e2 = s1.expr, s2.expr
if isinstance(s1, ast.ReturnStatement) and isinstance(s2, ast.ReturnStatement):
expr = None if e1 is None else ast.Ternary(item.expr, e1, e2)
item = ast.ReturnStatement(expr, s1.tt)
if isinstance(s1, ast.ExpressionStatement) and isinstance(s2, ast.ExpressionStatement):
if isinstance(e1, ast.Assignment) and isinstance(e2, ast.Assignment):
# if e1.params[0] == e2.params[0] and max(e1.params[1].complexity(), e2.params[1].complexity()) <= 1:
if e1.params[0] == e2.params[0]:
expr = ast.Ternary(item.expr, e1.params[1], e2.params[1])
temp = ast.ExpressionStatement(ast.Assignment(e1.params[0], expr))
if not block1.doesFallthrough():
assert(not block2.doesFallthrough())
item = ast.StatementBlock(item.func, item.continueKey, item.breakKey, [temp], block1.jumpKey)
else:
item = temp
if item.expr is not None:
item.expr = _simplifyExpressions(item.expr)
return [item]
def _fixExprStatements(scope, item, namegen):
if isinstance(item, ast.ExpressionStatement):
if not isinstance(item.expr, (ast.Assignment, ast.ClassInstanceCreation, ast.MethodInvocation, ast.Dummy)):
right = item.expr
left = ast.Local(right.dtype, lambda expr:namegen.getPrefix('dummy'))
decl = ast.VariableDeclarator(ast.TypeName(left.dtype), left)
item = ast.LocalDeclarationStatement(decl, right)
return [item]
def _addCastsAndParens(scope, item, env):
item.addCastsAndParens(env)
def _chooseJump(choices):
for b, t in choices:
if b is None:
return b, t
for b, t in choices:
if b.label is not None:
return b, t
return choices[0]
def _generateJumps(scope, targets=collections.OrderedDict(), fallthroughs=NONE_SET, dryRun=False):
assert(None in fallthroughs)
#breakkey can be None with non-None jumpkey when we're a scope in a switch statement that falls through
#and the end of the switch statement is unreachable
assert(scope.breakKey is not None or scope.jumpKey is None or not scope.labelable)
if scope.jumpKey not in fallthroughs:
assert(not scope.statements or not isinstance(scope.statements[-1], (ast.ReturnStatement, ast.ThrowStatement)))
vals = [k for k,v in targets.items() if v == scope.jumpKey]
assert(vals)
jump = _chooseJump(vals)
if not dryRun:
scope.statements.append(ast.JumpStatement(*jump))
for item in reversed(scope.statements):
if not item.getScopes():
fallthroughs = NONE_SET
continue
if isinstance(item, ast.WhileStatement):
fallthroughs = frozenset([None, item.continueKey])
else:
fallthroughs |= frozenset([item.breakKey])
newtargets = targets.copy()
if isinstance(item, ast.WhileStatement):
newtargets[None, True] = item.continueKey
newtargets[item, True] = item.continueKey
if isinstance(item, (ast.WhileStatement, ast.SwitchStatement)):
newtargets[None, False] = item.breakKey
newtargets[item, False] = item.breakKey
for subscope in reversed(item.getScopes()):
_generateJumps(subscope, newtargets, fallthroughs, dryRun=dryRun)
if isinstance(item, ast.SwitchStatement):
fallthroughs = frozenset([None, subscope.continueKey])
fallthroughs = frozenset([None, item.continueKey])
def _pruneVoidReturn(scope):
if scope.statements:
last = scope.statements[-1]
if isinstance(last, ast.ReturnStatement) and last.expr is None:
scope.statements.pop()
def generateAST(method, graph, forbidden_identifiers):
env = method.class_.env
namegen = NameGen(forbidden_identifiers)
class_ = method.class_
inputTypes = parseMethodDescriptor(method.descriptor, unsynthesize=False)[0]
tts = objtypes.verifierToSynthetic_seq(inputTypes)
if graph is not None:
entryNode, nodes = graphproxy.createGraphProxy(graph)
if not method.static:
entryNode.invars[0].name = 'this'
setree = structuring.structure(entryNode, nodes, (method.name == ''))
ast_root, varinfo = astgen.createAST(method, graph, setree, namegen)
argsources = [varinfo.var(entryNode, var) for var in entryNode.invars]
disp_args = argsources if method.static else argsources[1:]
for expr, tt in zip(disp_args, tts):
expr.dtype = tt
decls = [ast.VariableDeclarator(ast.TypeName(expr.dtype), expr) for expr in disp_args]
################################################################################################
ast_root.bases = (ast_root,) #needed for our setScopeParents later
# print ast_root.print_()
assert(_generateJumps(ast_root, dryRun=True) is None)
_preorder(ast_root, _fixObjectCreations)
boolize.boolizeVars(ast_root, argsources)
_simplifyBlocks(ast_root)
assert(_generateJumps(ast_root, dryRun=True) is None)
_mergeVariables(ast_root, argsources)
_preorder(ast_root, _createTernaries)
_inlineVariables(ast_root)
_simplifyBlocks(ast_root)
_preorder(ast_root, _createTernaries)
_inlineVariables(ast_root)
_simplifyBlocks(ast_root)
_createDeclarations(ast_root, argsources)
_preorder(ast_root, partial(_fixExprStatements, namegen=namegen))
_preorder(ast_root, partial(_addCastsAndParens, env=env))
_generateJumps(ast_root)
_pruneVoidReturn(ast_root)
else: #abstract or native method
ast_root = None
argsources = [ast.Local(tt, lambda expr:namegen.getPrefix('arg')) for tt in tts]
decls = [ast.VariableDeclarator(ast.TypeName(expr.dtype), expr) for expr in argsources]
flags = method.flags - set(['BRIDGE','SYNTHETIC','VARARGS'])
if method.name == '': #More arbtirary restrictions. Yay!
flags = flags - set(['ABSTRACT','STATIC','FINAL','NATIVE','STRICTFP','SYNCHRONIZED'])
flagstr = ' '.join(map(str.lower, sorted(flags)))
inputTypes, returnTypes = parseMethodDescriptor(method.descriptor, unsynthesize=False)
ret_tt = objtypes.verifierToSynthetic(returnTypes[0]) if returnTypes else ('.void',0)
return ast2.MethodDef(class_, flagstr, method.name, ast.TypeName(ret_tt), decls, ast_root)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy