9ab75bf40dd5cec9e11f204ec421e35f2d9d4837
[doneit] / src / coroutines.py
1 #!/usr/bin/env python
2
3 """
4 Uses for generators
5 * Pull pipelining (iterators)
6 * Push pipelining (coroutines)
7 * State machines (coroutines)
8 * "Cooperative multitasking" (coroutines)
9 * Algorithm -> Object transform for cohesiveness (for example context managers) (coroutines)
10
11 Design considerations
12 * When should a stage pass on exceptions or have it thrown within it?
13 * When should a stage pass on GeneratorExits?
14 * Is there a way to either turn a push generator into a iterator or to use
15         comprehensions syntax for push generators (I doubt it)
16 * When should the stage try and send data in both directions
17 * Since pull generators (generators), push generators (coroutines), subroutines, and coroutines are all coroutines, maybe we should rename the push generators to not confuse them, like signals/slots? and then refer to two-way generators as coroutines
18 ** If so, make s* and co* implementation of functions
19 """
20
21 import threading
22 import Queue
23 import pickle
24 import functools
25 import itertools
26 import xml.sax
27 import xml.parsers.expat
28
29
30 def autostart(func):
31         """
32         >>> @autostart
33         ... def grep_sink(pattern):
34         ...     print "Looking for %s" % pattern
35         ...     while True:
36         ...             line = yield
37         ...             if pattern in line:
38         ...                     print line,
39         >>> g = grep_sink("python")
40         Looking for python
41         >>> g.send("Yeah but no but yeah but no")
42         >>> g.send("A series of tubes")
43         >>> g.send("python generators rock!")
44         python generators rock!
45         >>> g.close()
46         """
47
48         @functools.wraps(func)
49         def start(*args, **kwargs):
50                 cr = func(*args, **kwargs)
51                 cr.next()
52                 return cr
53
54         return start
55
56
57 @autostart
58 def printer_sink(format = "%s"):
59         """
60         >>> pr = printer_sink("%r")
61         >>> pr.send("Hello")
62         'Hello'
63         >>> pr.send("5")
64         '5'
65         >>> pr.send(5)
66         5
67         >>> p = printer_sink()
68         >>> p.send("Hello")
69         Hello
70         >>> p.send("World")
71         World
72         >>> # p.throw(RuntimeError, "Goodbye")
73         >>> # p.send("Meh")
74         >>> # p.close()
75         """
76         while True:
77                 item = yield
78                 print format % (item, )
79
80
81 @autostart
82 def null_sink():
83         """
84         Good for uses like with cochain to pick up any slack
85         """
86         while True:
87                 item = yield
88
89
90 def itr_source(itr, target):
91         """
92         >>> itr_source(xrange(2), printer_sink())
93         0
94         1
95         """
96         for item in itr:
97                 target.send(item)
98
99
100 @autostart
101 def cofilter(predicate, target):
102         """
103         >>> p = printer_sink()
104         >>> cf = cofilter(None, p)
105         >>> cf.send("")
106         >>> cf.send("Hello")
107         Hello
108         >>> cf.send([])
109         >>> cf.send([1, 2])
110         [1, 2]
111         >>> cf.send(False)
112         >>> cf.send(True)
113         True
114         >>> cf.send(0)
115         >>> cf.send(1)
116         1
117         >>> # cf.throw(RuntimeError, "Goodbye")
118         >>> # cf.send(False)
119         >>> # cf.send(True)
120         >>> # cf.close()
121         """
122         if predicate is None:
123                 predicate = bool
124
125         while True:
126                 try:
127                         item = yield
128                         if predicate(item):
129                                 target.send(item)
130                 except StandardError, e:
131                         target.throw(e.__class__, e.message)
132
133
134 @autostart
135 def comap(function, target):
136         """
137         >>> p = printer_sink()
138         >>> cm = comap(lambda x: x+1, p)
139         >>> cm.send(0)
140         1
141         >>> cm.send(1.0)
142         2.0
143         >>> cm.send(-2)
144         -1
145         >>> # cm.throw(RuntimeError, "Goodbye")
146         >>> # cm.send(0)
147         >>> # cm.send(1.0)
148         >>> # cm.close()
149         """
150         while True:
151                 item = yield
152                 mappedItem = function(item)
153                 target.send(mappedItem)
154
155
156 def func_sink(function):
157         return comap(function, null_sink())
158
159
160 @autostart
161 def append_sink(l):
162         """
163         >>> l = []
164         >>> apps = append_sink(l)
165         >>> apps.send(1)
166         >>> apps.send(2)
167         >>> apps.send(3)
168         >>> print l
169         [1, 2, 3]
170         """
171         while True:
172                 item = yield
173                 l.append(item)
174
175
176 @autostart
177 def last_n_sink(l, n = 1):
178         """
179         >>> l = []
180         >>> lns = last_n_sink(l)
181         >>> lns.send(1)
182         >>> lns.send(2)
183         >>> lns.send(3)
184         >>> print l
185         [3]
186         """
187         del l[:]
188         while True:
189                 item = yield
190                 extraCount = len(l) - n + 1
191                 if 0 < extraCount:
192                         del l[0:extraCount]
193                 l.append(item)
194
195
196 @autostart
197 def coreduce(target, function, initializer = None):
198         """
199         >>> reduceResult = []
200         >>> lns = last_n_sink(reduceResult)
201         >>> cr = coreduce(lns, lambda x, y: x + y, 0)
202         >>> cr.send(1)
203         >>> cr.send(2)
204         >>> cr.send(3)
205         >>> print reduceResult
206         [6]
207         >>> cr = coreduce(lns, lambda x, y: x + y)
208         >>> cr.send(1)
209         >>> cr.send(2)
210         >>> cr.send(3)
211         >>> print reduceResult
212         [6]
213         """
214         isFirst = True
215         cumulativeRef = initializer
216         while True:
217                 item = yield
218                 if isFirst and initializer is None:
219                         cumulativeRef = item
220                 else:
221                         cumulativeRef = function(cumulativeRef, item)
222                 target.send(cumulativeRef)
223                 isFirst = False
224
225
226 @autostart
227 def cotee(targets):
228         """
229         Takes a sequence of coroutines and sends the received items to all of them
230
231         >>> ct = cotee((printer_sink("1 %s"), printer_sink("2 %s")))
232         >>> ct.send("Hello")
233         1 Hello
234         2 Hello
235         >>> ct.send("World")
236         1 World
237         2 World
238         >>> # ct.throw(RuntimeError, "Goodbye")
239         >>> # ct.send("Meh")
240         >>> # ct.close()
241         """
242         while True:
243                 try:
244                         item = yield
245                         for target in targets:
246                                 target.send(item)
247                 except StandardError, e:
248                         for target in targets:
249                                 target.throw(e.__class__, e.message)
250
251
252 class CoTee(object):
253         """
254         >>> ct = CoTee()
255         >>> ct.register_sink(printer_sink("1 %s"))
256         >>> ct.register_sink(printer_sink("2 %s"))
257         >>> ct.stage.send("Hello")
258         1 Hello
259         2 Hello
260         >>> ct.stage.send("World")
261         1 World
262         2 World
263         >>> ct.register_sink(printer_sink("3 %s"))
264         >>> ct.stage.send("Foo")
265         1 Foo
266         2 Foo
267         3 Foo
268         >>> # ct.stage.throw(RuntimeError, "Goodbye")
269         >>> # ct.stage.send("Meh")
270         >>> # ct.stage.close()
271         """
272
273         def __init__(self):
274                 self.stage = self._stage()
275                 self._targets = []
276
277         def register_sink(self, sink):
278                 self._targets.append(sink)
279
280         def unregister_sink(self, sink):
281                 self._targets.remove(sink)
282
283         def restart(self):
284                 self.stage = self._stage()
285
286         @autostart
287         def _stage(self):
288                 while True:
289                         try:
290                                 item = yield
291                                 for target in self._targets:
292                                         target.send(item)
293                         except StandardError, e:
294                                 for target in self._targets:
295                                         target.throw(e.__class__, e.message)
296
297
298 def _flush_queue(queue):
299         while not queue.empty():
300                 yield queue.get()
301
302
303 @autostart
304 def cocount(target, start = 0):
305         """
306         >>> cc = cocount(printer_sink("%s"))
307         >>> cc.send("a")
308         0
309         >>> cc.send(None)
310         1
311         >>> cc.send([])
312         2
313         >>> cc.send(0)
314         3
315         """
316         for i in itertools.count(start):
317                 item = yield
318                 target.send(i)
319
320
321 @autostart
322 def coenumerate(target, start = 0):
323         """
324         >>> ce = coenumerate(printer_sink("%r"))
325         >>> ce.send("a")
326         (0, 'a')
327         >>> ce.send(None)
328         (1, None)
329         >>> ce.send([])
330         (2, [])
331         >>> ce.send(0)
332         (3, 0)
333         """
334         for i in itertools.count(start):
335                 item = yield
336                 decoratedItem = i, item
337                 target.send(decoratedItem)
338
339
340 @autostart
341 def corepeat(target, elem):
342         """
343         >>> cr = corepeat(printer_sink("%s"), "Hello World")
344         >>> cr.send("a")
345         Hello World
346         >>> cr.send(None)
347         Hello World
348         >>> cr.send([])
349         Hello World
350         >>> cr.send(0)
351         Hello World
352         """
353         while True:
354                 item = yield
355                 target.send(elem)
356
357
358 @autostart
359 def cointercept(target, elems):
360         """
361         >>> cr = cointercept(printer_sink("%s"), [1, 2, 3, 4])
362         >>> cr.send("a")
363         1
364         >>> cr.send(None)
365         2
366         >>> cr.send([])
367         3
368         >>> cr.send(0)
369         4
370         >>> cr.send("Bye")
371         Traceback (most recent call last):
372           File "/usr/lib/python2.5/doctest.py", line 1228, in __run
373             compileflags, 1) in test.globs
374           File "<doctest __main__.cointercept[5]>", line 1, in <module>
375             cr.send("Bye")
376         StopIteration
377         """
378         item = yield
379         for elem in elems:
380                 target.send(elem)
381                 item = yield
382
383
384 @autostart
385 def codropwhile(target, pred):
386         """
387         >>> cdw = codropwhile(printer_sink("%s"), lambda x: x)
388         >>> cdw.send([0, 1, 2])
389         >>> cdw.send(1)
390         >>> cdw.send(True)
391         >>> cdw.send(False)
392         >>> cdw.send([0, 1, 2])
393         [0, 1, 2]
394         >>> cdw.send(1)
395         1
396         >>> cdw.send(True)
397         True
398         """
399         while True:
400                 item = yield
401                 if not pred(item):
402                         break
403
404         while True:
405                 item = yield
406                 target.send(item)
407
408
409 @autostart
410 def cotakewhile(target, pred):
411         """
412         >>> ctw = cotakewhile(printer_sink("%s"), lambda x: x)
413         >>> ctw.send([0, 1, 2])
414         [0, 1, 2]
415         >>> ctw.send(1)
416         1
417         >>> ctw.send(True)
418         True
419         >>> ctw.send(False)
420         >>> ctw.send([0, 1, 2])
421         >>> ctw.send(1)
422         >>> ctw.send(True)
423         """
424         while True:
425                 item = yield
426                 if not pred(item):
427                         break
428                 target.send(item)
429
430         while True:
431                 item = yield
432
433
434 @autostart
435 def coslice(target, lower, upper):
436         """
437         >>> cs = coslice(printer_sink("%r"), 3, 5)
438         >>> cs.send("0")
439         >>> cs.send("1")
440         >>> cs.send("2")
441         >>> cs.send("3")
442         '3'
443         >>> cs.send("4")
444         '4'
445         >>> cs.send("5")
446         >>> cs.send("6")
447         """
448         for i in xrange(lower):
449                 item = yield
450         for i in xrange(upper - lower):
451                 item = yield
452                 target.send(item)
453         while True:
454                 item = yield
455
456
457 @autostart
458 def cochain(targets):
459         """
460         >>> cr = cointercept(printer_sink("good %s"), [1, 2, 3, 4])
461         >>> cc = cochain([cr, printer_sink("end %s")])
462         >>> cc.send("a")
463         good 1
464         >>> cc.send(None)
465         good 2
466         >>> cc.send([])
467         good 3
468         >>> cc.send(0)
469         good 4
470         >>> cc.send("Bye")
471         end Bye
472         """
473         behind = []
474         for target in targets:
475                 try:
476                         while behind:
477                                 item = behind.pop()
478                                 target.send(item)
479                         while True:
480                                 item = yield
481                                 target.send(item)
482                 except StopIteration:
483                         behind.append(item)
484
485
486 @autostart
487 def queue_sink(queue):
488         """
489         >>> q = Queue.Queue()
490         >>> qs = queue_sink(q)
491         >>> qs.send("Hello")
492         >>> qs.send("World")
493         >>> qs.throw(RuntimeError, "Goodbye")
494         >>> qs.send("Meh")
495         >>> qs.close()
496         >>> print [i for i in _flush_queue(q)]
497         [(None, 'Hello'), (None, 'World'), (<type 'exceptions.RuntimeError'>, 'Goodbye'), (None, 'Meh'), (<type 'exceptions.GeneratorExit'>, None)]
498         """
499         while True:
500                 try:
501                         item = yield
502                         queue.put((None, item))
503                 except StandardError, e:
504                         queue.put((e.__class__, e.message))
505                 except GeneratorExit:
506                         queue.put((GeneratorExit, None))
507                         raise
508
509
510 def decode_item(item, target):
511         if item[0] is None:
512                 target.send(item[1])
513                 return False
514         elif item[0] is GeneratorExit:
515                 target.close()
516                 return True
517         else:
518                 target.throw(item[0], item[1])
519                 return False
520
521
522 def queue_source(queue, target):
523         """
524         >>> q = Queue.Queue()
525         >>> for i in [
526         ...     (None, 'Hello'),
527         ...     (None, 'World'),
528         ...     (GeneratorExit, None),
529         ...     ]:
530         ...     q.put(i)
531         >>> qs = queue_source(q, printer_sink())
532         Hello
533         World
534         """
535         isDone = False
536         while not isDone:
537                 item = queue.get()
538                 isDone = decode_item(item, target)
539
540
541 def threaded_stage(target, thread_factory = threading.Thread):
542         messages = Queue.Queue()
543
544         run_source = functools.partial(queue_source, messages, target)
545         thread_factory(target=run_source).start()
546
547         # Sink running in current thread
548         return functools.partial(queue_sink, messages)
549
550
551 @autostart
552 def pickle_sink(f):
553         while True:
554                 try:
555                         item = yield
556                         pickle.dump((None, item), f)
557                 except StandardError, e:
558                         pickle.dump((e.__class__, e.message), f)
559                 except GeneratorExit:
560                         pickle.dump((GeneratorExit, ), f)
561                         raise
562                 except StopIteration:
563                         f.close()
564                         return
565
566
567 def pickle_source(f, target):
568         try:
569                 isDone = False
570                 while not isDone:
571                         item = pickle.load(f)
572                         isDone = decode_item(item, target)
573         except EOFError:
574                 target.close()
575
576
577 class EventHandler(object, xml.sax.ContentHandler):
578
579         START = "start"
580         TEXT = "text"
581         END = "end"
582
583         def __init__(self, target):
584                 object.__init__(self)
585                 xml.sax.ContentHandler.__init__(self)
586                 self._target = target
587
588         def startElement(self, name, attrs):
589                 self._target.send((self.START, (name, attrs._attrs)))
590
591         def characters(self, text):
592                 self._target.send((self.TEXT, text))
593
594         def endElement(self, name):
595                 self._target.send((self.END, name))
596
597
598 def expat_parse(f, target):
599         parser = xml.parsers.expat.ParserCreate()
600         parser.buffer_size = 65536
601         parser.buffer_text = True
602         parser.returns_unicode = False
603         parser.StartElementHandler = lambda name, attrs: target.send(('start', (name, attrs)))
604         parser.EndElementHandler = lambda name: target.send(('end', name))
605         parser.CharacterDataHandler = lambda data: target.send(('text', data))
606         parser.ParseFile(f)
607
608
609 if __name__ == "__main__":
610         import doctest
611         doctest.testmod()