Updating from skeleton
[gc-dialer] / src / util / concurrent.py
index 503a1b4..f5f6e1d 100644 (file)
@@ -7,6 +7,97 @@ import errno
 import time
 import functools
 import contextlib
+import logging
+
+import misc
+
+
+_moduleLogger = logging.getLogger(__name__)
+
+
+class AsyncTaskQueue(object):
+
+       def __init__(self, taskPool):
+               self._asyncs = []
+               self._taskPool = taskPool
+
+       def add_async(self, func):
+               self.flush()
+               a = AsyncGeneratorTask(self._taskPool, func)
+               self._asyncs.append(a)
+               return a
+
+       def flush(self):
+               self._asyncs = [a for a in self._asyncs if not a.isDone]
+
+
+class AsyncGeneratorTask(object):
+
+       def __init__(self, pool, func):
+               self._pool = pool
+               self._func = func
+               self._run = None
+               self._isDone = False
+
+       @property
+       def isDone(self):
+               return self._isDone
+
+       def start(self, *args, **kwds):
+               assert self._run is None, "Task already started"
+               self._run = self._func(*args, **kwds)
+               trampoline, args, kwds = self._run.send(None) # priming the function
+               self._pool.add_task(
+                       trampoline,
+                       args,
+                       kwds,
+                       self.on_success,
+                       self.on_error,
+               )
+
+       @misc.log_exception(_moduleLogger)
+       def on_success(self, result):
+               _moduleLogger.debug("Processing success for: %r", self._func)
+               try:
+                       trampoline, args, kwds = self._run.send(result)
+               except StopIteration, e:
+                       self._isDone = True
+               else:
+                       self._pool.add_task(
+                               trampoline,
+                               args,
+                               kwds,
+                               self.on_success,
+                               self.on_error,
+                       )
+
+       @misc.log_exception(_moduleLogger)
+       def on_error(self, error):
+               _moduleLogger.debug("Processing error for: %r", self._func)
+               try:
+                       trampoline, args, kwds = self._run.throw(error)
+               except StopIteration, e:
+                       self._isDone = True
+               else:
+                       self._pool.add_task(
+                               trampoline,
+                               args,
+                               kwds,
+                               self.on_success,
+                               self.on_error,
+                       )
+
+       def __repr__(self):
+               return "<async %s at 0x%x>" % (self._func.__name__, id(self))
+
+       def __hash__(self):
+               return hash(self._func)
+
+       def __eq__(self, other):
+               return self._func == other._func
+
+       def __ne__(self, other):
+               return self._func != other._func
 
 
 def synchronized(lock):