Overriding for rules.
This commit is contained in:
parent
3ab1e27ff4
commit
25dd639a8c
@ -55,8 +55,7 @@ class Dispatcher(object):
|
||||
axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)]
|
||||
self.registry = Registry(*axis)
|
||||
|
||||
def register_rule(self, rule, *argtypes):
|
||||
""" Register new ``rule`` for ``argtypes``."""
|
||||
def check_rule(self, rule, *argtypes):
|
||||
# Check if we have the right number of parametrized types
|
||||
if len(argtypes) != self.params_arity:
|
||||
raise TypeError("Wrong number of type parameters.")
|
||||
@ -67,8 +66,16 @@ class Dispatcher(object):
|
||||
raise TypeError("Rule does not conform "
|
||||
"to previous implementations.")
|
||||
|
||||
def register_rule(self, rule, *argtypes):
|
||||
""" Register new ``rule`` for ``argtypes``."""
|
||||
self.check_rule(rule, *argtypes)
|
||||
self.registry.register(rule, *argtypes)
|
||||
|
||||
def override_rule(self, rule, *argtypes):
|
||||
""" Override ``rule`` for ``argtypes``."""
|
||||
self.check_rule(rule, *argtypes)
|
||||
self.registry.override(rule, *argtypes)
|
||||
|
||||
def lookup_rule(self, *args):
|
||||
""" Lookup rule by ``args``. Returns None if no rule was found."""
|
||||
args = args[:self.params_arity]
|
||||
@ -84,6 +91,12 @@ class Dispatcher(object):
|
||||
return self
|
||||
return register_rule
|
||||
|
||||
def override(self, *argtypes):
|
||||
def override_rule(func):
|
||||
self.override_rule(func, *argtypes)
|
||||
return self
|
||||
return override_rule
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
""" Dispatch call to appropriate rule."""
|
||||
rule = self.lookup_rule(*args)
|
||||
@ -105,7 +118,7 @@ class MethodDispatcher(Dispatcher):
|
||||
def proceed_unbound_rules(self, cls):
|
||||
for argtypes, func in self.local.unbound_rules:
|
||||
argtypes = (cls,) + argtypes
|
||||
self.register_rule(func, *argtypes)
|
||||
self.override_rule(func, *argtypes)
|
||||
self.local.unbound_rules = []
|
||||
|
||||
def __get__(self, obj, cls):
|
||||
@ -119,6 +132,8 @@ class MethodDispatcher(Dispatcher):
|
||||
return self
|
||||
return make_declaration
|
||||
|
||||
override = when
|
||||
|
||||
|
||||
def arity(argspec):
|
||||
""" Determinal positional arity of argspec."""
|
||||
|
@ -156,21 +156,21 @@ class MultifunctionTests(unittest.TestCase):
|
||||
self.assertRaises(TypeError, func, "1", 2)
|
||||
|
||||
def test_overriding(self):
|
||||
# XXX: for now, overriding is not allowed and Value error is raised
|
||||
# open questions are:
|
||||
# 1. Should we allow overriding by default.
|
||||
# a. If yes, should it be implicit or explicit (something like
|
||||
# Dispatcher.override method)
|
||||
# b. If no -- what exception we should raise.
|
||||
|
||||
from generic.multidispatch import multifunction
|
||||
|
||||
@multifunction(int, str)
|
||||
def func(x, y):
|
||||
return str(x) + y
|
||||
|
||||
self.assertEqual(func(1, "2"), "12")
|
||||
self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
|
||||
|
||||
@func.override(int, str)
|
||||
def func(x, y):
|
||||
return y + str(x)
|
||||
|
||||
self.assertEqual(func(1, "2"), "21")
|
||||
|
||||
|
||||
class MultimethodTests(unittest.TestCase):
|
||||
|
||||
@ -232,3 +232,41 @@ class MultimethodTests(unittest.TestCase):
|
||||
self.assertEqual(DummySub().foo((1,2)), (1,2,1))
|
||||
self.assertEqual(DummySub().foo(True), False)
|
||||
self.assertRaises(TypeError, DummySub().foo, [])
|
||||
|
||||
def test_override(self):
|
||||
from generic.multidispatch import multimethod
|
||||
from generic.multidispatch import has_multimethods
|
||||
|
||||
@has_multimethods
|
||||
class Dummy(object):
|
||||
|
||||
@multimethod(str, str)
|
||||
def foo(self, x, y):
|
||||
return x + y
|
||||
|
||||
@foo.when(str, str)
|
||||
def foo(self, x, y):
|
||||
return y + x
|
||||
|
||||
self.assertEqual(Dummy().foo("1", "2"), "21")
|
||||
|
||||
def test_inheritance_override(self):
|
||||
from generic.multidispatch import multimethod
|
||||
from generic.multidispatch import has_multimethods
|
||||
|
||||
@has_multimethods
|
||||
class Dummy(object):
|
||||
|
||||
@multimethod(int)
|
||||
def foo(self, x):
|
||||
return x + 1
|
||||
|
||||
@has_multimethods
|
||||
class DummySub(Dummy):
|
||||
|
||||
@Dummy.foo.when(int)
|
||||
def foo(self, x):
|
||||
return x + 2
|
||||
|
||||
self.assertEqual(Dummy().foo(1), 2)
|
||||
self.assertEqual(DummySub().foo(1), 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user