6d802e85a81ca2745d1f0976b8e0c0a9487e02c1
[theonering] / src / util / coroutines.py
1 #!/usr/bin/env python\r
2 \r
3 """\r
4 Uses for generators\r
5 * Pull pipelining (iterators)\r
6 * Push pipelining (coroutines)\r
7 * State machines (coroutines)\r
8 * "Cooperative multitasking" (coroutines)\r
9 * Algorithm -> Object transform for cohesiveness (for example context managers) (coroutines)\r
10 \r
11 Design considerations\r
12 * When should a stage pass on exceptions or have it thrown within it?\r
13 * When should a stage pass on GeneratorExits?\r
14 * Is there a way to either turn a push generator into a iterator or to use\r
15         comprehensions syntax for push generators (I doubt it)\r
16 * When should the stage try and send data in both directions\r
17 * Since pull generators (generators), push generators (coroutines), subroutines, and coroutines are all coroutines, maybe we should rename the push generators to not confuse them, like signals/slots? and then refer to two-way generators as coroutines\r
18 ** If so, make s* and co* implementation of functions\r
19 """\r
20 \r
21 import threading\r
22 import Queue\r
23 import pickle\r
24 import functools\r
25 import itertools\r
26 import xml.sax\r
27 import xml.parsers.expat\r
28 \r
29 \r
30 def autostart(func):\r
31         """\r
32         >>> @autostart\r
33         ... def grep_sink(pattern):\r
34         ...     print "Looking for %s" % pattern\r
35         ...     while True:\r
36         ...             line = yield\r
37         ...             if pattern in line:\r
38         ...                     print line,\r
39         >>> g = grep_sink("python")\r
40         Looking for python\r
41         >>> g.send("Yeah but no but yeah but no")\r
42         >>> g.send("A series of tubes")\r
43         >>> g.send("python generators rock!")\r
44         python generators rock!\r
45         >>> g.close()\r
46         """\r
47 \r
48         @functools.wraps(func)\r
49         def start(*args, **kwargs):\r
50                 cr = func(*args, **kwargs)\r
51                 cr.next()\r
52                 return cr\r
53 \r
54         return start\r
55 \r
56 \r
57 @autostart\r
58 def printer_sink(format = "%s"):\r
59         """\r
60         >>> pr = printer_sink("%r")\r
61         >>> pr.send("Hello")\r
62         'Hello'\r
63         >>> pr.send("5")\r
64         '5'\r
65         >>> pr.send(5)\r
66         5\r
67         >>> p = printer_sink()\r
68         >>> p.send("Hello")\r
69         Hello\r
70         >>> p.send("World")\r
71         World\r
72         >>> # p.throw(RuntimeError, "Goodbye")\r
73         >>> # p.send("Meh")\r
74         >>> # p.close()\r
75         """\r
76         while True:\r
77                 item = yield\r
78                 print format % (item, )\r
79 \r
80 \r
81 @autostart\r
82 def null_sink():\r
83         """\r
84         Good for uses like with cochain to pick up any slack\r
85         """\r
86         while True:\r
87                 item = yield\r
88 \r
89 \r
90 def itr_source(itr, target):\r
91         """\r
92         >>> itr_source(xrange(2), printer_sink())\r
93         0\r
94         1\r
95         """\r
96         for item in itr:\r
97                 target.send(item)\r
98 \r
99 \r
100 @autostart\r
101 def cofilter(predicate, target):\r
102         """\r
103         >>> p = printer_sink()\r
104         >>> cf = cofilter(None, p)\r
105         >>> cf.send("")\r
106         >>> cf.send("Hello")\r
107         Hello\r
108         >>> cf.send([])\r
109         >>> cf.send([1, 2])\r
110         [1, 2]\r
111         >>> cf.send(False)\r
112         >>> cf.send(True)\r
113         True\r
114         >>> cf.send(0)\r
115         >>> cf.send(1)\r
116         1\r
117         >>> # cf.throw(RuntimeError, "Goodbye")\r
118         >>> # cf.send(False)\r
119         >>> # cf.send(True)\r
120         >>> # cf.close()\r
121         """\r
122         if predicate is None:\r
123                 predicate = bool\r
124 \r
125         while True:\r
126                 try:\r
127                         item = yield\r
128                         if predicate(item):\r
129                                 target.send(item)\r
130                 except StandardError, e:\r
131                         target.throw(e.__class__, e.message)\r
132 \r
133 \r
134 @autostart\r
135 def comap(function, target):\r
136         """\r
137         >>> p = printer_sink()\r
138         >>> cm = comap(lambda x: x+1, p)\r
139         >>> cm.send(0)\r
140         1\r
141         >>> cm.send(1.0)\r
142         2.0\r
143         >>> cm.send(-2)\r
144         -1\r
145         >>> # cm.throw(RuntimeError, "Goodbye")\r
146         >>> # cm.send(0)\r
147         >>> # cm.send(1.0)\r
148         >>> # cm.close()\r
149         """\r
150         while True:\r
151                 try:\r
152                         item = yield\r
153                         mappedItem = function(item)\r
154                         target.send(mappedItem)\r
155                 except StandardError, e:\r
156                         target.throw(e.__class__, e.message)\r
157 \r
158 \r
159 @autostart\r
160 def append_sink(l):\r
161         """\r
162         >>> l = []\r
163         >>> apps = append_sink(l)\r
164         >>> apps.send(1)\r
165         >>> apps.send(2)\r
166         >>> apps.send(3)\r
167         >>> print l\r
168         [1, 2, 3]\r
169         """\r
170         while True:\r
171                 item = yield\r
172                 l.append(item)\r
173 \r
174 \r
175 @autostart\r
176 def last_n_sink(l, n = 1):\r
177         """\r
178         >>> l = []\r
179         >>> lns = last_n_sink(l)\r
180         >>> lns.send(1)\r
181         >>> lns.send(2)\r
182         >>> lns.send(3)\r
183         >>> print l\r
184         [3]\r
185         """\r
186         del l[:]\r
187         while True:\r
188                 item = yield\r
189                 extraCount = len(l) - n + 1\r
190                 if 0 < extraCount:\r
191                         del l[0:extraCount]\r
192                 l.append(item)\r
193 \r
194 \r
195 @autostart\r
196 def coreduce(target, function, initializer = None):\r
197         """\r
198         >>> reduceResult = []\r
199         >>> lns = last_n_sink(reduceResult)\r
200         >>> cr = coreduce(lns, lambda x, y: x + y, 0)\r
201         >>> cr.send(1)\r
202         >>> cr.send(2)\r
203         >>> cr.send(3)\r
204         >>> print reduceResult\r
205         [6]\r
206         >>> cr = coreduce(lns, lambda x, y: x + y)\r
207         >>> cr.send(1)\r
208         >>> cr.send(2)\r
209         >>> cr.send(3)\r
210         >>> print reduceResult\r
211         [6]\r
212         """\r
213         isFirst = True\r
214         cumulativeRef = initializer\r
215         while True:\r
216                 item = yield\r
217                 if isFirst and initializer is None:\r
218                         cumulativeRef = item\r
219                 else:\r
220                         cumulativeRef = function(cumulativeRef, item)\r
221                 target.send(cumulativeRef)\r
222                 isFirst = False\r
223 \r
224 \r
225 @autostart\r
226 def cotee(targets):\r
227         """\r
228         Takes a sequence of coroutines and sends the received items to all of them\r
229 \r
230         >>> ct = cotee((printer_sink("1 %s"), printer_sink("2 %s")))\r
231         >>> ct.send("Hello")\r
232         1 Hello\r
233         2 Hello\r
234         >>> ct.send("World")\r
235         1 World\r
236         2 World\r
237         >>> # ct.throw(RuntimeError, "Goodbye")\r
238         >>> # ct.send("Meh")\r
239         >>> # ct.close()\r
240         """\r
241         while True:\r
242                 try:\r
243                         item = yield\r
244                         for target in targets:\r
245                                 target.send(item)\r
246                 except StandardError, e:\r
247                         for target in targets:\r
248                                 target.throw(e.__class__, e.message)\r
249 \r
250 \r
251 class CoTee(object):\r
252         """\r
253         >>> ct = CoTee()\r
254         >>> ct.register_sink(printer_sink("1 %s"))\r
255         >>> ct.register_sink(printer_sink("2 %s"))\r
256         >>> ct.stage.send("Hello")\r
257         1 Hello\r
258         2 Hello\r
259         >>> ct.stage.send("World")\r
260         1 World\r
261         2 World\r
262         >>> ct.register_sink(printer_sink("3 %s"))\r
263         >>> ct.stage.send("Foo")\r
264         1 Foo\r
265         2 Foo\r
266         3 Foo\r
267         >>> # ct.stage.throw(RuntimeError, "Goodbye")\r
268         >>> # ct.stage.send("Meh")\r
269         >>> # ct.stage.close()\r
270         """\r
271 \r
272         def __init__(self):\r
273                 self.stage = self._stage()\r
274                 self._targets = []\r
275 \r
276         def register_sink(self, sink):\r
277                 self._targets.append(sink)\r
278 \r
279         def unregister_sink(self, sink):\r
280                 self._targets.remove(sink)\r
281 \r
282         def restart(self):\r
283                 self.stage = self._stage()\r
284 \r
285         @autostart\r
286         def _stage(self):\r
287                 while True:\r
288                         try:\r
289                                 item = yield\r
290                                 for target in self._targets:\r
291                                         target.send(item)\r
292                         except StandardError, e:\r
293                                 for target in self._targets:\r
294                                         target.throw(e.__class__, e.message)\r
295 \r
296 \r
297 def _flush_queue(queue):\r
298         while not queue.empty():\r
299                 yield queue.get()\r
300 \r
301 \r
302 @autostart\r
303 def cocount(target, start = 0):\r
304         """\r
305         >>> cc = cocount(printer_sink("%s"))\r
306         >>> cc.send("a")\r
307         0\r
308         >>> cc.send(None)\r
309         1\r
310         >>> cc.send([])\r
311         2\r
312         >>> cc.send(0)\r
313         3\r
314         """\r
315         for i in itertools.count(start):\r
316                 item = yield\r
317                 target.send(i)\r
318 \r
319 \r
320 @autostart\r
321 def coenumerate(target, start = 0):\r
322         """\r
323         >>> ce = coenumerate(printer_sink("%r"))\r
324         >>> ce.send("a")\r
325         (0, 'a')\r
326         >>> ce.send(None)\r
327         (1, None)\r
328         >>> ce.send([])\r
329         (2, [])\r
330         >>> ce.send(0)\r
331         (3, 0)\r
332         """\r
333         for i in itertools.count(start):\r
334                 item = yield\r
335                 decoratedItem = i, item\r
336                 target.send(decoratedItem)\r
337 \r
338 \r
339 @autostart\r
340 def corepeat(target, elem):\r
341         """\r
342         >>> cr = corepeat(printer_sink("%s"), "Hello World")\r
343         >>> cr.send("a")\r
344         Hello World\r
345         >>> cr.send(None)\r
346         Hello World\r
347         >>> cr.send([])\r
348         Hello World\r
349         >>> cr.send(0)\r
350         Hello World\r
351         """\r
352         while True:\r
353                 item = yield\r
354                 target.send(elem)\r
355 \r
356 \r
357 @autostart\r
358 def cointercept(target, elems):\r
359         """\r
360         >>> cr = cointercept(printer_sink("%s"), [1, 2, 3, 4])\r
361         >>> cr.send("a")\r
362         1\r
363         >>> cr.send(None)\r
364         2\r
365         >>> cr.send([])\r
366         3\r
367         >>> cr.send(0)\r
368         4\r
369         >>> cr.send("Bye")\r
370         Traceback (most recent call last):\r
371           File "/usr/lib/python2.5/doctest.py", line 1228, in __run\r
372             compileflags, 1) in test.globs\r
373           File "<doctest __main__.cointercept[5]>", line 1, in <module>\r
374             cr.send("Bye")\r
375         StopIteration\r
376         """\r
377         item = yield\r
378         for elem in elems:\r
379                 target.send(elem)\r
380                 item = yield\r
381 \r
382 \r
383 @autostart\r
384 def codropwhile(target, pred):\r
385         """\r
386         >>> cdw = codropwhile(printer_sink("%s"), lambda x: x)\r
387         >>> cdw.send([0, 1, 2])\r
388         >>> cdw.send(1)\r
389         >>> cdw.send(True)\r
390         >>> cdw.send(False)\r
391         >>> cdw.send([0, 1, 2])\r
392         [0, 1, 2]\r
393         >>> cdw.send(1)\r
394         1\r
395         >>> cdw.send(True)\r
396         True\r
397         """\r
398         while True:\r
399                 item = yield\r
400                 if not pred(item):\r
401                         break\r
402 \r
403         while True:\r
404                 item = yield\r
405                 target.send(item)\r
406 \r
407 \r
408 @autostart\r
409 def cotakewhile(target, pred):\r
410         """\r
411         >>> ctw = cotakewhile(printer_sink("%s"), lambda x: x)\r
412         >>> ctw.send([0, 1, 2])\r
413         [0, 1, 2]\r
414         >>> ctw.send(1)\r
415         1\r
416         >>> ctw.send(True)\r
417         True\r
418         >>> ctw.send(False)\r
419         >>> ctw.send([0, 1, 2])\r
420         >>> ctw.send(1)\r
421         >>> ctw.send(True)\r
422         """\r
423         while True:\r
424                 item = yield\r
425                 if not pred(item):\r
426                         break\r
427                 target.send(item)\r
428 \r
429         while True:\r
430                 item = yield\r
431 \r
432 \r
433 @autostart\r
434 def coslice(target, lower, upper):\r
435         """\r
436         >>> cs = coslice(printer_sink("%r"), 3, 5)\r
437         >>> cs.send("0")\r
438         >>> cs.send("1")\r
439         >>> cs.send("2")\r
440         >>> cs.send("3")\r
441         '3'\r
442         >>> cs.send("4")\r
443         '4'\r
444         >>> cs.send("5")\r
445         >>> cs.send("6")\r
446         """\r
447         for i in xrange(lower):\r
448                 item = yield\r
449         for i in xrange(upper - lower):\r
450                 item = yield\r
451                 target.send(item)\r
452         while True:\r
453                 item = yield\r
454 \r
455 \r
456 @autostart\r
457 def cochain(targets):\r
458         """\r
459         >>> cr = cointercept(printer_sink("good %s"), [1, 2, 3, 4])\r
460         >>> cc = cochain([cr, printer_sink("end %s")])\r
461         >>> cc.send("a")\r
462         good 1\r
463         >>> cc.send(None)\r
464         good 2\r
465         >>> cc.send([])\r
466         good 3\r
467         >>> cc.send(0)\r
468         good 4\r
469         >>> cc.send("Bye")\r
470         end Bye\r
471         """\r
472         behind = []\r
473         for target in targets:\r
474                 try:\r
475                         while behind:\r
476                                 item = behind.pop()\r
477                                 target.send(item)\r
478                         while True:\r
479                                 item = yield\r
480                                 target.send(item)\r
481                 except StopIteration:\r
482                         behind.append(item)\r
483 \r
484 \r
485 @autostart\r
486 def queue_sink(queue):\r
487         """\r
488         >>> q = Queue.Queue()\r
489         >>> qs = queue_sink(q)\r
490         >>> qs.send("Hello")\r
491         >>> qs.send("World")\r
492         >>> qs.throw(RuntimeError, "Goodbye")\r
493         >>> qs.send("Meh")\r
494         >>> qs.close()\r
495         >>> print [i for i in _flush_queue(q)]\r
496         [(None, 'Hello'), (None, 'World'), (<type 'exceptions.RuntimeError'>, 'Goodbye'), (None, 'Meh'), (<type 'exceptions.GeneratorExit'>, None)]\r
497         """\r
498         while True:\r
499                 try:\r
500                         item = yield\r
501                         queue.put((None, item))\r
502                 except StandardError, e:\r
503                         queue.put((e.__class__, e.message))\r
504                 except GeneratorExit:\r
505                         queue.put((GeneratorExit, None))\r
506                         raise\r
507 \r
508 \r
509 def decode_item(item, target):\r
510         if item[0] is None:\r
511                 target.send(item[1])\r
512                 return False\r
513         elif item[0] is GeneratorExit:\r
514                 target.close()\r
515                 return True\r
516         else:\r
517                 target.throw(item[0], item[1])\r
518                 return False\r
519 \r
520 \r
521 def queue_source(queue, target):\r
522         """\r
523         >>> q = Queue.Queue()\r
524         >>> for i in [\r
525         ...     (None, 'Hello'),\r
526         ...     (None, 'World'),\r
527         ...     (GeneratorExit, None),\r
528         ...     ]:\r
529         ...     q.put(i)\r
530         >>> qs = queue_source(q, printer_sink())\r
531         Hello\r
532         World\r
533         """\r
534         isDone = False\r
535         while not isDone:\r
536                 item = queue.get()\r
537                 isDone = decode_item(item, target)\r
538 \r
539 \r
540 def threaded_stage(target, thread_factory = threading.Thread):\r
541         messages = Queue.Queue()\r
542 \r
543         run_source = functools.partial(queue_source, messages, target)\r
544         thread_factory(target=run_source).start()\r
545 \r
546         # Sink running in current thread\r
547         return functools.partial(queue_sink, messages)\r
548 \r
549 \r
550 @autostart\r
551 def pickle_sink(f):\r
552         while True:\r
553                 try:\r
554                         item = yield\r
555                         pickle.dump((None, item), f)\r
556                 except StandardError, e:\r
557                         pickle.dump((e.__class__, e.message), f)\r
558                 except GeneratorExit:\r
559                         pickle.dump((GeneratorExit, ), f)\r
560                         raise\r
561                 except StopIteration:\r
562                         f.close()\r
563                         return\r
564 \r
565 \r
566 def pickle_source(f, target):\r
567         try:\r
568                 isDone = False\r
569                 while not isDone:\r
570                         item = pickle.load(f)\r
571                         isDone = decode_item(item, target)\r
572         except EOFError:\r
573                 target.close()\r
574 \r
575 \r
576 class EventHandler(object, xml.sax.ContentHandler):\r
577 \r
578         START = "start"\r
579         TEXT = "text"\r
580         END = "end"\r
581 \r
582         def __init__(self, target):\r
583                 object.__init__(self)\r
584                 xml.sax.ContentHandler.__init__(self)\r
585                 self._target = target\r
586 \r
587         def startElement(self, name, attrs):\r
588                 self._target.send((self.START, (name, attrs._attrs)))\r
589 \r
590         def characters(self, text):\r
591                 self._target.send((self.TEXT, text))\r
592 \r
593         def endElement(self, name):\r
594                 self._target.send((self.END, name))\r
595 \r
596 \r
597 def expat_parse(f, target):\r
598         parser = xml.parsers.expat.ParserCreate()\r
599         parser.buffer_size = 65536\r
600         parser.buffer_text = True\r
601         parser.returns_unicode = False\r
602         parser.StartElementHandler = lambda name, attrs: target.send(('start', (name, attrs)))\r
603         parser.EndElementHandler = lambda name: target.send(('end', name))\r
604         parser.CharacterDataHandler = lambda data: target.send(('text', data))\r
605         parser.ParseFile(f)\r
606 \r
607 \r
608 if __name__ == "__main__":\r
609         import doctest\r
610         doctest.testmod()\r