Add simple and naive multimethods implementation.

This commit is contained in:
Andrey Popp 2010-07-15 10:39:45 +04:00
parent 8fd082aa8a
commit ac08b7f910
2 changed files with 179 additions and 0 deletions

62
generic/multidispatch.py Normal file
View File

@ -0,0 +1,62 @@
""" Multidispatch for functions."""
import functools
import inspect
from generic.registry import Registry
from generic.registry import TypeAxis
__all__ = []
# function name -> dispatcher
dispatchers = {}
def multimethod(*arg_types):
global dispatchers
def register_rule(func):
if func.__name__ in dispatchers:
dispatcher = dispatchers[func.__name__]
else:
dispatcher = functools.wraps(func)(Dispatcher(len(arg_types)))
dispatchers[func.__name__] = dispatcher
dispatcher.register_rule(func, *arg_types)
return dispatcher
return register_rule
def reset():
""" Reset dispatchers. Useful for testing."""
global dispatchers
dispatchers = {}
class Dispatcher(object):
""" Function call dispatcher based on argument types."""
def __init__(self, arity):
self.arity = arity
axis = [("arg_%d" % n, TypeAxis()) for n in range(arity)]
self.registry = Registry(*axis)
def register_rule(self, rule, *args):
self.check_rule(rule)
self.registry.register(rule, *args)
def lookup_rule(self, *args):
return self.registry.lookup(*args)
def check_rule(self, rule):
argspec = inspect.getargspec(rule)
if argspec.defaults:
raise NotImplementedError("Keyword argument support "
"not implemented yet.")
if not len(argspec.args) == self.arity:
raise TypeError("Rule does not conform "
"to previous implementations.")
def __call__(self, *args):
rule = self.lookup_rule(*args)
if rule is None:
raise TypeError("No avaible rule found for %r" % (args,))
return rule(*args)

View File

@ -0,0 +1,117 @@
""" Tests for :module:`generic.multidispatch`."""
import unittest
__all__ = []
class DispatcherTests(unittest.TestCase):
def test_one_argument(self):
from generic.multidispatch import Dispatcher
dispatcher = Dispatcher(1)
dispatcher.register_rule(lambda x: x + 1, int)
self.assertEqual(dispatcher(1), 2)
self.assertRaises(TypeError, dispatcher.__call__, "s")
dispatcher.register_rule(lambda x: x + "1", str)
self.assertEqual(dispatcher(1), 2)
self.assertEqual(dispatcher("1"), "11")
self.assertRaises(TypeError, dispatcher.__call__, tuple())
def test_two_arguments(self):
from generic.multidispatch import Dispatcher
dispatcher = Dispatcher(2)
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
self.assertEqual(dispatcher(1, 2), 4)
self.assertRaises(TypeError, dispatcher.__call__, "s", "ss")
self.assertRaises(TypeError, dispatcher.__call__, 1, "ss")
self.assertRaises(TypeError, dispatcher.__call__, "s", 2)
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
self.assertEqual(dispatcher(1, 2), 4)
self.assertEqual(dispatcher("1", "2"), "121")
self.assertRaises(TypeError, dispatcher.__call__, "1", 1)
self.assertRaises(TypeError, dispatcher.__call__, 1, "1")
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
self.assertEqual(dispatcher(1, 2), 4)
self.assertEqual(dispatcher("1", "2"), "121")
self.assertEqual(dispatcher(1, "2"), "121")
self.assertRaises(TypeError, dispatcher.__call__, "1", 1)
def test_bottom_rule(self):
from generic.multidispatch import Dispatcher
dispatcher = Dispatcher(1)
dispatcher.register_rule(lambda x: x, object)
self.assertEqual(dispatcher(1), 1)
self.assertEqual(dispatcher("1"), "1")
self.assertEqual(dispatcher([1]), [1])
self.assertEqual(dispatcher((1,)), (1,))
def test_subtype_evaluation(self):
class Super(object):
pass
class Sub(Super):
pass
from generic.multidispatch import Dispatcher
dispatcher = Dispatcher(1)
dispatcher.register_rule(lambda x: x, Super)
o_super = Super()
self.assertEqual(dispatcher(o_super), o_super)
o_sub = Sub()
self.assertEqual(dispatcher(o_sub), o_sub)
self.assertRaises(TypeError, dispatcher.__call__, object())
dispatcher.register_rule(lambda x: (x, x), Sub)
o_super = Super()
self.assertEqual(dispatcher(o_super), o_super)
o_sub = Sub()
self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
def test_register_rule_with_different_arity(self):
from generic.multidispatch import Dispatcher
dispatcher = Dispatcher(1)
dispatcher.register_rule(lambda x: x, int)
self.assertRaises(TypeError, dispatcher.register_rule, lambda x, y: x)
def test_register_rule_wit_kw_args(self):
# Keyword args do not supported right now.
from generic.multidispatch import Dispatcher
dispatcher = Dispatcher(1)
self.assertRaises(
NotImplementedError,
dispatcher.register_rule, lambda x=1: x)
class TestMultimethod(unittest.TestCase):
def tearDown(self):
from generic.multidispatch import reset
reset()
def test_it(self):
from generic.multidispatch import multimethod
@multimethod(int, str)
def func(x, y):
return str(x) + y
self.assertEqual(func(1, "2"), "12")
self.assertRaises(TypeError, func, 1, 2)
self.assertRaises(TypeError, func, "1", 2)
self.assertRaises(TypeError, func, "1", "2")
@multimethod(str, str)
def func(x, y):
return x + y
self.assertEqual(func(1, "2"), "12")
self.assertEqual(func("1", "2"), "12")
self.assertRaises(TypeError, func, 1, 2)
self.assertRaises(TypeError, func, "1", 2)