Replace old code by Python3 implementation from Gaphor

Multimethods are missing now. Those have to be added back.
This commit is contained in:
Arjan Molenaar 2019-11-08 16:35:11 +01:00
parent ee35c528cb
commit bb20f6992b
6 changed files with 641 additions and 729 deletions

View File

@ -2,7 +2,7 @@
This module provides API for event management. There are two APIs provided:
* Global event management API: subscribe, unsubscribe, fire.
* Global event management API: subscribe, unsubscribe, handle.
* Local event management API: Manager
If you run only one instance of your application per Python
@ -12,116 +12,83 @@ to have different configurations for them -- you should use local API
and have one instance of Manager object per application instance.
"""
from collections import namedtuple
from typing import Callable, Set, Type
from generic.registry import Registry
from generic.registry import TypeAxis
from generic.registry import Registry, TypeAxis
__all__ = ("Manager", "subscribe", "unsubscribe", "fire", "subscriber")
class HandlerSet(namedtuple("HandlerSet", ["parents", "handlers"])):
""" Set of handlers for specific type of event.
__all__ = "Manager"
This object stores ``handlers`` for specific event type and
``parents`` reference to handler sets of event's supertypes.
"""
Event = object
Handler = Callable[[object], None]
HandlerSet = Set[Handler]
@property
def all_handlers(self):
""" Iterate over own and supertypes' handlers.
This iterator yields just unique values, so it won't yield the
same handler twice, even if it was registered both for some
event type and its supertype.
"""
seen = set()
seen_add = seen.add
# yield own handlers first
for handler in self.handlers:
seen_add(handler)
yield handler
# yield supertypes' handlers then
for parent in self.parents:
for handler in parent.all_handlers:
if not handler in seen:
seen_add(handler)
yield handler
class Manager(object):
class Manager:
""" Event manager
Provides API for subscribing for and firing events. There's also global
event manager instantiated at module level with functions
:func:`.subscribe`, :func:`.fire` and decorator :func:`.subscriber` aliased
:func:`.subscribe`, :func:`.handle` and decorator :func:`.subscriber` aliased
to corresponding methods of class.
"""
def __init__(self):
registry: Registry[HandlerSet]
def __init__(self) -> None:
axes = (("event_type", TypeAxis()),)
self.registry = Registry(*axes)
def subscribe(self, handler, event_type):
def subscribe(self, handler: Handler, event_type: Type[Event]) -> None:
""" Subscribe ``handler`` to specified ``event_type``"""
handler_set = self.registry.get_registration(event_type)
if not handler_set:
if handler_set is None:
handler_set = self._register_handler_set(event_type)
handler_set.handlers.add(handler)
handler_set.add(handler)
def unsubscribe(self, handler, event_type):
def unsubscribe(self, handler: Handler, event_type: Type[Event]) -> None:
""" Unsubscribe ``handler`` from ``event_type``"""
handler_set = self.registry.get_registration(event_type)
if handler_set and handler in handler_set.handlers:
handler_set.handlers.remove(handler)
if handler_set and handler in handler_set:
handler_set.remove(handler)
def fire(self, event):
def handle(self, event: Event) -> None:
""" Fire ``event``
All subscribers will be executed with no determined order.
"""
handler_set = self.registry.lookup(event)
for handler in handler_set.all_handlers:
handler(event)
handler_sets = self.registry.query(event)
for handler_set in handler_sets:
if handler_set:
for handler in set(handler_set):
handler(event)
def _register_handler_set(self, event_type):
""" Register new handler set for ``event_type``."""
# Collect handler sets for supertypes
parent_handler_sets = []
parents = event_type.__bases__
for parent in parents:
parent_handlers = self.registry.get_registration(parent)
if parent_handlers is None:
parent_handlers = self._register_handler_set(parent)
parent_handler_sets.append(parent_handlers)
handler_set = HandlerSet(parents=parent_handler_sets, handlers=set())
def _register_handler_set(self, event_type: Type[Event]) -> HandlerSet:
""" Register new handler set for ``event_type``.
"""
handler_set: HandlerSet = set()
self.registry.register(handler_set, event_type)
return handler_set
def subscriber(self, event_type):
def subscriber(self, event_type: Type[Event]) -> Callable[[Handler], Handler]:
""" Decorator for subscribing handlers
Works like this:
>>> mymanager = Manager()
>>> class MyEvent():
... pass
>>> @mymanager.subscriber(MyEvent)
... def mysubscriber(evt):
... # handle event
... return
>>> mymanager.fire(MyEvent())
>>> mymanager.handle(MyEvent())
"""
def registrator(func):
def registrator(func: Handler) -> Handler:
self.subscribe(func, event_type)
return func
return registrator
# Global event manager
_global_manager = Manager()
# Global event management API
subscribe = _global_manager.subscribe
unsubscribe = _global_manager.unsubscribe
fire = _global_manager.fire
subscriber = _global_manager.subscriber

View File

@ -1,206 +1,132 @@
""" Multidispatch for functions and methods"""
""" Multidispatch for functions and methods.
This code is a Python 3, slimmed down version of the
generic package by Andrey Popp.
Only the generic function code is left in tact -- no generic methods.
The interface has been made in line with `functools.singledispatch`.
Note that this module does not support annotated functions.
"""
from __future__ import annotations
from typing import cast, Any, Callable, Generic, TypeVar, Union
import functools
import inspect
import types
import threading
from generic.registry import Registry
from generic.registry import TypeAxis
from generic.registry import Registry, TypeAxis
__all__ = ("multifunction", "multimethod", "has_multimethods")
__all__ = "multidispatch"
def multifunction(*argtypes):
""" Declare function as multifunction
T = TypeVar("T", bound=Union[Callable[..., Any], type])
KeyType = Union[type, None]
def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]:
""" Declare function as multidispatch
This decorator takes ``argtypes`` argument types and replace decorated
function with :class:`.FunctionDispatcher` object, which is responsible for
multiple dispatch feature.
"""
def _replace_with_dispatcher(func):
dispatcher = _make_dispatcher(FunctionDispatcher, func, len(argtypes))
def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]:
nonlocal argtypes
argspec = inspect.getfullargspec(func)
if not argtypes:
arity = _arity(argspec)
if isinstance(func, type):
# It's a class we deal with:
arity -= 1
argtypes = (object,) * arity
dispatcher = cast(
FunctionDispatcher[T],
functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func),
)
dispatcher.register_rule(func, *argtypes)
return dispatcher
return _replace_with_dispatcher
def multimethod(*argtypes):
""" Declare method as multimethod
This decorator works exactly the same as :func:`.multifunction` decorator
but replaces decorated method with :class:`.MethodDispatcher` object
instead.
Should be used only for decorating methods and enclosing class should have
:func:`.has_multimethods` decorator.
"""
def _replace_with_dispatcher(func):
dispatcher = _make_dispatcher(MethodDispatcher, func, len(argtypes) + 1)
dispatcher.register_unbound_rule(func, *argtypes)
return dispatcher
return _replace_with_dispatcher
def has_multimethods(cls):
""" Declare class as one that have multimethods
Should only be used for decorating classes which have methods decorated with
:func:`.multimethod` decorator.
"""
for name, obj in cls.__dict__.items():
if isinstance(obj, MethodDispatcher):
obj.proceed_unbound_rules(cls)
return cls
class FunctionDispatcher(object):
class FunctionDispatcher(Generic[T]):
""" Multidispatcher for functions
This object dispatch calls to function by its argument types. Usually it is
produced by :func:`.multifunction` decorator.
produced by :func:`.multidispatch` decorator.
You should not manually create objects of this type.
"""
def __init__(self, argspec, params_arity):
registry: Registry[T]
def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None:
""" Initialize dispatcher with ``argspec`` of type
:class:`inspect.ArgSpec` and ``params_arity`` that represent number
params."""
# Check if we have enough positional arguments for number of type params
if arity(argspec) < params_arity:
raise TypeError("Not enough positional arguments "
"for number of type parameters provided.")
if _arity(argspec) < params_arity:
raise TypeError(
"Not enough positional arguments "
"for number of type parameters provided."
)
self.argspec = argspec
self.params_arity = params_arity
axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)]
axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)]
self.registry = Registry(*axis)
def check_rule(self, rule, *argtypes):
def check_rule(self, rule: T, *argtypes: KeyType) -> None:
# Check if we have the right number of parametrized types
if len(argtypes) != self.params_arity:
raise TypeError("Wrong number of type parameters.")
raise TypeError(
f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}."
)
# Check if we have the same argspec (by number of args)
rule_argspec = inspect.getargspec(rule)
if not is_equalent_argspecs(rule_argspec, self.argspec):
raise TypeError("Rule does not conform "
"to previous implementations.")
rule_argspec = inspect.getfullargspec(rule)
left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4])
right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4])
if left_spec != right_spec:
raise TypeError(
f"Rule does not conform to previous implementations: {left_spec} != {right_spec}."
)
def register_rule(self, rule, *argtypes):
def register_rule(self, rule: T, *argtypes: KeyType) -> None:
""" Register new ``rule`` for ``argtypes``."""
self.check_rule(rule, *argtypes)
self.registry.register(rule, *argtypes)
def override_rule(self, rule, *argtypes):
""" Override ``rule`` for ``argtypes``."""
self.check_rule(rule, *argtypes)
self.registry.override(rule, *argtypes)
def lookup_rule(self, *args):
""" Lookup rule by ``args``. Returns None if no rule was found."""
args = args[:self.params_arity]
rule = self.registry.lookup(*args)
if rule is None:
raise TypeError("No available rule found for %r" % (args,))
return rule
def when(self, *argtypes):
""" Decorator for registering new case for multifunction
def register(self, *argtypes: KeyType) -> Callable[[T], T]:
""" Decorator for registering new case for multidispatch
New case will be registered for types identified by ``argtypes``. The
length of ``argtypes`` should be equal to the length of ``argtypes``
argument were passed corresponding :func:`.multifunction` call, which
also indicated the number of arguments multifunction dispatches on.
argument were passed corresponding :func:`.multidispatch` call, which
also indicated the number of arguments multidispatch dispatches on.
"""
def register_rule(func):
def register_rule(func: T) -> T:
self.register_rule(func, *argtypes)
return self
return func
return register_rule
@property
def otherwise(self):
""" Decorator which registeres "catch-all" case for multifunction"""
def register_rule(func):
self.register_rule(func, [object]*self.params_arity)
return self
return register_rule
def override(self, *argtypes):
""" Decorator for overriding case for ``argtypes``"""
def override_rule(func):
self.override_rule(func, *argtypes)
return self
return override_rule
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
""" Dispatch call to appropriate rule."""
rule = self.lookup_rule(*args)
trimmed_args = args[: self.params_arity]
rule = self.registry.lookup(*trimmed_args)
if not rule:
raise TypeError(f"No available rule found for {trimmed_args!r}")
return rule(*args, **kwargs)
class MethodDispatcher(FunctionDispatcher):
""" Multiple dispatch for methods
This object dispatch call to method by its class and arguments types.
Usually it is produced by :func:`.multimethod` decorator.
You should not manually create objects of this type.
"""
def __init__(self, argspec, params_arity):
FunctionDispatcher.__init__(self, argspec, params_arity)
# some data, that should be local to thread of execution
self.local = threading.local()
self.local.unbound_rules = []
def register_unbound_rule(self, func, *argtypes):
""" Register unbound rule that should be processed by
``proceed_unbound_rules`` later."""
self.local.unbound_rules.append((argtypes, func))
def proceed_unbound_rules(self, cls):
""" Process all unbound rule by binding them to ``cls`` type."""
for argtypes, func in self.local.unbound_rules:
argtypes = (cls,) + argtypes
self.override_rule(func, *argtypes)
self.local.unbound_rules = []
def __get__(self, obj, cls):
if obj is None:
return self
return types.MethodType(self, obj)
def when(self, *argtypes):
""" Register new case for multimethod for ``argtypes``"""
def make_declaration(meth):
self.register_unbound_rule(meth, *argtypes)
return self
return make_declaration
def override(self, *argtypes):
""" Decorator for overriding case for ``argtypes``"""
return self.when(*argtypes)
@property
def otherwise(self):
""" Decorator which registeres "catch-all" case for multimethod"""
def make_declaration(func):
self.register_unbound_rule(func, [object]*self.params_arity)
return self
return make_declaration
def arity(argspec):
def _arity(argspec: inspect.FullArgSpec) -> int:
""" Determinal positional arity of argspec."""
args = argspec.args if argspec.args else []
defaults = argspec.defaults if argspec.defaults else []
return len(args) - len(defaults)
def is_equalent_argspecs(left, right):
""" Check argspec equalence."""
return map(lambda x: len(x) if x else 0, left) == \
map(lambda x: len(x) if x else 0, right)
def _make_dispatcher(dispacther_cls, func, params_arity):
argspec = inspect.getargspec(func)
wrapper = functools.wraps(func)
dispatcher = wrapper(dispacther_cls(argspec, params_arity))
return dispatcher

View File

@ -5,84 +5,101 @@ This implementation was borrowed from happy[1] project by Chris Rossi.
[1]: http://bitbucket.org/chrisrossi/happy
"""
from __future__ import annotations
__all__ = ("Registry", "SimpleAxis", "TypeAxis")
class Registry(object):
from typing import (
Any,
Dict,
Generic,
KeysView,
List,
Generator,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
K = TypeVar("K")
S = TypeVar("S")
T = TypeVar("T")
V = TypeVar("V")
Axis = Union["SimpleAxis", "TypeAxis"]
class Registry(Generic[T]):
""" Registry implementation."""
def __init__(self, *axes):
self._tree = _TreeNode()
def __init__(self, *axes: Tuple[str, Axis]):
self._tree: _TreeNode[T] = _TreeNode()
self._axes = [axis for name, axis in axes]
self._axes_dict = dict([
(name, (i, axis)) for i, (name, axis) in enumerate(axes)
])
self._axes_dict = {name: (i, axis) for i, (name, axis) in enumerate(axes)}
def register(self, target, *arg_keys, **kw_keys):
self._register(target, self._align_with_axes(arg_keys, kw_keys), False)
def override(self, target, *arg_keys, **kw_keys):
self._register(target, self._align_with_axes(arg_keys, kw_keys), True)
def _register(self, target, keys, override):
def register(self, target: T, *arg_keys: K, **kw_keys: K) -> None:
tree_node = self._tree
for key in keys:
tree_node = tree_node.setdefault(key, _TreeNode())
for key in self._align_with_axes(arg_keys, kw_keys):
tree_node = tree_node.setdefault(key, _TreeNode[T]())
if not override and not tree_node.target is None:
if not tree_node.target is None:
raise ValueError(
"Registration conflicts with existing registration. Use "
"override method to override.")
f"Registration for {target} conflicts with existing registration {tree_node.target}."
)
tree_node.target = target
def get_registration(self, *arg_keys, **kw_keys):
def get_registration(self, *arg_keys: K, **kw_keys: K) -> Optional[T]:
tree_node = self._tree
for key in self._align_with_axes(arg_keys, kw_keys):
if not tree_node.has_key(key):
if not key in tree_node:
return None
tree_node = tree_node[key]
return tree_node.target
def lookup(self, *arg_objs, **kw_objs):
def lookup(self, *arg_objs: V, **kw_objs: V) -> Optional[T]:
return next(self.query(*arg_objs, **kw_objs), None)
def query(self, *arg_objs: V, **kw_objs: V) -> Generator[Optional[T], None, None]:
objs = self._align_with_axes(arg_objs, kw_objs)
axes = self._axes
return self._lookup(self._tree, objs, axes)
return self._query(self._tree, objs, axes)
def _lookup(self, tree_node, objs, axes):
def _query(
self, tree_node: _TreeNode[T], objs: Sequence[Optional[V]], axes: Sequence[Axis]
) -> Generator[Optional[T], None, None]:
""" Recursively traverse registration tree, from left to right, most
specific to least specific, returning the first target found on a
matching node. """
if not objs:
return tree_node.target
yield tree_node.target
else:
obj = objs[0]
obj = objs[0]
# Skip non-participating nodes
if obj is None:
next_node: Optional[_TreeNode[T]] = tree_node.get(None, None)
if next_node is not None:
yield from self._query(next_node, objs[1:], axes[1:])
else:
# Get matches on this axis and iterate from most to least specific
axis = axes[0]
for match_key in axis.matches(obj, tree_node.keys()):
yield from self._query(tree_node[match_key], objs[1:], axes[1:])
# Skip non-participating nodes
if obj is None:
next_node = tree_node.get(None, None)
if next_node is not None:
return self._lookup(next_node, objs[1:], axes[1:])
return None
# Get matches on this axis and iterate from most to least specific
axis = axes[0]
for match_key in axis.matches(obj, tree_node.keys()):
target = self._lookup(tree_node[match_key], objs[1:], axes[1:])
if target is not None:
return target
return None
def _align_with_axes(self, args, kw):
def _align_with_axes(
self, args: Sequence[S], kw: Dict[str, S]
) -> Sequence[Optional[S]]:
""" Create a list matching up all args and kwargs with their
corresponding axes, in order, using ``None`` as a placeholder for
skipped axes. """
axes_dict = self._axes_dict
aligned = [None for i in xrange(len(axes_dict))]
aligned: List[Optional[S]] = [None for i in range(len(axes_dict))]
args_len = len(args)
if args_len + len(kw) > len(aligned):
if args_len + len(kw) > len(aligned):
raise ValueError("Cannot have more arguments than axes.")
for i, arg in enumerate(args):
@ -91,12 +108,13 @@ class Registry(object):
for k, v in kw.items():
i_axis = axes_dict.get(k, None)
if i_axis is None:
raise ValueError("No axis with name: %s" % k)
raise ValueError(f"No axis with name: {k}")
i, axis = i_axis
if aligned[i] is not None:
raise ValueError("Axis defined twice between positional and "
"keyword arguments")
raise ValueError(
"Axis defined twice between positional and " "keyword arguments"
)
aligned[i] = v
@ -106,13 +124,15 @@ class Registry(object):
return aligned
class _TreeNode(dict):
target = None
def __str__(self):
return "<TreeNode %s %s>" % (self.target, dict.__str__(self))
class _TreeNode(Generic[T], Dict[Any, Any]):
target: Optional[T] = None
class SimpleAxis(object):
def __str__(self) -> str:
return f"<TreeNode {self.target} {dict.__str__(self)}>"
class SimpleAxis:
""" A simple axis where the key into the axis is the same as the object to
be matched (aka the identity axis). This axis behaves just like a
dictionary. You might use this axis if you are interested in registering
@ -122,21 +142,23 @@ class SimpleAxis(object):
Subclasses can override the ``get_keys`` method for implementing arbitrary
axes.
"""
def matches(self, obj, keys):
for key in self.get_keys(obj):
def matches(
self, obj: object, keys: KeysView[Optional[object]]
) -> Generator[object, None, None]:
for key in [obj]:
if key in keys:
yield key
yield obj
def get_keys(self, obj):
"""
Return the keys for the given object that could match this axis, from
most specific to least specific. A convenient override point.
"""
return [obj,]
class TypeAxis(SimpleAxis):
class TypeAxis:
""" An axis which matches the class and super classes of an object in
method resolution order.
"""
def get_keys(self, obj):
return type(obj).mro()
def matches(
self, obj: object, keys: KeysView[Optional[type]]
) -> Generator[type, None, None]:
for key in type(obj).mro():
if key in keys:
yield key

View File

@ -1,150 +1,181 @@
""" Tests for :module:`generic.event`."""
import unittest
from __future__ import annotations
__all__ = ("ManagerTests",)
from typing import Callable, List
from generic.event import Manager
class ManagerTests(unittest.TestCase):
def makeHandler(self, effect):
return lambda e: e.effects.append(effect)
def make_handler(effect: object) -> Callable[[Event], None]:
return lambda e: e.effects.append(effect)
def createManager(self):
from generic.event import Manager
return Manager()
def test_subscribe_single_event(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), EventA)
e = EventA()
events.fire(e)
self.assertEqual(len(e.effects), 1)
self.assertTrue("handler1" in e.effects)
def create_manager():
return Manager()
def test_subscribe_via_decorator(self):
events = self.createManager()
events.subscriber(EventA)(self.makeHandler("handler1"))
e = EventA()
events.fire(e)
self.assertEqual(len(e.effects), 1)
self.assertTrue("handler1" in e.effects)
def test_subscribe_event_inheritance(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), EventA)
events.subscribe(self.makeHandler("handler2"), EventB)
def test_subscribe_single_event():
events = create_manager()
events.subscribe(make_handler("handler1"), EventA)
e = EventA()
events.handle(e)
assert len(e.effects) == 1
assert "handler1" in e.effects
ea = EventA()
events.fire(ea)
self.assertEqual(len(ea.effects), 1)
self.assertTrue("handler1" in ea.effects)
eb = EventB()
events.fire(eb)
self.assertEqual(len(eb.effects), 2)
self.assertTrue("handler1" in eb.effects)
self.assertTrue("handler2" in eb.effects)
def test_subscribe_via_decorator():
events = create_manager()
events.subscriber(EventA)(make_handler("handler1"))
e = EventA()
events.handle(e)
assert len(e.effects) == 1
assert "handler1" in e.effects
def test_subscribe_event_multiple_inheritance(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), EventA)
events.subscribe(self.makeHandler("handler2"), EventC)
events.subscribe(self.makeHandler("handler3"), EventD)
ea = EventA()
events.fire(ea)
self.assertEqual(len(ea.effects), 1)
self.assertTrue("handler1" in ea.effects)
def test_subscribe_event_inheritance():
events = create_manager()
events.subscribe(make_handler("handler1"), EventA)
events.subscribe(make_handler("handler2"), EventB)
ec = EventC()
events.fire(ec)
self.assertEqual(len(ec.effects), 1)
self.assertTrue("handler2" in ec.effects)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 1
assert "handler1" in ea.effects
ed = EventD()
events.fire(ed)
self.assertEqual(len(ed.effects), 3)
self.assertTrue("handler1" in ed.effects)
self.assertTrue("handler2" in ed.effects)
self.assertTrue("handler3" in ed.effects)
eb = EventB()
events.handle(eb)
assert len(eb.effects) == 2
assert "handler1" in eb.effects
assert "handler2" in eb.effects
def test_subscribe_event_malformed_multiple_inheritance(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), EventA)
events.subscribe(self.makeHandler("handler2"), EventD)
events.subscribe(self.makeHandler("handler3"), EventE)
ea = EventA()
events.fire(ea)
self.assertEqual(len(ea.effects), 1)
self.assertTrue("handler1" in ea.effects)
def test_subscribe_event_multiple_inheritance():
events = create_manager()
events.subscribe(make_handler("handler1"), EventA)
events.subscribe(make_handler("handler2"), EventC)
events.subscribe(make_handler("handler3"), EventD)
ed = EventD()
events.fire(ed)
self.assertEqual(len(ed.effects), 2)
self.assertTrue("handler1" in ed.effects)
self.assertTrue("handler2" in ed.effects)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 1
assert "handler1" in ea.effects
ee = EventE()
events.fire(ee)
self.assertEqual(len(ee.effects), 3)
self.assertTrue("handler1" in ee.effects)
self.assertTrue("handler2" in ee.effects)
self.assertTrue("handler3" in ee.effects)
ec = EventC()
events.handle(ec)
assert len(ec.effects) == 1
assert "handler2" in ec.effects
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), Event)
events.subscribe(self.makeHandler("handler2"), EventB)
ed = EventD()
events.handle(ed)
assert len(ed.effects) == 3
assert "handler1" in ed.effects
assert "handler2" in ed.effects
assert "handler3" in ed.effects
eb = EventB()
events.fire(eb)
self.assertEqual(len(eb.effects), 2)
self.assertTrue("handler1" in eb.effects)
self.assertTrue("handler2" in eb.effects)
def test_unsubscribe_single_event(self):
events = self.createManager()
handler = self.makeHandler("handler1")
events.subscribe(handler, EventA)
events.unsubscribe(handler, EventA)
e = EventA()
events.fire(e)
self.assertEqual(len(e.effects), 0)
def test_subscribe_no_events():
events = create_manager()
def test_unsubscribe_event_inheritance(self):
events = self.createManager()
handler1 = self.makeHandler("handler1")
handler2 = self.makeHandler("handler2")
events.subscribe(handler1, EventA)
events.subscribe(handler2, EventB)
events.unsubscribe(handler1, EventA)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 0
ea = EventA()
events.fire(ea)
self.assertEqual(len(ea.effects), 0)
eb = EventB()
events.fire(eb)
self.assertEqual(len(eb.effects), 1)
self.assertTrue("handler2" in eb.effects)
def test_subscribe_base_event():
events = create_manager()
events.subscribe(make_handler("handler1"), EventA)
class Event(object):
ea = EventB()
events.handle(ea)
assert len(ea.effects) == 1
assert "handler1" in ea.effects
def test_subscribe_event_malformed_multiple_inheritance():
events = create_manager()
events.subscribe(make_handler("handler1"), EventA)
events.subscribe(make_handler("handler2"), EventD)
events.subscribe(make_handler("handler3"), EventE)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 1
assert "handler1" in ea.effects
ed = EventD()
events.handle(ed)
assert len(ed.effects) == 2
assert "handler1" in ed.effects
assert "handler2" in ed.effects
ee = EventE()
events.handle(ee)
assert len(ee.effects) == 3
assert "handler1" in ee.effects
assert "handler2" in ee.effects
assert "handler3" in ee.effects
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro():
events = create_manager()
events.subscribe(make_handler("handler1"), Event)
events.subscribe(make_handler("handler2"), EventB)
eb = EventB()
events.handle(eb)
assert len(eb.effects) == 2
assert "handler1" in eb.effects
assert "handler2" in eb.effects
def test_unsubscribe_single_event():
events = create_manager()
handler = make_handler("handler1")
events.subscribe(handler, EventA)
events.unsubscribe(handler, EventA)
e = EventA()
events.handle(e)
assert len(e.effects) == 0
def test_unsubscribe_event_inheritance():
events = create_manager()
handler1 = make_handler("handler1")
handler2 = make_handler("handler2")
events.subscribe(handler1, EventA)
events.subscribe(handler2, EventB)
events.unsubscribe(handler1, EventA)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 0
eb = EventB()
events.handle(eb)
assert len(eb.effects) == 1
assert "handler2" in eb.effects
class Event:
def __init__(self) -> None:
self.effects: List[object] = []
def __init__(self):
self.effects = []
class EventA(Event):
pass
class EventB(EventA):
pass
class EventC(Event):
pass
class EventD(EventA, EventC):
pass
class EventE(EventD, EventA):
pass

View File

@ -1,270 +1,224 @@
""" Tests for :module:`generic.multidispatch`."""
import unittest
import pytest
__all__ = ("DispatcherTests",)
class DispatcherTests(unittest.TestCase):
def createDispatcher(self, params_arity, args=None, varargs=None,
keywords=None, defaults=None):
from inspect import ArgSpec
from generic.multidispatch import FunctionDispatcher
return FunctionDispatcher(ArgSpec(args=args, varargs=varargs,
keywords=keywords,
defaults=defaults), params_arity)
from inspect import FullArgSpec
from generic.multidispatch import multidispatch, FunctionDispatcher
def test_one_argument(self):
dispatcher = self.createDispatcher(1, args=["x"])
def create_dispatcher(
params_arity, args=None, varargs=None, keywords=None, defaults=None
) -> FunctionDispatcher:
dispatcher.register_rule(lambda x: x + 1, int)
self.assertEqual(dispatcher(1), 2)
self.assertRaises(TypeError, dispatcher, "s")
return FunctionDispatcher(
FullArgSpec(
args=args,
varargs=varargs,
varkw=keywords,
defaults=defaults,
kwonlyargs=[],
kwonlydefaults={},
annotations={},
),
params_arity,
)
dispatcher.register_rule(lambda x: x + "1", str)
self.assertEqual(dispatcher(1), 2)
self.assertEqual(dispatcher("1"), "11")
self.assertRaises(TypeError, dispatcher, tuple())
def test_two_arguments(self):
dispatcher = self.createDispatcher(2, args=["x", "y"])
def test_one_argument():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
self.assertEqual(dispatcher(1, 2), 4)
self.assertRaises(TypeError, dispatcher, "s", "ss")
self.assertRaises(TypeError, dispatcher, 1, "ss")
self.assertRaises(TypeError, dispatcher, "s", 2)
dispatcher.register_rule(lambda x: x + 1, int)
assert dispatcher(1) == 2
with pytest.raises(TypeError):
dispatcher("s")
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
self.assertEqual(dispatcher(1, 2), 4)
self.assertEqual(dispatcher("1", "2"), "121")
self.assertRaises(TypeError, dispatcher, "1", 1)
self.assertRaises(TypeError, dispatcher, 1, "1")
dispatcher.register_rule(lambda x: x + "1", str)
assert dispatcher(1) == 2
assert dispatcher("1") == "11"
with pytest.raises(TypeError):
dispatcher(tuple())
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
self.assertEqual(dispatcher(1, 2), 4)
self.assertEqual(dispatcher("1", "2"), "121")
self.assertEqual(dispatcher(1, "2"), "121")
self.assertRaises(TypeError, dispatcher, "1", 1)
def test_bottom_rule(self):
dispatcher = self.createDispatcher(1, args=["x"])
def test_two_arguments():
dispatcher = create_dispatcher(2, args=["x", "y"])
dispatcher.register_rule(lambda x: x, object)
self.assertEqual(dispatcher(1), 1)
self.assertEqual(dispatcher("1"), "1")
self.assertEqual(dispatcher([1]), [1])
self.assertEqual(dispatcher((1,)), (1,))
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
assert dispatcher(1, 2) == 4
with pytest.raises(TypeError):
dispatcher("s", "ss")
with pytest.raises(TypeError):
dispatcher(1, "ss")
with pytest.raises(TypeError):
dispatcher("s", 2)
def test_subtype_evaluation(self):
class Super(object):
pass
class Sub(Super):
pass
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
with pytest.raises(TypeError):
dispatcher(1, "1")
dispatcher = self.createDispatcher(1, args=["x"])
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
assert dispatcher(1, "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
dispatcher.register_rule(lambda x: x, Super)
o_super = Super()
self.assertEqual(dispatcher(o_super), o_super)
o_sub = Sub()
self.assertEqual(dispatcher(o_sub), o_sub)
self.assertRaises(TypeError, dispatcher, object())
dispatcher.register_rule(lambda x: (x, x), Sub)
o_super = Super()
self.assertEqual(dispatcher(o_super), o_super)
o_sub = Sub()
self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
def test_bottom_rule():
dispatcher = create_dispatcher(1, args=["x"])
def test_register_rule_with_wrong_arity(self):
dispatcher = self.createDispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, int)
self.assertRaises(
TypeError,
dispatcher.register_rule, lambda x, y: x, str)
dispatcher.register_rule(lambda x: x, object)
assert dispatcher(1) == 1
assert dispatcher("1") == "1"
assert dispatcher([1]) == [1]
assert dispatcher((1,)) == (1,)
def test_register_rule_with_different_arg_names(self):
dispatcher = self.createDispatcher(1, args=["x"])
dispatcher.register_rule(lambda y: y, int)
self.assertEqual(dispatcher(1), 1)
def test_dispatching_with_varargs(self):
dispatcher = self.createDispatcher(1, args=["x"], varargs="va")
dispatcher.register_rule(lambda x, *va: x, int)
self.assertEqual(dispatcher(1), 1)
self.assertRaises(TypeError, dispatcher, "1", 2, 3)
def test_subtype_evaluation():
class Super:
pass
def test_dispatching_with_varkw(self):
dispatcher = self.createDispatcher(1, args=["x"], keywords="vk")
dispatcher.register_rule(lambda x, **vk: x, int)
self.assertEqual(dispatcher(1), 1)
self.assertRaises(TypeError, dispatcher, "1", a=1, b=2)
class Sub(Super):
pass
def test_dispatching_with_kw(self):
dispatcher = self.createDispatcher(1, args=["x", "y"], defaults=["vk"])
dispatcher.register_rule(lambda x, y=1: x, int)
self.assertEqual(dispatcher(1), 1)
self.assertRaises(TypeError, dispatcher, "1", k=1)
dispatcher = create_dispatcher(1, args=["x"])
def test_create_dispatcher_with_pos_args_less_multi_arity(self):
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x"])
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x", "y"],
defaults=["x"])
dispatcher.register_rule(lambda x: x, Super)
o_super = Super()
assert dispatcher(o_super) == o_super
o_sub = Sub()
assert dispatcher(o_sub) == o_sub
with pytest.raises(TypeError):
dispatcher(object())
def test_register_rule_with_wrong_number_types_parameters(self):
dispatcher = self.createDispatcher(1, args=["x", "y"])
self.assertRaises(
TypeError,
dispatcher.register_rule, lambda x, y: x, int, str)
dispatcher.register_rule(lambda x: (x, x), Sub)
o_super = Super()
assert dispatcher(o_super) == o_super
o_sub = Sub()
assert dispatcher(o_sub) == (o_sub, o_sub)
def test_register_rule_with_partial_dispatching(self):
dispatcher = self.createDispatcher(1, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x, int)
self.assertEqual(dispatcher(1, 2), 1)
self.assertEqual(dispatcher(1, "2"), 1)
self.assertRaises(TypeError, dispatcher, "2", 1)
def test_register_rule_with_wrong_arity():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x: x, int)
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, str)
self.assertEqual(dispatcher(1, 2), 1)
self.assertEqual(dispatcher(1, "2"), 1)
self.assertEqual(dispatcher("1", "2"), "1")
self.assertEqual(dispatcher("1", 2), "1")
class MultifunctionTests(unittest.TestCase):
def test_it(self):
from generic.multidispatch import multifunction
def test_register_rule_with_different_arg_names():
dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda y: y, int)
assert dispatcher(1) == 1
@multifunction(int, str)
def func(x, y):
return str(x) + y
self.assertEqual(func(1, "2"), "12")
self.assertRaises(TypeError, func, 1, 2)
self.assertRaises(TypeError, func, "1", 2)
self.assertRaises(TypeError, func, "1", "2")
def test_dispatching_with_varargs():
dispatcher = create_dispatcher(1, args=["x"], varargs="va")
dispatcher.register_rule(lambda x, *va: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", 2, 3)
@func.when(str, str)
def func(x, y):
return x + y
self.assertEqual(func(1, "2"), "12")
self.assertEqual(func("1", "2"), "12")
self.assertRaises(TypeError, func, 1, 2)
self.assertRaises(TypeError, func, "1", 2)
def test_dispatching_with_varkw():
dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
dispatcher.register_rule(lambda x, **vk: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", a=1, b=2)
def test_overriding(self):
from generic.multidispatch import multifunction
@multifunction(int, str)
def func(x, y):
return str(x) + y
def test_dispatching_with_kw():
dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
dispatcher.register_rule(lambda x, y=1: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", k=1)
self.assertEqual(func(1, "2"), "12")
self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
@func.override(int, str)
def func(x, y):
return y + str(x)
def test_create_dispatcher_with_pos_args_less_multi_arity():
with pytest.raises(TypeError):
create_dispatcher(2, args=["x"])
with pytest.raises(TypeError):
create_dispatcher(2, args=["x", "y"], defaults=["x"])
self.assertEqual(func(1, "2"), "21")
class MultimethodTests(unittest.TestCase):
def test_register_rule_with_wrong_number_types_parameters():
dispatcher = create_dispatcher(1, args=["x", "y"])
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, int, str)
def test_multimethod(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods
class Dummy(object):
def test_register_rule_with_partial_dispatching():
dispatcher = create_dispatcher(1, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x, int)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
with pytest.raises(TypeError):
dispatcher("2", 1)
dispatcher.register_rule(lambda x, y: x, str)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
assert dispatcher("1", "2") == "1"
assert dispatcher("1", 2) == "1"
@multimethod(int)
def foo(self, x):
return x + 1
@foo.when(str)
def foo(self, x):
return x + "1"
def test_default_dispatcher():
@multidispatch(int, str)
def func(x, y):
return str(x) + y
self.assertEqual(Dummy().foo(1), 2)
self.assertEqual(Dummy().foo("1"), "11")
self.assertRaises(TypeError, Dummy().foo, [])
assert func(1, "2") == "12"
with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
with pytest.raises(TypeError):
func("1", "2")
def test_inheritance(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods
class Dummy(object):
def test_multiple_functions():
@multidispatch(int, str)
def func(x, y):
return str(x) + y
@multimethod(int)
def foo(self, x):
return x + 1
@func.register(str, str)
def _(x, y):
return x + y
@foo.when(float)
def foo(self, x):
return x + 1.5
assert func(1, "2") == "12"
assert func("1", "2") == "12"
with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
@has_multimethods
class DummySub(Dummy):
@Dummy.foo.when(str)
def foo(self, x):
return x + "1"
def test_default():
@multidispatch()
def func(x, y):
return x + y
@foo.when(tuple)
def foo(self, x):
return x + (1,)
@func.register(str, str)
def _(x, y):
return y + x
@Dummy.foo.when(bool)
def foo(self, x):
return not x
assert func(1, 1) == 2
assert func("1", "2") == "21"
self.assertEqual(Dummy().foo(1), 2)
self.assertEqual(Dummy().foo(1.5), 3.0)
self.assertRaises(TypeError, Dummy().foo, "1")
self.assertEqual(DummySub().foo(1), 2)
self.assertEqual(DummySub().foo(1.5), 3.0)
self.assertEqual(DummySub().foo("1"), "11")
self.assertEqual(DummySub().foo((1,2)), (1,2,1))
self.assertEqual(DummySub().foo(True), False)
self.assertRaises(TypeError, DummySub().foo, [])
def test_override(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
def test_on_classes():
@multidispatch()
class A:
def __init__(self, a, b):
self.v = a + b
@has_multimethods
class Dummy(object):
@A.register(str, str) # type: ignore[attr-defined]
class B:
def __init__(self, a, b):
self.v = b + a
@multimethod(str, str)
def foo(self, x, y):
return x + y
@foo.when(str, str)
def foo(self, x, y):
return y + x
self.assertEqual(Dummy().foo("1", "2"), "21")
def test_inheritance_override(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods
class Dummy(object):
@multimethod(int)
def foo(self, x):
return x + 1
@has_multimethods
class DummySub(Dummy):
@Dummy.foo.when(int)
def foo(self, x):
return x + 2
self.assertEqual(Dummy().foo(1), 2)
self.assertEqual(DummySub().foo(1), 3)
assert A(1, 1).v == 2
assert A("1", "2").v == "21"

View File

@ -1,135 +1,147 @@
""" Tests for :module:`generic.registry`."""
import unittest
import pytest
__all__ = ("RegistryTests",)
from typing import Union
from generic.registry import Registry, SimpleAxis, TypeAxis
class RegistryTests(unittest.TestCase):
def test_one_axis_no_specificity(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('foo', SimpleAxis()))
a = object()
b = object()
registry.register(a)
registry.register(b, 'foo')
self.assertEqual(registry.lookup(), a)
self.assertEqual(registry.lookup('foo'), b)
self.assertEqual(registry.lookup('bar'), None)
def test_two_axes(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
from generic.registry import TypeAxis
registry = Registry(('type', TypeAxis()),
('name', SimpleAxis()))
target1 = Target('one')
registry.register(target1, object)
target2 = Target('two')
registry.register(target2, DummyA)
target3 = Target('three')
registry.register(target3, DummyA, 'foo')
context1 = object()
self.assertEqual(registry.lookup(context1), target1)
context2 = DummyB()
self.assertEqual(registry.lookup(context2), target2)
self.assertEqual(registry.lookup(context2, 'foo'), target3)
target4 = object()
registry.register(target4, DummyB)
self.assertEqual(registry.lookup(context2), target4)
self.assertEqual(registry.lookup(context2, 'foo'), target3)
def test_get_registration(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
from generic.registry import TypeAxis
registry = Registry(('type', TypeAxis()),
('name', SimpleAxis()))
registry.register('one', object)
registry.register('two', DummyA, 'foo')
self.assertEqual(registry.get_registration(object), 'one')
self.assertEqual(registry.get_registration(DummyA, 'foo'), 'two')
self.assertEqual(registry.get_registration(object, 'foo'), None)
self.assertEqual(registry.get_registration(DummyA), None)
def test_register_too_many_keys(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
self.assertRaises(ValueError, registry.register, object(),
'one', 'two')
def test_lookup_too_many_keys(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
self.assertRaises(ValueError, registry.lookup, 'one', 'two')
def test_conflict_error(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
registry.register(object(), name='foo')
self.assertRaises(ValueError, registry.register, object(), 'foo')
def test_override(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
registry.register(1, name='foo')
registry.override(2, name='foo')
self.assertEqual(registry.lookup('foo'), 2)
def test_skip_nodes(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(
('one', SimpleAxis()),
('two', SimpleAxis()),
('three', SimpleAxis())
)
registry.register('foo', one=1, three=3)
self.assertEqual(registry.lookup(1, three=3), 'foo')
def test_miss(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(
('one', SimpleAxis()),
('two', SimpleAxis()),
('three', SimpleAxis())
)
registry.register('foo', 1, 2)
self.assertEqual(registry.lookup(one=1, three=3), None)
def test_bad_lookup(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()),
('grade', SimpleAxis()))
self.assertRaises(ValueError, registry.register, 1, foo=1)
self.assertRaises(ValueError, registry.lookup, foo=1)
self.assertRaises(ValueError, registry.register, 1, 'foo', name='foo')
class DummyA(object):
class DummyA:
pass
class DummyB(DummyA):
pass
class Target(object):
def __init__(self, name):
self.name = name
# Only called if being printed due to a failing test
def __repr__(self): #pragma NO COVERAGE
return "Target('%s')" % self.name
def test_one_axis_no_specificity():
registry: Registry[object] = Registry(("foo", SimpleAxis()))
a = object()
b = object()
registry.register(a)
registry.register(b, "foo")
assert registry.lookup() == a
assert registry.lookup("foo") == b
assert registry.lookup("bar") is None
def test_subtyping_on_axes():
registry: Registry[str] = Registry(("type", TypeAxis()))
target1 = "one"
registry.register(target1, object)
target2 = "two"
registry.register(target2, DummyA)
target3 = "three"
registry.register(target3, DummyB)
assert registry.lookup(object()) == target1
assert registry.lookup(DummyA()) == target2
assert registry.lookup(DummyB()) == target3
def test_query_subtyping_on_axes():
registry: Registry[str] = Registry(("type", TypeAxis()))
target1 = "one"
registry.register(target1, object)
target2 = "two"
registry.register(target2, DummyA)
target3 = "three"
registry.register(target3, DummyB)
target4 = "four"
registry.register(target4, int)
assert list(registry.query(object())) == [target1]
assert list(registry.query(DummyA())) == [target2, target1]
assert list(registry.query(DummyB())) == [target3, target2, target1]
assert list(registry.query(3)) == [target4, target1]
def test_two_axes():
registry: Registry[Union[str, object]] = Registry(
("type", TypeAxis()), ("name", SimpleAxis())
)
target1 = "one"
registry.register(target1, object)
target2 = "two"
registry.register(target2, DummyA)
target3 = "three"
registry.register(target3, DummyA, "foo")
context1 = object()
assert registry.lookup(context1) == target1
context2 = DummyB()
assert registry.lookup(context2) == target2
assert registry.lookup(context2, "foo") == target3
target4 = object()
registry.register(target4, DummyB)
assert registry.lookup(context2) == target4
assert registry.lookup(context2, "foo") == target3
def test_get_registration():
registry: Registry[str] = Registry(("type", TypeAxis()), ("name", SimpleAxis()))
registry.register("one", object)
registry.register("two", DummyA, "foo")
assert registry.get_registration(object) == "one"
assert registry.get_registration(DummyA, "foo") == "two"
assert registry.get_registration(object, "foo") is None
assert registry.get_registration(DummyA) is None
def test_register_too_many_keys():
registry: Registry[type] = Registry(("name", SimpleAxis()))
with pytest.raises(ValueError):
registry.register(object, "one", "two")
def test_lookup_too_many_keys():
registry: Registry[object] = Registry(("name", SimpleAxis()))
with pytest.raises(ValueError):
registry.register(registry.lookup("one", "two"))
def test_conflict_error():
registry: Registry[Union[object, type]] = Registry(("name", SimpleAxis()))
registry.register(object(), name="foo")
with pytest.raises(ValueError):
registry.register(object, "foo")
def test_skip_nodes():
registry: Registry[str] = Registry(
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
)
registry.register("foo", one=1, three=3)
assert registry.lookup(1, three=3) == "foo"
def test_miss():
registry: Registry[str] = Registry(
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
)
registry.register("foo", 1, 2)
assert registry.lookup(one=1, three=3) is None
def test_bad_lookup():
registry: Registry[int] = Registry(("name", SimpleAxis()), ("grade", SimpleAxis()))
with pytest.raises(ValueError):
registry.register(1, foo=1)
with pytest.raises(ValueError):
registry.lookup(foo=1)
with pytest.raises(ValueError):
registry.register(1, "foo", name="foo")