""" The purpose of this module is to hide iterators and the itertools module from the programmer, without loosing the laziness and the memory efficiency. The single entry point is: collect(x) -> collector object """ # # Design goals: # # * collect(x) behaves as much like a list as possible. # # * collect(iterable) asks 'iterable' for more elements as lazily as possible. # * as they are based on iterables, collect objects behave lazily only when # their elements tend to be accessed in a more or less left-to-right order. # # * collect objects are themselves iterable, of course; if there is no # reference left to a collect object but only to iterators derived from # it, then each item in the collection should be discarded along the way # as soon as all iterators have already returned it. # # So ``for item in collect(iterable)'' is as lazy and memory-efficient as # ``for item in iterable''. # from __future__ import generators from itertools import islice import operator, sys __all__ = ['collect'] class collect(object): """A collector is a sequence that behaves almost like a list. It collects its items lazily, and discards them early if it can prove that they will no longer be used. """ def __init__(self, iterable=[]): if isinstance(iterable, (list, tuple, dict)): self._head = storagenode(list(iterable)) elif isinstance(iterable, collect): self._head = iterable._head elif isinstance(iterable, storagebase): self._head = iterable else: self._head = storagethunk(iter(iterable)) ### read-only sequence methods ### def __iter__(self): """Iterate over the elements. Warning! Unlike lists, this produces exactly the elements that were in the collection at the moment the iterator was produced. Further changes to the collect object are not reflected. (This is probably a good thing but it can break compatibility) """ return self._head.niter() def __getitem__(self, index): if not isinstance(index, slice): if index < 0: index += len(self) if index < 0: raise IndexError, "collect index out of range" try: self._head = self._head.populate(index+1) except StopIteration: raise IndexError, "collect index out of range" return self._head.datalist[index] elif index.step is None or index.step > 0: # slicing with a positive step start = index.start stop = index.stop step = index.step if step is None: step = 1 if start is None: start = 0 elif start < 0: start += len(self) # XXX should be more lazy if start < 0: start = 0 if stop is not None and stop >= 0: cache = self._head.force(0) if stop <= len(cache): return collect(storagenode(cache[start:stop:step])) try: head = self._head.forward(start) except StopIteration: return collect(storageend()) if stop is None: if step == 1: return collect(head) else: return collect(islice(head.niter(), 0, None, step)) elif start < stop: return collect(islice(head.niter(), 0, stop-start, step)) elif stop < 0: it = idroptail(head.niter(), -stop) if step != 1: it = islice(it, 0, None, step) return collect(it) else: return collect(storageend()) else: # slicing with a negative step -- cannot do it lazily upperbound = sys.maxint if (index.start is not None and index.start >= 0 and (index.stop is None or index.stop >= 0)): upperbound = index.start+1 # don't need more items in this case cache = self._head.force(upperbound) return collect(storagenode(cache[index])) def __len__(self): length = 0 head = self._head while isinstance(head, storagenode): length += len(head.datalist) head = head.next return length + len(head.force()) def __nonzero__(self): return bool(self._head.force(1)) def __repr__(self): items = [] head = self._head while head is not None: head = head.repr(items) return '%s([%s)' % (self.__class__.__name__, ', '.join(items)) def __str__(self): return '%s(%r)' % (self.__class__.__name__, self._head.force()) def __add__(self, other): if isinstance(other, (collect, list)): return collect(concat(self, other)) return NotImplemented def __radd__(self, other): if isinstance(other, (collect, list)): return collect(concat(other, self)) return NotImplemented def __mul__(self, count): if isinstance(count, (int, long)): return collect(repeat(self, count)) return NotImplemented __rmul__ = __mul__ def _compare_op(op): def compare(self, other): if isinstance(other, list): other = collect(other) if isinstance(other, collect): l1, l2 = ifirstdiff(self._head.niter(), other._head.niter()) return op(l1, l2) return NotImplemented return compare __lt__ = _compare_op(operator.lt) __le__ = _compare_op(operator.le) __eq__ = _compare_op(operator.eq) __ne__ = _compare_op(operator.ne) __gt__ = _compare_op(operator.gt) __ge__ = _compare_op(operator.ge) del _compare_op def __hash__(self): raise TypeError, "collect objects are unhashable" ### mutable sequence methods ### def _setslice(self, slice, seq, deletion=False): # slicing step = slice.step if step is None: step = 1 if step == 1 and slice.stop is None: # special case: replacing the end of the sequence if slice.start in (None, 0): self._head = collect(seq)._head else: self._head = concat(self[:slice.start], seq) return if step > 0: lowerbound = slice.start upperbound = slice.stop shift = 0 else: lowerbound = slice.stop upperbound = slice.start shift = 1 if (upperbound is not None and upperbound >= 0 and (lowerbound is None or lowerbound >= 0)): # in this case we're sure we don't need the items in # the part [max(lowerbound,upperbound)+shift:] if lowerbound is not None: upperbound = max(upperbound, lowerbound) upperbound += shift part1 = self[:upperbound] part2 = self[upperbound:] else: part1 = self part2 = storageend() lst = list(part1) if deletion: del lst[slice] else: lst[slice] = seq self._head = concat(lst, part2) def _inrange(self, index): if index < 0: index += len(self) if index < 0: raise IndexError, "collect index out of range" try: self._head = self._head.populate(index+1) except StopIteration: raise IndexError, "collect index out of range" if sys.getrefcount(self._head) > 2: self._head = storagenode(self._head.datalist[:], self._head.next) return index def __setitem__(self, index, obj): if not isinstance(index, slice): index = self._inrange(index) self._head.datalist[index] = obj else: self._setslice(index, obj) def __delitem__(self, index): if not isinstance(index, slice): index = self._inrange(index) del self._head.datalist[index] else: self._setslice(index, [], deletion=True) def append(self, x): newnode = storagenode([x]) self._head = concat(self, newnode) def extend(self, x): self._head = concat(self, x) def __iadd__(self, other): self._head = concat(self, other) return self def __imul__(self, count): if isinstance(count, (int, long)): self._head = repeat(self, count) return self return NotImplemented def count(self, x): return operator.countOf(self, x) def index(self, x, i=0, j=None): if i < 0: i += len(self) if i < 0: i = 0 sl = self[slice(i,j)] return operator.indexOf(sl, x) + i def insert(self, i, x): self[i:i] = [x] def pop(self, i=-1): x = self[i] del self[i] return x def remove(self, x): del self[self.index(x)] def reverse(self): lst = list(self) lst.reverse() self[:] = lst def sort(self, cmp=None, key=None, reverse=False): """Lazy sort function. Warning! Unlike list.sort(), this method will continue to performs comparisons between the elements even after it returned! """ from __builtin__ import cmp as bltincmp from heapq import heapify, heappop if cmp is None: if key is None and not reverse: # easy case where we don't need the wrapper class lst = [] for item in self: lst.append((item, len(lst))) def inorder_gen(heap): while heap: yield heappop(heap)[0] heapify(lst) self[:] = inorder_gen(lst) return cmp = bltincmp class wrapper: if key is None: def __init__(self, item, index): self.key = self.item = item self.index = index else: def __init__(self, item, index): self.item = item self.key = key(item) self.index = index if reverse: def __cmp__(self, other): return (cmp(other.key, self.key) or bltincmp(other.index, self.index)) else: def __cmp__(self, other): return (cmp(self.key, other.key) or bltincmp(self.index, other.index)) lst = [] for item in self: lst.append(wrapper(item, len(lst))) def inorder_gen(heap): while heap: yield heappop(heap).item heapify(lst) self[:] = inorder_gen(lst) # ____________________________________________________________ # # Implementation data structures # class storagebase(object): datalist = None def niter(self): while True: node = self.getnode() result = node.datalist[::-1] self = node.next del node while result: yield result.pop() # try hard not to keep a ref to the data def force(self, needed=sys.maxint): if needed <= 0: return self.datalist or [] try: while not self.datalist: self = self.resolve() except StopIteration: return [] try: while len(self.datalist) < needed: self.merge() except StopIteration: pass return self.datalist def populate(self, needed): while not self.datalist: self = self.resolve() while len(self.datalist) < needed: self.merge() return self def forward(self, n): while n > 0: node = self.getnode() if n < len(node.datalist): return storagenode(node.datalist[n:], node.next) n -= len(node.datalist) self = node.next return self def getnode(self): while not self.datalist: self = self.resolve() return self def known_empty(self): return False class storageend(storagebase): def resolve(self): raise StopIteration def repr(self, items): if not items: items.append('') items[-1] += ']' def known_empty(self): return True class storagenode(storagebase): def __init__(self, datalist, next=storageend()): self.datalist = datalist self.next = next def resolve(self): assert self.datalist == [] return self.next def repr(self, items): for x in self.datalist: items.append(repr(x)) return self.next def merge(self): nextnode = self.next.getnode() self.datalist += nextnode.datalist self.next = nextnode.next def known_empty(self): while not self.datalist: self = self.next if not isinstance(self, storagenode): return self.known_empty() return False class storagethunk(storagebase): def __init__(self, moredata, afterwards=None): self.moredata = moredata self.afterwards = afterwards def resolve(self): try: data = self.moredata.next() except StopIteration: if self.afterwards is None: self.__class__ = storageend else: self.datalist = [] self.next = self.afterwards() self.__class__ = storagenode else: self.datalist = [data] self.next = storagethunk(self.moredata, self.afterwards) self.__class__ = storagenode del self.afterwards del self.moredata return self def repr(self, items): if len(items) < 3: return self.resolve() items[-1] += '...' def concat(x, y): # x, y can each be either a storage or an iterable if not isinstance(y, storagebase): y = collect(y)._head if isinstance(x, (list, tuple, dict)): return storagenode(list(x), y) if isinstance(x, collect): x = x._head if isinstance(x, storagebase): if y.known_empty(): return x x = x.niter() else: x = iter(x) return storagethunk(x, lambda: y) def repeat(x, n): if n <= 0: return storageend() if isinstance(x, collect): x = x._head if isinstance(x, storagebase): if n == 1 or x.known_empty(): return x x1 = x.niter() else: x1 = iter(x) return storagethunk(x1, lambda: repeat(x, n-1)) def ifirstdiff(it1, it2): while True: try: data1 = (it1.next(),) except StopIteration: data1 = () try: data2 = (it2.next(),) except StopIteration: return data1, () if data1 != data2: return data1, data2 def ikeeptail(iterable, keep): """Returns the last 'keep' elements of 'iterable'. """ it = iter(iterable) buffer = [] i = 0 try: for j in xrange(keep): buffer.append(it.next()) while True: buffer[i] = it.next() i = (i+1) % keep except StopIteration: for x in buffer[i:]: yield x for x in buffer[:i]: yield x def idroptail(iterable, skip): """Returns the same elements as 'iterable' minus the last 'skip' ones. skip must be > 0. """ it = iter(iterable) buffer = [it.next() for i in xrange(skip)] i = 0 while True: value = it.next() yield buffer[i] buffer[i] = value i = (i+1) % skip # ____________________________________________________________ if __name__ == '__main__': import sys def ansi_print(text, esc, file=None): if file is None: file = sys.stderr text = text.rstrip() if sys.platform != "win32" and file.isatty(): text = ('\x1b[%sm' % esc + text + '\x1b[0m') # ANSI color code "reset" file.write(text + '\n') def enum(*args): for x in xrange(*args): ansi_print(str(x), '31') yield x c = collect(enum(0, 50, 2)) d = collect(enum(10)) e = collect('akuhfquiewhfxuiewmhuiew')