Overriding for rules.

This commit is contained in:
Andrey Popp 2010-07-19 00:39:21 +04:00
parent 3ab1e27ff4
commit 25dd639a8c
2 changed files with 63 additions and 10 deletions

View File

@ -55,8 +55,7 @@ class Dispatcher(object):
axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)]
self.registry = Registry(*axis)
def register_rule(self, rule, *argtypes):
""" Register new ``rule`` for ``argtypes``."""
def check_rule(self, rule, *argtypes):
# Check if we have the right number of parametrized types
if len(argtypes) != self.params_arity:
raise TypeError("Wrong number of type parameters.")
@ -67,8 +66,16 @@ class Dispatcher(object):
raise TypeError("Rule does not conform "
"to previous implementations.")
def register_rule(self, rule, *argtypes):
""" Register new ``rule`` for ``argtypes``."""
self.check_rule(rule, *argtypes)
self.registry.register(rule, *argtypes)
def override_rule(self, rule, *argtypes):
""" Override ``rule`` for ``argtypes``."""
self.check_rule(rule, *argtypes)
self.registry.override(rule, *argtypes)
def lookup_rule(self, *args):
""" Lookup rule by ``args``. Returns None if no rule was found."""
args = args[:self.params_arity]
@ -84,6 +91,12 @@ class Dispatcher(object):
return self
return register_rule
def override(self, *argtypes):
def override_rule(func):
self.override_rule(func, *argtypes)
return self
return override_rule
def __call__(self, *args, **kwargs):
""" Dispatch call to appropriate rule."""
rule = self.lookup_rule(*args)
@ -105,7 +118,7 @@ class MethodDispatcher(Dispatcher):
def proceed_unbound_rules(self, cls):
for argtypes, func in self.local.unbound_rules:
argtypes = (cls,) + argtypes
self.register_rule(func, *argtypes)
self.override_rule(func, *argtypes)
self.local.unbound_rules = []
def __get__(self, obj, cls):
@ -119,6 +132,8 @@ class MethodDispatcher(Dispatcher):
return self
return make_declaration
override = when
def arity(argspec):
""" Determinal positional arity of argspec."""

View File

@ -156,21 +156,21 @@ class MultifunctionTests(unittest.TestCase):
self.assertRaises(TypeError, func, "1", 2)
def test_overriding(self):
# XXX: for now, overriding is not allowed and Value error is raised
# open questions are:
# 1. Should we allow overriding by default.
# a. If yes, should it be implicit or explicit (something like
# Dispatcher.override method)
# b. If no -- what exception we should raise.
from generic.multidispatch import multifunction
@multifunction(int, str)
def func(x, y):
return str(x) + y
self.assertEqual(func(1, "2"), "12")
self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
@func.override(int, str)
def func(x, y):
return y + str(x)
self.assertEqual(func(1, "2"), "21")
class MultimethodTests(unittest.TestCase):
@ -232,3 +232,41 @@ class MultimethodTests(unittest.TestCase):
self.assertEqual(DummySub().foo((1,2)), (1,2,1))
self.assertEqual(DummySub().foo(True), False)
self.assertRaises(TypeError, DummySub().foo, [])
def test_override(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods
class Dummy(object):
@multimethod(str, str)
def foo(self, x, y):
return x + y
@foo.when(str, str)
def foo(self, x, y):
return y + x
self.assertEqual(Dummy().foo("1", "2"), "21")
def test_inheritance_override(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods
class Dummy(object):
@multimethod(int)
def foo(self, x):
return x + 1
@has_multimethods
class DummySub(Dummy):
@Dummy.foo.when(int)
def foo(self, x):
return x + 2
self.assertEqual(Dummy().foo(1), 2)
self.assertEqual(DummySub().foo(1), 3)