Add simple and naive multimethods implementation.
This commit is contained in:
parent
8fd082aa8a
commit
ac08b7f910
62
generic/multidispatch.py
Normal file
62
generic/multidispatch.py
Normal file
@ -0,0 +1,62 @@
|
||||
""" Multidispatch for functions."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
from generic.registry import Registry
|
||||
from generic.registry import TypeAxis
|
||||
|
||||
__all__ = []
|
||||
|
||||
# function name -> dispatcher
|
||||
dispatchers = {}
|
||||
|
||||
|
||||
def multimethod(*arg_types):
|
||||
global dispatchers
|
||||
def register_rule(func):
|
||||
if func.__name__ in dispatchers:
|
||||
dispatcher = dispatchers[func.__name__]
|
||||
else:
|
||||
dispatcher = functools.wraps(func)(Dispatcher(len(arg_types)))
|
||||
dispatchers[func.__name__] = dispatcher
|
||||
dispatcher.register_rule(func, *arg_types)
|
||||
return dispatcher
|
||||
return register_rule
|
||||
|
||||
|
||||
def reset():
|
||||
""" Reset dispatchers. Useful for testing."""
|
||||
global dispatchers
|
||||
dispatchers = {}
|
||||
|
||||
|
||||
class Dispatcher(object):
|
||||
""" Function call dispatcher based on argument types."""
|
||||
|
||||
def __init__(self, arity):
|
||||
self.arity = arity
|
||||
axis = [("arg_%d" % n, TypeAxis()) for n in range(arity)]
|
||||
self.registry = Registry(*axis)
|
||||
|
||||
def register_rule(self, rule, *args):
|
||||
self.check_rule(rule)
|
||||
self.registry.register(rule, *args)
|
||||
|
||||
def lookup_rule(self, *args):
|
||||
return self.registry.lookup(*args)
|
||||
|
||||
def check_rule(self, rule):
|
||||
argspec = inspect.getargspec(rule)
|
||||
if argspec.defaults:
|
||||
raise NotImplementedError("Keyword argument support "
|
||||
"not implemented yet.")
|
||||
if not len(argspec.args) == self.arity:
|
||||
raise TypeError("Rule does not conform "
|
||||
"to previous implementations.")
|
||||
|
||||
def __call__(self, *args):
|
||||
rule = self.lookup_rule(*args)
|
||||
if rule is None:
|
||||
raise TypeError("No avaible rule found for %r" % (args,))
|
||||
return rule(*args)
|
117
generic/tests/test_multidispatch.py
Normal file
117
generic/tests/test_multidispatch.py
Normal file
@ -0,0 +1,117 @@
|
||||
""" Tests for :module:`generic.multidispatch`."""
|
||||
|
||||
import unittest
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
class DispatcherTests(unittest.TestCase):
|
||||
|
||||
def test_one_argument(self):
|
||||
from generic.multidispatch import Dispatcher
|
||||
dispatcher = Dispatcher(1)
|
||||
|
||||
dispatcher.register_rule(lambda x: x + 1, int)
|
||||
self.assertEqual(dispatcher(1), 2)
|
||||
self.assertRaises(TypeError, dispatcher.__call__, "s")
|
||||
|
||||
dispatcher.register_rule(lambda x: x + "1", str)
|
||||
self.assertEqual(dispatcher(1), 2)
|
||||
self.assertEqual(dispatcher("1"), "11")
|
||||
self.assertRaises(TypeError, dispatcher.__call__, tuple())
|
||||
|
||||
def test_two_arguments(self):
|
||||
from generic.multidispatch import Dispatcher
|
||||
dispatcher = Dispatcher(2)
|
||||
|
||||
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
|
||||
self.assertEqual(dispatcher(1, 2), 4)
|
||||
self.assertRaises(TypeError, dispatcher.__call__, "s", "ss")
|
||||
self.assertRaises(TypeError, dispatcher.__call__, 1, "ss")
|
||||
self.assertRaises(TypeError, dispatcher.__call__, "s", 2)
|
||||
|
||||
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
|
||||
self.assertEqual(dispatcher(1, 2), 4)
|
||||
self.assertEqual(dispatcher("1", "2"), "121")
|
||||
self.assertRaises(TypeError, dispatcher.__call__, "1", 1)
|
||||
self.assertRaises(TypeError, dispatcher.__call__, 1, "1")
|
||||
|
||||
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
|
||||
self.assertEqual(dispatcher(1, 2), 4)
|
||||
self.assertEqual(dispatcher("1", "2"), "121")
|
||||
self.assertEqual(dispatcher(1, "2"), "121")
|
||||
self.assertRaises(TypeError, dispatcher.__call__, "1", 1)
|
||||
|
||||
def test_bottom_rule(self):
|
||||
from generic.multidispatch import Dispatcher
|
||||
dispatcher = Dispatcher(1)
|
||||
|
||||
dispatcher.register_rule(lambda x: x, object)
|
||||
self.assertEqual(dispatcher(1), 1)
|
||||
self.assertEqual(dispatcher("1"), "1")
|
||||
self.assertEqual(dispatcher([1]), [1])
|
||||
self.assertEqual(dispatcher((1,)), (1,))
|
||||
|
||||
def test_subtype_evaluation(self):
|
||||
class Super(object):
|
||||
pass
|
||||
class Sub(Super):
|
||||
pass
|
||||
|
||||
from generic.multidispatch import Dispatcher
|
||||
dispatcher = Dispatcher(1)
|
||||
|
||||
dispatcher.register_rule(lambda x: x, Super)
|
||||
o_super = Super()
|
||||
self.assertEqual(dispatcher(o_super), o_super)
|
||||
o_sub = Sub()
|
||||
self.assertEqual(dispatcher(o_sub), o_sub)
|
||||
self.assertRaises(TypeError, dispatcher.__call__, object())
|
||||
|
||||
dispatcher.register_rule(lambda x: (x, x), Sub)
|
||||
o_super = Super()
|
||||
self.assertEqual(dispatcher(o_super), o_super)
|
||||
o_sub = Sub()
|
||||
self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
|
||||
|
||||
def test_register_rule_with_different_arity(self):
|
||||
from generic.multidispatch import Dispatcher
|
||||
dispatcher = Dispatcher(1)
|
||||
dispatcher.register_rule(lambda x: x, int)
|
||||
self.assertRaises(TypeError, dispatcher.register_rule, lambda x, y: x)
|
||||
|
||||
def test_register_rule_wit_kw_args(self):
|
||||
# Keyword args do not supported right now.
|
||||
from generic.multidispatch import Dispatcher
|
||||
dispatcher = Dispatcher(1)
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
dispatcher.register_rule, lambda x=1: x)
|
||||
|
||||
|
||||
class TestMultimethod(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
from generic.multidispatch import reset
|
||||
reset()
|
||||
|
||||
def test_it(self):
|
||||
from generic.multidispatch import multimethod
|
||||
|
||||
@multimethod(int, str)
|
||||
def func(x, y):
|
||||
return str(x) + y
|
||||
|
||||
self.assertEqual(func(1, "2"), "12")
|
||||
self.assertRaises(TypeError, func, 1, 2)
|
||||
self.assertRaises(TypeError, func, "1", 2)
|
||||
self.assertRaises(TypeError, func, "1", "2")
|
||||
|
||||
@multimethod(str, str)
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
self.assertEqual(func(1, "2"), "12")
|
||||
self.assertEqual(func("1", "2"), "12")
|
||||
self.assertRaises(TypeError, func, 1, 2)
|
||||
self.assertRaises(TypeError, func, "1", 2)
|
Loading…
x
Reference in New Issue
Block a user