generic/tests/test_multidispatch.py
Dan Yeaw b33155a79f
Enable docformatter pre-commit checks
Signed-off-by: Dan Yeaw <dan@yeaw.me>
2020-09-15 19:50:55 -04:00

226 lines
5.8 KiB
Python

"""Tests for :module:`generic.multidispatch`."""
from inspect import FullArgSpec
import pytest
from generic.multidispatch import FunctionDispatcher, multidispatch
def create_dispatcher(
params_arity, args=None, varargs=None, keywords=None, defaults=None
) -> FunctionDispatcher:
return FunctionDispatcher(
FullArgSpec(
args=args,
varargs=varargs,
varkw=keywords,
defaults=defaults,
kwonlyargs=[],
kwonlydefaults={},
annotations={},
),
params_arity,
)
def test_one_argument():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x + 1, int)
assert dispatcher(1) == 2
with pytest.raises(TypeError):
dispatcher("s")
dispatcher.register_rule(lambda x: x + "1", str)
assert dispatcher(1) == 2
assert dispatcher("1") == "11"
with pytest.raises(TypeError):
dispatcher(tuple())
def test_two_arguments():
dispatcher = create_dispatcher(2, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
assert dispatcher(1, 2) == 4
with pytest.raises(TypeError):
dispatcher("s", "ss")
with pytest.raises(TypeError):
dispatcher(1, "ss")
with pytest.raises(TypeError):
dispatcher("s", 2)
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
with pytest.raises(TypeError):
dispatcher(1, "1")
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
assert dispatcher(1, "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
def test_bottom_rule():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, object)
assert dispatcher(1) == 1
assert dispatcher("1") == "1"
assert dispatcher([1]) == [1]
assert dispatcher((1,)) == (1,)
def test_subtype_evaluation():
class Super:
pass
class Sub(Super):
pass
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, Super)
o_super = Super()
assert dispatcher(o_super) == o_super
o_sub = Sub()
assert dispatcher(o_sub) == o_sub
with pytest.raises(TypeError):
dispatcher(object())
dispatcher.register_rule(lambda x: (x, x), Sub)
o_super = Super()
assert dispatcher(o_super) == o_super
o_sub = Sub()
assert dispatcher(o_sub) == (o_sub, o_sub)
def test_register_rule_with_wrong_arity():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, int)
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, str)
def test_register_rule_with_different_arg_names():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda y: y, int)
assert dispatcher(1) == 1
def test_dispatching_with_varargs():
dispatcher = create_dispatcher(1, args=["x"], varargs="va")
dispatcher.register_rule(lambda x, *va: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", 2, 3)
def test_dispatching_with_varkw():
dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
dispatcher.register_rule(lambda x, **vk: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", a=1, b=2)
def test_dispatching_with_kw():
dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
dispatcher.register_rule(lambda x, y=1: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", k=1)
def test_create_dispatcher_with_pos_args_less_multi_arity():
with pytest.raises(TypeError):
create_dispatcher(2, args=["x"])
with pytest.raises(TypeError):
create_dispatcher(2, args=["x", "y"], defaults=["x"])
def test_register_rule_with_wrong_number_types_parameters():
dispatcher = create_dispatcher(1, args=["x", "y"])
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, int, str)
def test_register_rule_with_partial_dispatching():
dispatcher = create_dispatcher(1, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x, int)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
with pytest.raises(TypeError):
dispatcher("2", 1)
dispatcher.register_rule(lambda x, y: x, str)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
assert dispatcher("1", "2") == "1"
assert dispatcher("1", 2) == "1"
def test_default_dispatcher():
@multidispatch(int, str)
def func(x, y):
return str(x) + y
assert func(1, "2") == "12"
with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
with pytest.raises(TypeError):
func("1", "2")
def test_multiple_functions():
@multidispatch(int, str)
def func(x, y):
return str(x) + y
@func.register(str, str)
def _(x, y):
return x + y
assert func(1, "2") == "12"
assert func("1", "2") == "12"
with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
def test_default():
@multidispatch()
def func(x, y):
return x + y
@func.register(str, str)
def _(x, y):
return y + x
assert func(1, 1) == 2
assert func("1", "2") == "21"
def test_on_classes():
@multidispatch()
class A:
def __init__(self, a, b):
self.v = a + b
@A.register(str, str) # type: ignore[attr-defined]
class B:
def __init__(self, a, b):
self.v = b + a
assert A(1, 1).v == 2
assert A("1", "2").v == "21"