import autopath from pypy.annotation.pairtype import extendabletype, pair, pairtype from pypy.objspace.flow.model import Variable, Constant, SpaceOperation from pypy.rpython.lltype import Signed, Void from pypy.translator.unsimplify import insert_empty_block class Term: __metaclass__ = extendabletype def __init__(self, lltype): self.lltype = lltype def copy(self, memo): if self in memo: return memo[self] else: result = self.__class__.__new__(self.__class__) result.__dict__ = self.__dict__.copy() memo[self] = result self.internal_copy(memo) return result def internal_copy(self, memo): pass class RuntimeValue(Term): def flatten(self, result, result_list): if self not in result: result[self] = True result_list.append(self) class CompiletimeValue(RuntimeValue): def __init__(self, lltype, value): self.lltype = lltype self.value = value self.var = Constant(value) self.var.concretetype = lltype class VirtualStructure(Term): def __init__(self, lltype, fields={}): self.lltype = lltype self.fields = fields.copy() def flatten(self, result, result_list): if self not in result: result[self] = True result_list.append(self) items = self.fields.items() items.sort() for key, term in items: term.flatten(result, result_list) def internal_copy(self, memo): for key in self.fields.keys(): self.fields[key] = self.fields[key].copy(memo) class __extend__(pairtype(Term, Term)): def instanceshape((self, other), memo): return False def unionshape((self, other), memo): assert self.lltype == other.lltype return RuntimeValue(self.lltype) class __extend__(pairtype(Term, RuntimeValue)): def instanceshape((self, other), memo): return True class __extend__(pairtype(VirtualStructure, VirtualStructure)): def instanceshape((self, other), memo): keys1 = self.fields.keys() keys1.sort() keys2 = other.fields.keys() keys2.sort() if keys1 != keys2: return False for key in keys1: if not instance_shape(self.fields[key], other.fields[key], memo): return False return True def unionshape((self, other), memo): assert self.lltype == other.lltype keys1 = self.fields.keys() keys1.sort() keys2 = other.fields.keys() keys2.sort() if keys1 != keys2: return RuntimeValue(self.lltype) else: result = VirtualStructure(self.lltype) memo[self, other] = result for key in keys1: term = union_shape(self.fields[key], other.fields[key], memo) result.fields[key] = term return result def instance_shape(term1, term2, memo=None): if memo is None: memo = {} if memo.setdefault(term2, term1) is not term1: return False return pair(term1, term2).instanceshape(memo) def same_shape(term1, term2): return instance_shape(term1, term2) and instance_shape(term2, term1) def union_shape(term1, term2, memo=None): # return the less general term of which both term1 and term2 are # instance shapes. All Term objects of the result are new. if memo is None: memo = {} memokey = term1, term2 try: return memo[memokey] except KeyError: result = pair(term1, term2).unionshape(memo) memo[memokey] = result return result def copy_shape(term): return union_shape(term, term) def flatten_term(term, klass=Term): d = {} lst = [] term.flatten(d, lst) return [term for term in lst if isinstance(term, klass)] def test_shape(): r1 = RuntimeValue(Signed) r2 = RuntimeValue(Signed) s1 = VirtualStructure(Signed) s2 = VirtualStructure(Signed) assert same_shape(r1, r2) assert not same_shape(s1, r2) assert not same_shape(r1, s2) assert same_shape(s1, s2) assert instance_shape(r1, r2) assert instance_shape(s1, r2) assert not instance_shape(r1, s2) assert instance_shape(s1, s2) # ____________________________________________________________ class AbstractFrame: def __init__(self, block, cells): assert len(cells) == len(block.inputargs) bindings = {} for cell, v in zip(cells, block.inputargs): bindings[v] = cell self.block = block self.inputbindings = VirtualStructure(None, bindings) def frombindings(cls, block, inputbindings): cells = [inputbindings.fields[v] for v in block.inputargs] return cls(block, cells) frombindings = classmethod(frombindings) def merge(self, other): assert self.block is other.block newbindings = union_shape(self.inputbindings, other.inputbindings) return self.frombindings(self.block, newbindings) def copy(self): newbindings = copy_shape(self.inputbindings) return self.frombindings(self.block, newbindings) def __eq__(self, other): return (self.__class__ is other.__class__ and self.block is other.block and same_shape(self.inputbindings, other.inputbindings)) def __ne__(self, other): return not (self == other) def __hash__(self): raise TypeError("Frame objects are not hashable") def flowin(self): self.bindings = copy_shape(self.inputbindings) self.setup_flowin() block = self.block for op in block.operations: self.consider_op(op) result = [] for link in block.exits: result.append(self.flowout_frame(link)) return result def setup_flowin(self): pass def flowout_frame(self, link): cells = [self.binding(v) for v in link.args] frame = self.__class__(link.target, cells) return frame def setbinding(self, v, cell): assert isinstance(v, Variable) self.bindings.fields[v] = cell def binding(self, v): if isinstance(v, Variable): return self.bindings.fields[v] elif isinstance(v, Constant): return CompiletimeValue(v.concretetype, v.value) else: raise TypeError("expected a Variable or Constant, got %r" % (v,)) def consider_op(self, op): consider_meth = getattr(self, 'consider_op_' + op.opname, self.consider_default) consider_meth(op) class AbstractInterpreter: def __init__(self): self.pendingframes = {} self.frames = {} def addframe(self, frame): block = frame.block if block in self.frames: frame = frame.merge(self.frames[block]) if frame == self.frames[block]: return False elif block.operations == (): cells = [RuntimeValue(v.concretetype) for v in block.inputargs] frame = frame.__class__(block, cells) else: frame = frame.copy() self.frames[block] = frame self.pendingframes[block] = frame return True def complete(self): while self.pendingframes: block, frame = self.pendingframes.popitem() for newframe in frame.flowin(): self.addframe(newframe) # ____________________________________________________________ class MallocTrackingFrame(AbstractFrame): def setup_flowin(self): self.newops = [] self.casted_pointers = {} # {Variable: (original_Variable, castop)} self.newinputargs = [] self.newinputargs_term = [] for v in self.block.inputargs: term = self.binding(v) if isinstance(term, RuntimeValue): self.newinputargs.append(v) self.newinputargs_term.append(self.inputbindings.fields[v]) term.var = v for term, inputterm in zip(flatten_term(self.bindings), flatten_term(self.inputbindings)): if isinstance(term, RuntimeValue) and not hasattr(term, 'var'): if isinstance(term, CompiletimeValue): term.var = Constant(term.value) else: term.var = Variable() term.var.concretetype = term.lltype self.newinputargs.append(term.var) self.newinputargs_term.append(inputterm) assert len(self.newinputargs_term) == len(self.newinputargs) #self.flowout_map = {} #for link in block.exits: # self.flowout_map[link] = link.target.inputargs def consider_default(self, op): for v in op.args: cell = self.binding(v) if not isinstance(cell, RuntimeValue): v1 = self.force(cell) newop = SpaceOperation('same_as', [v1], v) self.newops.append(newop) rtvalue = RuntimeValue(op.result.concretetype) rtvalue.var = op.result self.setbinding(op.result, rtvalue) self.newops.append(op) def possibly_not_casted_binding(self, v): try: return self.binding(v) except KeyError: v_original, castop = self.casted_pointers[v] cell = self.binding(v_original) if isinstance(cell, VirtualStructure): return cell else: self.consider_default(castop) return self.binding(v) def consider_op_malloc(self, op): self.setbinding(op.result, VirtualStructure(op.result.concretetype)) def consider_op_setfield(self, op): cell = self.possibly_not_casted_binding(op.args[0]) if isinstance(cell, VirtualStructure): attrname = op.args[0].concretetype, op.args[1].value cell.fields[attrname] = self.binding(op.args[2]) return self.consider_default(op) def consider_op_getfield(self, op): cell = self.possibly_not_casted_binding(op.args[0]) if isinstance(cell, VirtualStructure): attrname = op.args[0].concretetype, op.args[1].value if attrname in cell.fields: cell = cell.fields[attrname] if isinstance(cell, RuntimeValue): newop = SpaceOperation('same_as', [cell.var], op.result) self.newops.append(newop) cell.var = op.result self.setbinding(op.result, cell) return self.consider_default(op) def consider_op_cast_pointer(self, op): # HACK: see if the result is used only in getfield and setfield it = iter(self.block.operations) for op1 in it: if op1 is op: break for op1 in it: # resume just after 'op' if op.result in op1.args[1:]: break if (op.result == op1.args[0] and op1.opname not in ('getfield', 'setfield')): break else: for link in self.block.exits: if op.result in link.args: break else: # not found anywhere else than getfield/setfield operations self.casted_pointers[op.result] = op.args[0], op return # skip the operation self.consider_default(op) def force(self, term): assert isinstance(term, VirtualStructure) pending_fields = [] mapping = {} for subterm in flatten_term(term, klass=VirtualStructure): rtvalue = RuntimeValue(subterm.lltype) rtvalue.var = Variable() rtvalue.var.concretetype = rtvalue.lltype ctype = Constant(rtvalue.lltype) ctype.concretetype = Void newop = SpaceOperation('malloc', [ctype], rtvalue.var) self.newops.append(newop) for key, t_value in subterm.fields.items(): pending_fields.append((rtvalue.var, key, t_value)) subterm.__class__ = rtvalue.__class__ subterm.__dict__ = rtvalue.__dict__ for v, (targetlltype, attrname), t_value in pending_fields: assert isinstance(t_value, RuntimeValue) v_value = t_value.var cname = Constant(attrname) cname.concretetype = Void v_result = Variable() v_result.concretetype = Void if targetlltype != v.concretetype: v1 = Variable() v1.concretetype = targetlltype newop = SpaceOperation('cast_pointer', [v], v1) self.newops.append(newop) v = v1 newop = SpaceOperation('setfield', [v, cname, v_value], v_result) self.newops.append(newop) return term.var def patchblock(self, translator, frames): block = self.block if block.operations == (): return block.inputargs = self.newinputargs block.operations[:] = self.newops def patchlinks(self, translator, frames): block = self.block if block.operations == (): return can_insert_here = block.exitswitch is None and len(block.exits) == 1 for link in block.exits: f_provided = self.flowout_frame(link) f_provided.inputbindings = f_provided.inputbindings.copy({}) f_expected = frames[link.target] mapping = {} consistent = instance_shape(f_provided.inputbindings, f_expected.inputbindings, mapping) assert consistent f_provided.newops = [] outputargs = [] for t_expected in f_expected.newinputargs_term: assert isinstance(t_expected, RuntimeValue) t_provided = mapping[t_expected] if isinstance(t_provided, RuntimeValue): v_provided = t_provided.var else: v_provided = f_provided.force(t_provided) outputargs.append(v_provided) link.args = outputargs if f_provided.newops: if can_insert_here: block.operations.extend(f_provided.newops) else: block = insert_empty_block(translator, link, newops=f_provided.newops) link = block.exits[0] def remove_simple_mallocs(translator, graph): cells = [RuntimeValue(v.concretetype) for v in graph.getargs()] startframe = MallocTrackingFrame(graph.startblock, cells) interpreter = AbstractInterpreter() interpreter.addframe(startframe) interpreter.complete() for frame in interpreter.frames.values(): frame.patchlinks(translator, interpreter.frames) for frame in interpreter.frames.values(): frame.patchblock(translator, interpreter.frames) ## for block, frame in interpreter.frames.items(): ## rtvalues = flatten_term(frame.inputbindings, klass=RuntimeValue) ## rtvalues_set = dict.fromkeys(rtvalues) ## frame.newinputargs = [] ## for v in block.inputargs: ## term = frame.binding(v) ## if term in rtvalues_set: ## rtvalues_set[term] = v ## frame.newinputargs.append(v) ## for term in rtvalues: ## if rtvalues_set[term] is None: ## v = Variable() ## rtvalues_set[term] = v ## frame.newinputargs.append(v) ## #import pdb; pdb.set_trace() ## for block, frame in interpreter.frames.items(): ## ... # ____________________________________________________________ from pypy.translator.translator import Translator from pypy.annotation import model as annmodel from pypy.translator.backendoptimization import remove_same_as ##def fn(x, y): ## s, d = x+y, x-y ## return s*d class T: pass def dummystuff(x): return x def fn(y, x, a, b, c, d, e, f): t = T() t.x = x t.y = y if x > 0: t = dummystuff(t) return t def tget(): t = Translator(fn) a = t.annotate([int]*8) a.simplify() t.specialize() t.checkgraphs() return t t = tget() #for graph in self.flowgraphs.values(): # remove_simple_mallocs(graph) remove_same_as(t.getflowgraph()) remove_simple_mallocs(t, t.getflowgraph()) t.view() t.checkgraphs() #t.view()