generic/tests/test_multidispatch.py
pre-commit-ci[bot] ed866fdf71
[pre-commit.ci] pre-commit autoupdate (#294)
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/psf/black: 22.12.0 → 23.1.0](https://github.com/psf/black/compare/22.12.0...23.1.0)
- [github.com/charliermarsh/ruff-pre-commit: v0.0.237 → v0.0.242](https://github.com/charliermarsh/ruff-pre-commit/compare/v0.0.237...v0.0.242)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-02-06 18:52:10 -05:00

225 lines
5.8 KiB
Python

"""Tests for :module:`generic.multidispatch`."""
from inspect import FullArgSpec
import pytest
from generic.multidispatch import FunctionDispatcher, multidispatch
def create_dispatcher(
params_arity, args=None, varargs=None, keywords=None, defaults=None
) -> FunctionDispatcher:
return FunctionDispatcher(
FullArgSpec(
args=args,
varargs=varargs,
varkw=keywords,
defaults=defaults,
kwonlyargs=[],
kwonlydefaults={},
annotations={},
),
params_arity,
)
def test_one_argument():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x + 1, int)
assert dispatcher(1) == 2
with pytest.raises(TypeError):
dispatcher("s")
dispatcher.register_rule(lambda x: f"{x}1", str)
assert dispatcher(1) == 2
assert dispatcher("1") == "11"
with pytest.raises(TypeError):
dispatcher(())
def test_two_arguments():
dispatcher = create_dispatcher(2, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
assert dispatcher(1, 2) == 4
with pytest.raises(TypeError):
dispatcher("s", "ss")
with pytest.raises(TypeError):
dispatcher(1, "ss")
with pytest.raises(TypeError):
dispatcher("s", 2)
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
with pytest.raises(TypeError):
dispatcher(1, "1")
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
assert dispatcher(1, "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
def test_bottom_rule():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, object)
assert dispatcher(1) == 1
assert dispatcher("1") == "1"
assert dispatcher([1]) == [1]
assert dispatcher((1,)) == (1,)
def test_subtype_evaluation():
class Super:
pass
class Sub(Super):
pass
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, Super)
o_super = Super()
assert dispatcher(o_super) == o_super
o_sub = Sub()
assert dispatcher(o_sub) == o_sub
with pytest.raises(TypeError):
dispatcher(object())
dispatcher.register_rule(lambda x: (x, x), Sub)
o_super = Super()
assert dispatcher(o_super) == o_super
o_sub = Sub()
assert dispatcher(o_sub) == (o_sub, o_sub)
def test_register_rule_with_wrong_arity():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, int)
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, str)
def test_register_rule_with_different_arg_names():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda y: y, int)
assert dispatcher(1) == 1
def test_dispatching_with_varargs():
dispatcher = create_dispatcher(1, args=["x"], varargs="va")
dispatcher.register_rule(lambda x, *va: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", 2, 3)
def test_dispatching_with_varkw():
dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
dispatcher.register_rule(lambda x, **vk: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", a=1, b=2)
def test_dispatching_with_kw():
dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
dispatcher.register_rule(lambda x, y=1: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", k=1)
def test_create_dispatcher_with_pos_args_less_multi_arity():
with pytest.raises(TypeError):
create_dispatcher(2, args=["x"])
with pytest.raises(TypeError):
create_dispatcher(2, args=["x", "y"], defaults=["x"])
def test_register_rule_with_wrong_number_types_parameters():
dispatcher = create_dispatcher(1, args=["x", "y"])
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, int, str)
def test_register_rule_with_partial_dispatching():
dispatcher = create_dispatcher(1, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x, int)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
with pytest.raises(TypeError):
dispatcher("2", 1)
dispatcher.register_rule(lambda x, y: x, str)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
assert dispatcher("1", "2") == "1"
assert dispatcher("1", 2) == "1"
def test_default_dispatcher():
@multidispatch(int, str)
def func(x, y):
return str(x) + y
assert func(1, "2") == "12"
with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
with pytest.raises(TypeError):
func("1", "2")
def test_multiple_functions():
@multidispatch(int, str)
def func(x, y):
return str(x) + y
@func.register(str, str)
def _(x, y):
return x + y
assert func(1, "2") == "12"
assert func("1", "2") == "12"
with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
def test_default():
@multidispatch()
def func(x, y):
return x + y
@func.register(str, str)
def _(x, y):
return y + x
assert func(1, 1) == 2
assert func("1", "2") == "21"
def test_on_classes():
@multidispatch()
class A:
def __init__(self, a, b):
self.v = a + b
@A.register(str, str) # type: ignore[attr-defined]
class B:
def __init__(self, a, b):
self.v = b + a
assert A(1, 1).v == 2
assert A("1", "2").v == "21"