Favorites support
[gonvert] / src / util / overloading.py
1 #!/usr/bin/env python
2 import new
3
4 # Make the environment more like Python 3.0
5 __metaclass__ = type
6 from itertools import izip as zip
7 import textwrap
8 import inspect
9
10
11 __all__ = [
12         "AnyType",
13         "overloaded"
14 ]
15
16
17 AnyType = object
18
19
20 class overloaded:
21         """
22         Dynamically overloaded functions.
23
24         This is an implementation of (dynamically, or run-time) overloaded
25         functions; also known as generic functions or multi-methods.
26
27         The dispatch algorithm uses the types of all argument for dispatch,
28         similar to (compile-time) overloaded functions or methods in C++ and
29         Java.
30
31         Most of the complexity in the algorithm comes from the need to support
32         subclasses in call signatures.  For example, if an function is
33         registered for a signature (T1, T2), then a call with a signature (S1,
34         S2) is acceptable, assuming that S1 is a subclass of T1, S2 a subclass
35         of T2, and there are no other more specific matches (see below).
36
37         If there are multiple matches and one of those doesn't *dominate* all
38         others, the match is deemed ambiguous and an exception is raised.  A
39         subtlety here: if, after removing the dominated matches, there are
40         still multiple matches left, but they all map to the same function,
41         then the match is not deemed ambiguous and that function is used.
42         Read the method find_func() below for details.
43
44         @note Python 2.5 is required due to the use of predicates any() and all().
45         @note only supports positional arguments
46
47         @author http://www.artima.com/weblogs/viewpost.jsp?thread=155514
48
49         >>> import misc
50         >>> misc.validate_decorator (overloaded)
51         >>>
52         >>>
53         >>>
54         >>>
55         >>> #################
56         >>> #Basics, with reusing names and without
57         >>> @overloaded
58         ... def foo(x):
59         ...     "prints x"
60         ...     print x
61         ...
62         >>> @foo.register(int)
63         ... def foo(x):
64         ...     "prints the hex representation of x"
65         ...     print hex(x)
66         ...
67         >>> from types import DictType
68         >>> @foo.register(DictType)
69         ... def foo_dict(x):
70         ...     "prints the keys of x"
71         ...     print [k for k in x.iterkeys()]
72         ...
73         >>> #combines all of the doc strings to help keep track of the specializations
74         >>> foo.__doc__  # doctest: +ELLIPSIS
75         "prints x\\n\\n...overloading.foo (<type 'int'>):\\n\\tprints the hex representation of x\\n\\n...overloading.foo_dict (<type 'dict'>):\\n\\tprints the keys of x"
76         >>> foo ("text")
77         text
78         >>> foo (10) #calling the specialized foo
79         0xa
80         >>> foo ({3:5, 6:7}) #calling the specialization foo_dict
81         [3, 6]
82         >>> foo_dict ({3:5, 6:7}) #with using a unique name, you still have the option of calling the function directly
83         [3, 6]
84         >>>
85         >>>
86         >>>
87         >>>
88         >>> #################
89         >>> #Multiple arguments, accessing the default, and function finding
90         >>> @overloaded
91         ... def two_arg (x, y):
92         ...     print x,y
93         ...
94         >>> @two_arg.register(int, int)
95         ... def two_arg_int_int (x, y):
96         ...     print hex(x), hex(y)
97         ...
98         >>> @two_arg.register(float, int)
99         ... def two_arg_float_int (x, y):
100         ...     print x, hex(y)
101         ...
102         >>> @two_arg.register(int, float)
103         ... def two_arg_int_float (x, y):
104         ...     print hex(x), y
105         ...
106         >>> two_arg.__doc__ # doctest: +ELLIPSIS
107         "...overloading.two_arg_int_int (<type 'int'>, <type 'int'>):\\n\\n...overloading.two_arg_float_int (<type 'float'>, <type 'int'>):\\n\\n...overloading.two_arg_int_float (<type 'int'>, <type 'float'>):"
108         >>> two_arg(9, 10)
109         0x9 0xa
110         >>> two_arg(9.0, 10)
111         9.0 0xa
112         >>> two_arg(15, 16.0)
113         0xf 16.0
114         >>> two_arg.default_func(9, 10)
115         9 10
116         >>> two_arg.find_func ((int, float)) == two_arg_int_float
117         True
118         >>> (int, float) in two_arg
119         True
120         >>> (str, int) in two_arg
121         False
122         >>>
123         >>>
124         >>>
125         >>> #################
126         >>> #wildcard
127         >>> @two_arg.register(AnyType, str)
128         ... def two_arg_any_str (x, y):
129         ...     print x, y.lower()
130         ...
131         >>> two_arg("Hello", "World")
132         Hello world
133         >>> two_arg(500, "World")
134         500 world
135         """
136
137         def __init__(self, default_func):
138                 # Decorator to declare new overloaded function.
139                 self.registry = {}
140                 self.cache = {}
141                 self.default_func = default_func
142                 self.__name__ = self.default_func.__name__
143                 self.__doc__ = self.default_func.__doc__
144                 self.__dict__.update (self.default_func.__dict__)
145
146         def __get__(self, obj, type=None):
147                 if obj is None:
148                         return self
149                 return new.instancemethod(self, obj)
150
151         def register(self, *types):
152                 """
153                 Decorator to register an implementation for a specific set of types.
154
155                 .register(t1, t2)(f) is equivalent to .register_func((t1, t2), f).
156                 """
157
158                 def helper(func):
159                         self.register_func(types, func)
160
161                         originalDoc = self.__doc__ if self.__doc__ is not None else ""
162                         typeNames = ", ".join ([str(type) for type in types])
163                         typeNames = "".join ([func.__module__+".", func.__name__, " (", typeNames, "):"])
164                         overloadedDoc = ""
165                         if func.__doc__ is not None:
166                                 overloadedDoc = textwrap.fill (func.__doc__, width=60, initial_indent="\t", subsequent_indent="\t")
167                         self.__doc__ = "\n".join ([originalDoc, "", typeNames, overloadedDoc]).strip()
168
169                         new_func = func
170
171                         #Masking the function, so we want to take on its traits
172                         if func.__name__ == self.__name__:
173                                 self.__dict__.update (func.__dict__)
174                                 new_func = self
175                         return new_func
176
177                 return helper
178
179         def register_func(self, types, func):
180                 """Helper to register an implementation."""
181                 self.registry[tuple(types)] = func
182                 self.cache = {} # Clear the cache (later we can optimize this).
183
184         def __call__(self, *args):
185                 """Call the overloaded function."""
186                 types = tuple(map(type, args))
187                 func = self.cache.get(types)
188                 if func is None:
189                         self.cache[types] = func = self.find_func(types)
190                 return func(*args)
191
192         def __contains__ (self, types):
193                 return self.find_func(types) is not self.default_func
194
195         def find_func(self, types):
196                 """Find the appropriate overloaded function; don't call it.
197
198                 @note This won't work for old-style classes or classes without __mro__
199                 """
200                 func = self.registry.get(types)
201                 if func is not None:
202                         # Easy case -- direct hit in registry.
203                         return func
204
205                 # Phillip Eby suggests to use issubclass() instead of __mro__.
206                 # There are advantages and disadvantages.
207
208                 # I can't help myself -- this is going to be intense functional code.
209                 # Find all possible candidate signatures.
210                 mros = tuple(inspect.getmro(t) for t in types)
211                 n = len(mros)
212                 candidates = [sig for sig in self.registry
213                                 if len(sig) == n and
214                                         all(t in mro for t, mro in zip(sig, mros))]
215
216                 if not candidates:
217                         # No match at all -- use the default function.
218                         return self.default_func
219                 elif len(candidates) == 1:
220                         # Unique match -- that's an easy case.
221                         return self.registry[candidates[0]]
222
223                 # More than one match -- weed out the subordinate ones.
224
225                 def dominates(dom, sub,
226                                 orders=tuple(dict((t, i) for i, t in enumerate(mro))
227                                                         for mro in mros)):
228                         # Predicate to decide whether dom strictly dominates sub.
229                         # Strict domination is defined as domination without equality.
230                         # The arguments dom and sub are type tuples of equal length.
231                         # The orders argument is a precomputed auxiliary data structure
232                         # giving dicts of ordering information corresponding to the
233                         # positions in the type tuples.
234                         # A type d dominates a type s iff order[d] <= order[s].
235                         # A type tuple (d1, d2, ...) dominates a type tuple of equal length
236                         # (s1, s2, ...) iff d1 dominates s1, d2 dominates s2, etc.
237                         if dom is sub:
238                                 return False
239                         return all(order[d] <= order[s] for d, s, order in zip(dom, sub, orders))
240
241                 # I suppose I could inline dominates() but it wouldn't get any clearer.
242                 candidates = [cand
243                                 for cand in candidates
244                                         if not any(dominates(dom, cand) for dom in candidates)]
245                 if len(candidates) == 1:
246                         # There's exactly one candidate left.
247                         return self.registry[candidates[0]]
248
249                 # Perhaps these multiple candidates all have the same implementation?
250                 funcs = set(self.registry[cand] for cand in candidates)
251                 if len(funcs) == 1:
252                         return funcs.pop()
253
254                 # No, the situation is irreducibly ambiguous.
255                 raise TypeError("ambigous call; types=%r; candidates=%r" %
256                                                 (types, candidates))