From bb20f6992bb409e8032d8736e4527c7f14b2ce87 Mon Sep 17 00:00:00 2001 From: Arjan Molenaar <gaphor@gmail.com> Date: Fri, 8 Nov 2019 16:35:11 +0100 Subject: [PATCH] Replace old code by Python3 implementation from Gaphor Multimethods are missing now. Those have to be added back. --- generic/event.py | 105 ++++------ generic/multidispatch.py | 224 +++++++-------------- generic/registry.py | 150 ++++++++------ tests/test_event.py | 247 +++++++++++++---------- tests/test_multidispatch.py | 384 ++++++++++++++++-------------------- tests/test_registry.py | 260 ++++++++++++------------ 6 files changed, 641 insertions(+), 729 deletions(-) diff --git a/generic/event.py b/generic/event.py index 98abe7f..628df8c 100644 --- a/generic/event.py +++ b/generic/event.py @@ -2,7 +2,7 @@ This module provides API for event management. There are two APIs provided: -* Global event management API: subscribe, unsubscribe, fire. +* Global event management API: subscribe, unsubscribe, handle. * Local event management API: Manager If you run only one instance of your application per Python @@ -12,116 +12,83 @@ to have different configurations for them -- you should use local API and have one instance of Manager object per application instance. """ -from collections import namedtuple +from typing import Callable, Set, Type -from generic.registry import Registry -from generic.registry import TypeAxis +from generic.registry import Registry, TypeAxis -__all__ = ("Manager", "subscribe", "unsubscribe", "fire", "subscriber") -class HandlerSet(namedtuple("HandlerSet", ["parents", "handlers"])): - """ Set of handlers for specific type of event. +__all__ = "Manager" - This object stores ``handlers`` for specific event type and - ``parents`` reference to handler sets of event's supertypes. - """ +Event = object +Handler = Callable[[object], None] +HandlerSet = Set[Handler] - @property - def all_handlers(self): - """ Iterate over own and supertypes' handlers. - This iterator yields just unique values, so it won't yield the - same handler twice, even if it was registered both for some - event type and its supertype. - """ - seen = set() - seen_add = seen.add - - # yield own handlers first - for handler in self.handlers: - seen_add(handler) - yield handler - - # yield supertypes' handlers then - for parent in self.parents: - for handler in parent.all_handlers: - if not handler in seen: - seen_add(handler) - yield handler - -class Manager(object): +class Manager: """ Event manager Provides API for subscribing for and firing events. There's also global event manager instantiated at module level with functions - :func:`.subscribe`, :func:`.fire` and decorator :func:`.subscriber` aliased + :func:`.subscribe`, :func:`.handle` and decorator :func:`.subscriber` aliased to corresponding methods of class. """ - def __init__(self): + registry: Registry[HandlerSet] + + def __init__(self) -> None: axes = (("event_type", TypeAxis()),) self.registry = Registry(*axes) - def subscribe(self, handler, event_type): + def subscribe(self, handler: Handler, event_type: Type[Event]) -> None: """ Subscribe ``handler`` to specified ``event_type``""" handler_set = self.registry.get_registration(event_type) - if not handler_set: + if handler_set is None: handler_set = self._register_handler_set(event_type) - handler_set.handlers.add(handler) + handler_set.add(handler) - def unsubscribe(self, handler, event_type): + def unsubscribe(self, handler: Handler, event_type: Type[Event]) -> None: """ Unsubscribe ``handler`` from ``event_type``""" handler_set = self.registry.get_registration(event_type) - if handler_set and handler in handler_set.handlers: - handler_set.handlers.remove(handler) + if handler_set and handler in handler_set: + handler_set.remove(handler) - def fire(self, event): + def handle(self, event: Event) -> None: """ Fire ``event`` All subscribers will be executed with no determined order. """ - handler_set = self.registry.lookup(event) - for handler in handler_set.all_handlers: - handler(event) + handler_sets = self.registry.query(event) + for handler_set in handler_sets: + if handler_set: + for handler in set(handler_set): + handler(event) - def _register_handler_set(self, event_type): - """ Register new handler set for ``event_type``.""" - # Collect handler sets for supertypes - parent_handler_sets = [] - parents = event_type.__bases__ - for parent in parents: - parent_handlers = self.registry.get_registration(parent) - if parent_handlers is None: - parent_handlers = self._register_handler_set(parent) - parent_handler_sets.append(parent_handlers) - - handler_set = HandlerSet(parents=parent_handler_sets, handlers=set()) + def _register_handler_set(self, event_type: Type[Event]) -> HandlerSet: + """ Register new handler set for ``event_type``. + """ + handler_set: HandlerSet = set() self.registry.register(handler_set, event_type) return handler_set - def subscriber(self, event_type): + def subscriber(self, event_type: Type[Event]) -> Callable[[Handler], Handler]: """ Decorator for subscribing handlers Works like this: + >>> mymanager = Manager() + >>> class MyEvent(): + ... pass >>> @mymanager.subscriber(MyEvent) ... def mysubscriber(evt): ... # handle event ... return - >>> mymanager.fire(MyEvent()) + >>> mymanager.handle(MyEvent()) """ - def registrator(func): + + def registrator(func: Handler) -> Handler: self.subscribe(func, event_type) return func + return registrator - -# Global event manager -_global_manager = Manager() - -# Global event management API -subscribe = _global_manager.subscribe -unsubscribe = _global_manager.unsubscribe -fire = _global_manager.fire -subscriber = _global_manager.subscriber diff --git a/generic/multidispatch.py b/generic/multidispatch.py index 294aed9..3b4bb3e 100644 --- a/generic/multidispatch.py +++ b/generic/multidispatch.py @@ -1,206 +1,132 @@ -""" Multidispatch for functions and methods""" +""" Multidispatch for functions and methods. + +This code is a Python 3, slimmed down version of the +generic package by Andrey Popp. + +Only the generic function code is left in tact -- no generic methods. +The interface has been made in line with `functools.singledispatch`. + +Note that this module does not support annotated functions. +""" + +from __future__ import annotations + +from typing import cast, Any, Callable, Generic, TypeVar, Union import functools import inspect -import types -import threading -from generic.registry import Registry -from generic.registry import TypeAxis +from generic.registry import Registry, TypeAxis -__all__ = ("multifunction", "multimethod", "has_multimethods") +__all__ = "multidispatch" -def multifunction(*argtypes): - """ Declare function as multifunction +T = TypeVar("T", bound=Union[Callable[..., Any], type]) +KeyType = Union[type, None] + + +def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]: + """ Declare function as multidispatch This decorator takes ``argtypes`` argument types and replace decorated function with :class:`.FunctionDispatcher` object, which is responsible for multiple dispatch feature. """ - def _replace_with_dispatcher(func): - dispatcher = _make_dispatcher(FunctionDispatcher, func, len(argtypes)) + + def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]: + nonlocal argtypes + argspec = inspect.getfullargspec(func) + if not argtypes: + arity = _arity(argspec) + if isinstance(func, type): + # It's a class we deal with: + arity -= 1 + argtypes = (object,) * arity + + dispatcher = cast( + FunctionDispatcher[T], + functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func), + ) dispatcher.register_rule(func, *argtypes) return dispatcher + return _replace_with_dispatcher -def multimethod(*argtypes): - """ Declare method as multimethod - This decorator works exactly the same as :func:`.multifunction` decorator - but replaces decorated method with :class:`.MethodDispatcher` object - instead. - - Should be used only for decorating methods and enclosing class should have - :func:`.has_multimethods` decorator. - """ - def _replace_with_dispatcher(func): - dispatcher = _make_dispatcher(MethodDispatcher, func, len(argtypes) + 1) - dispatcher.register_unbound_rule(func, *argtypes) - return dispatcher - return _replace_with_dispatcher - -def has_multimethods(cls): - """ Declare class as one that have multimethods - - Should only be used for decorating classes which have methods decorated with - :func:`.multimethod` decorator. - """ - for name, obj in cls.__dict__.items(): - if isinstance(obj, MethodDispatcher): - obj.proceed_unbound_rules(cls) - return cls - -class FunctionDispatcher(object): +class FunctionDispatcher(Generic[T]): """ Multidispatcher for functions This object dispatch calls to function by its argument types. Usually it is - produced by :func:`.multifunction` decorator. + produced by :func:`.multidispatch` decorator. You should not manually create objects of this type. """ - def __init__(self, argspec, params_arity): + registry: Registry[T] + + def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None: """ Initialize dispatcher with ``argspec`` of type :class:`inspect.ArgSpec` and ``params_arity`` that represent number params.""" # Check if we have enough positional arguments for number of type params - if arity(argspec) < params_arity: - raise TypeError("Not enough positional arguments " - "for number of type parameters provided.") + if _arity(argspec) < params_arity: + raise TypeError( + "Not enough positional arguments " + "for number of type parameters provided." + ) self.argspec = argspec self.params_arity = params_arity - axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)] + axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)] self.registry = Registry(*axis) - def check_rule(self, rule, *argtypes): + def check_rule(self, rule: T, *argtypes: KeyType) -> None: # Check if we have the right number of parametrized types if len(argtypes) != self.params_arity: - raise TypeError("Wrong number of type parameters.") + raise TypeError( + f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}." + ) # Check if we have the same argspec (by number of args) - rule_argspec = inspect.getargspec(rule) - if not is_equalent_argspecs(rule_argspec, self.argspec): - raise TypeError("Rule does not conform " - "to previous implementations.") + rule_argspec = inspect.getfullargspec(rule) + left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4]) + right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4]) + if left_spec != right_spec: + raise TypeError( + f"Rule does not conform to previous implementations: {left_spec} != {right_spec}." + ) - def register_rule(self, rule, *argtypes): + def register_rule(self, rule: T, *argtypes: KeyType) -> None: """ Register new ``rule`` for ``argtypes``.""" self.check_rule(rule, *argtypes) self.registry.register(rule, *argtypes) - def override_rule(self, rule, *argtypes): - """ Override ``rule`` for ``argtypes``.""" - self.check_rule(rule, *argtypes) - self.registry.override(rule, *argtypes) - - def lookup_rule(self, *args): - """ Lookup rule by ``args``. Returns None if no rule was found.""" - args = args[:self.params_arity] - rule = self.registry.lookup(*args) - if rule is None: - raise TypeError("No available rule found for %r" % (args,)) - return rule - - def when(self, *argtypes): - """ Decorator for registering new case for multifunction + def register(self, *argtypes: KeyType) -> Callable[[T], T]: + """ Decorator for registering new case for multidispatch New case will be registered for types identified by ``argtypes``. The length of ``argtypes`` should be equal to the length of ``argtypes`` - argument were passed corresponding :func:`.multifunction` call, which - also indicated the number of arguments multifunction dispatches on. + argument were passed corresponding :func:`.multidispatch` call, which + also indicated the number of arguments multidispatch dispatches on. """ - def register_rule(func): + + def register_rule(func: T) -> T: self.register_rule(func, *argtypes) - return self + return func + return register_rule - @property - def otherwise(self): - """ Decorator which registeres "catch-all" case for multifunction""" - def register_rule(func): - self.register_rule(func, [object]*self.params_arity) - return self - return register_rule - - def override(self, *argtypes): - """ Decorator for overriding case for ``argtypes``""" - def override_rule(func): - self.override_rule(func, *argtypes) - return self - return override_rule - - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ Dispatch call to appropriate rule.""" - rule = self.lookup_rule(*args) + trimmed_args = args[: self.params_arity] + rule = self.registry.lookup(*trimmed_args) + if not rule: + raise TypeError(f"No available rule found for {trimmed_args!r}") return rule(*args, **kwargs) -class MethodDispatcher(FunctionDispatcher): - """ Multiple dispatch for methods - This object dispatch call to method by its class and arguments types. - Usually it is produced by :func:`.multimethod` decorator. - - You should not manually create objects of this type. - """ - - def __init__(self, argspec, params_arity): - FunctionDispatcher.__init__(self, argspec, params_arity) - - # some data, that should be local to thread of execution - self.local = threading.local() - self.local.unbound_rules = [] - - def register_unbound_rule(self, func, *argtypes): - """ Register unbound rule that should be processed by - ``proceed_unbound_rules`` later.""" - self.local.unbound_rules.append((argtypes, func)) - - def proceed_unbound_rules(self, cls): - """ Process all unbound rule by binding them to ``cls`` type.""" - for argtypes, func in self.local.unbound_rules: - argtypes = (cls,) + argtypes - self.override_rule(func, *argtypes) - self.local.unbound_rules = [] - - def __get__(self, obj, cls): - if obj is None: - return self - return types.MethodType(self, obj) - - def when(self, *argtypes): - """ Register new case for multimethod for ``argtypes``""" - def make_declaration(meth): - self.register_unbound_rule(meth, *argtypes) - return self - return make_declaration - - def override(self, *argtypes): - """ Decorator for overriding case for ``argtypes``""" - return self.when(*argtypes) - - @property - def otherwise(self): - """ Decorator which registeres "catch-all" case for multimethod""" - def make_declaration(func): - self.register_unbound_rule(func, [object]*self.params_arity) - return self - return make_declaration - -def arity(argspec): +def _arity(argspec: inspect.FullArgSpec) -> int: """ Determinal positional arity of argspec.""" args = argspec.args if argspec.args else [] defaults = argspec.defaults if argspec.defaults else [] return len(args) - len(defaults) - -def is_equalent_argspecs(left, right): - """ Check argspec equalence.""" - return map(lambda x: len(x) if x else 0, left) == \ - map(lambda x: len(x) if x else 0, right) - -def _make_dispatcher(dispacther_cls, func, params_arity): - argspec = inspect.getargspec(func) - wrapper = functools.wraps(func) - dispatcher = wrapper(dispacther_cls(argspec, params_arity)) - return dispatcher diff --git a/generic/registry.py b/generic/registry.py index 982d53c..9143f7b 100644 --- a/generic/registry.py +++ b/generic/registry.py @@ -5,84 +5,101 @@ This implementation was borrowed from happy[1] project by Chris Rossi. [1]: http://bitbucket.org/chrisrossi/happy """ +from __future__ import annotations + __all__ = ("Registry", "SimpleAxis", "TypeAxis") -class Registry(object): +from typing import ( + Any, + Dict, + Generic, + KeysView, + List, + Generator, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +K = TypeVar("K") +S = TypeVar("S") +T = TypeVar("T") +V = TypeVar("V") +Axis = Union["SimpleAxis", "TypeAxis"] + + +class Registry(Generic[T]): """ Registry implementation.""" - def __init__(self, *axes): - self._tree = _TreeNode() + def __init__(self, *axes: Tuple[str, Axis]): + self._tree: _TreeNode[T] = _TreeNode() self._axes = [axis for name, axis in axes] - self._axes_dict = dict([ - (name, (i, axis)) for i, (name, axis) in enumerate(axes) - ]) + self._axes_dict = {name: (i, axis) for i, (name, axis) in enumerate(axes)} - def register(self, target, *arg_keys, **kw_keys): - self._register(target, self._align_with_axes(arg_keys, kw_keys), False) - - def override(self, target, *arg_keys, **kw_keys): - self._register(target, self._align_with_axes(arg_keys, kw_keys), True) - - def _register(self, target, keys, override): + def register(self, target: T, *arg_keys: K, **kw_keys: K) -> None: tree_node = self._tree - for key in keys: - tree_node = tree_node.setdefault(key, _TreeNode()) + for key in self._align_with_axes(arg_keys, kw_keys): + tree_node = tree_node.setdefault(key, _TreeNode[T]()) - if not override and not tree_node.target is None: + if not tree_node.target is None: raise ValueError( - "Registration conflicts with existing registration. Use " - "override method to override.") + f"Registration for {target} conflicts with existing registration {tree_node.target}." + ) tree_node.target = target - def get_registration(self, *arg_keys, **kw_keys): + def get_registration(self, *arg_keys: K, **kw_keys: K) -> Optional[T]: tree_node = self._tree for key in self._align_with_axes(arg_keys, kw_keys): - if not tree_node.has_key(key): + if not key in tree_node: return None tree_node = tree_node[key] return tree_node.target - def lookup(self, *arg_objs, **kw_objs): + def lookup(self, *arg_objs: V, **kw_objs: V) -> Optional[T]: + return next(self.query(*arg_objs, **kw_objs), None) + + def query(self, *arg_objs: V, **kw_objs: V) -> Generator[Optional[T], None, None]: objs = self._align_with_axes(arg_objs, kw_objs) axes = self._axes - return self._lookup(self._tree, objs, axes) + return self._query(self._tree, objs, axes) - def _lookup(self, tree_node, objs, axes): + def _query( + self, tree_node: _TreeNode[T], objs: Sequence[Optional[V]], axes: Sequence[Axis] + ) -> Generator[Optional[T], None, None]: """ Recursively traverse registration tree, from left to right, most specific to least specific, returning the first target found on a matching node. """ if not objs: - return tree_node.target + yield tree_node.target + else: + obj = objs[0] - obj = objs[0] + # Skip non-participating nodes + if obj is None: + next_node: Optional[_TreeNode[T]] = tree_node.get(None, None) + if next_node is not None: + yield from self._query(next_node, objs[1:], axes[1:]) + else: + # Get matches on this axis and iterate from most to least specific + axis = axes[0] + for match_key in axis.matches(obj, tree_node.keys()): + yield from self._query(tree_node[match_key], objs[1:], axes[1:]) - # Skip non-participating nodes - if obj is None: - next_node = tree_node.get(None, None) - if next_node is not None: - return self._lookup(next_node, objs[1:], axes[1:]) - return None - - # Get matches on this axis and iterate from most to least specific - axis = axes[0] - for match_key in axis.matches(obj, tree_node.keys()): - target = self._lookup(tree_node[match_key], objs[1:], axes[1:]) - if target is not None: - return target - - return None - - def _align_with_axes(self, args, kw): + def _align_with_axes( + self, args: Sequence[S], kw: Dict[str, S] + ) -> Sequence[Optional[S]]: """ Create a list matching up all args and kwargs with their corresponding axes, in order, using ``None`` as a placeholder for skipped axes. """ axes_dict = self._axes_dict - aligned = [None for i in xrange(len(axes_dict))] + aligned: List[Optional[S]] = [None for i in range(len(axes_dict))] args_len = len(args) - if args_len + len(kw) > len(aligned): + if args_len + len(kw) > len(aligned): raise ValueError("Cannot have more arguments than axes.") for i, arg in enumerate(args): @@ -91,12 +108,13 @@ class Registry(object): for k, v in kw.items(): i_axis = axes_dict.get(k, None) if i_axis is None: - raise ValueError("No axis with name: %s" % k) + raise ValueError(f"No axis with name: {k}") i, axis = i_axis if aligned[i] is not None: - raise ValueError("Axis defined twice between positional and " - "keyword arguments") + raise ValueError( + "Axis defined twice between positional and " "keyword arguments" + ) aligned[i] = v @@ -106,13 +124,15 @@ class Registry(object): return aligned -class _TreeNode(dict): - target = None - def __str__(self): - return "<TreeNode %s %s>" % (self.target, dict.__str__(self)) +class _TreeNode(Generic[T], Dict[Any, Any]): + target: Optional[T] = None -class SimpleAxis(object): + def __str__(self) -> str: + return f"<TreeNode {self.target} {dict.__str__(self)}>" + + +class SimpleAxis: """ A simple axis where the key into the axis is the same as the object to be matched (aka the identity axis). This axis behaves just like a dictionary. You might use this axis if you are interested in registering @@ -122,21 +142,23 @@ class SimpleAxis(object): Subclasses can override the ``get_keys`` method for implementing arbitrary axes. """ - def matches(self, obj, keys): - for key in self.get_keys(obj): + + def matches( + self, obj: object, keys: KeysView[Optional[object]] + ) -> Generator[object, None, None]: + for key in [obj]: if key in keys: - yield key + yield obj - def get_keys(self, obj): - """ - Return the keys for the given object that could match this axis, from - most specific to least specific. A convenient override point. - """ - return [obj,] -class TypeAxis(SimpleAxis): +class TypeAxis: """ An axis which matches the class and super classes of an object in method resolution order. """ - def get_keys(self, obj): - return type(obj).mro() + + def matches( + self, obj: object, keys: KeysView[Optional[type]] + ) -> Generator[type, None, None]: + for key in type(obj).mro(): + if key in keys: + yield key diff --git a/tests/test_event.py b/tests/test_event.py index e967bdd..63609d8 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -1,150 +1,181 @@ """ Tests for :module:`generic.event`.""" -import unittest +from __future__ import annotations -__all__ = ("ManagerTests",) +from typing import Callable, List +from generic.event import Manager -class ManagerTests(unittest.TestCase): - def makeHandler(self, effect): - return lambda e: e.effects.append(effect) +def make_handler(effect: object) -> Callable[[Event], None]: + return lambda e: e.effects.append(effect) - def createManager(self): - from generic.event import Manager - return Manager() - def test_subscribe_single_event(self): - events = self.createManager() - events.subscribe(self.makeHandler("handler1"), EventA) - e = EventA() - events.fire(e) - self.assertEqual(len(e.effects), 1) - self.assertTrue("handler1" in e.effects) +def create_manager(): + return Manager() - def test_subscribe_via_decorator(self): - events = self.createManager() - events.subscriber(EventA)(self.makeHandler("handler1")) - e = EventA() - events.fire(e) - self.assertEqual(len(e.effects), 1) - self.assertTrue("handler1" in e.effects) - def test_subscribe_event_inheritance(self): - events = self.createManager() - events.subscribe(self.makeHandler("handler1"), EventA) - events.subscribe(self.makeHandler("handler2"), EventB) +def test_subscribe_single_event(): + events = create_manager() + events.subscribe(make_handler("handler1"), EventA) + e = EventA() + events.handle(e) + assert len(e.effects) == 1 + assert "handler1" in e.effects - ea = EventA() - events.fire(ea) - self.assertEqual(len(ea.effects), 1) - self.assertTrue("handler1" in ea.effects) - eb = EventB() - events.fire(eb) - self.assertEqual(len(eb.effects), 2) - self.assertTrue("handler1" in eb.effects) - self.assertTrue("handler2" in eb.effects) +def test_subscribe_via_decorator(): + events = create_manager() + events.subscriber(EventA)(make_handler("handler1")) + e = EventA() + events.handle(e) + assert len(e.effects) == 1 + assert "handler1" in e.effects - def test_subscribe_event_multiple_inheritance(self): - events = self.createManager() - events.subscribe(self.makeHandler("handler1"), EventA) - events.subscribe(self.makeHandler("handler2"), EventC) - events.subscribe(self.makeHandler("handler3"), EventD) - ea = EventA() - events.fire(ea) - self.assertEqual(len(ea.effects), 1) - self.assertTrue("handler1" in ea.effects) +def test_subscribe_event_inheritance(): + events = create_manager() + events.subscribe(make_handler("handler1"), EventA) + events.subscribe(make_handler("handler2"), EventB) - ec = EventC() - events.fire(ec) - self.assertEqual(len(ec.effects), 1) - self.assertTrue("handler2" in ec.effects) + ea = EventA() + events.handle(ea) + assert len(ea.effects) == 1 + assert "handler1" in ea.effects - ed = EventD() - events.fire(ed) - self.assertEqual(len(ed.effects), 3) - self.assertTrue("handler1" in ed.effects) - self.assertTrue("handler2" in ed.effects) - self.assertTrue("handler3" in ed.effects) + eb = EventB() + events.handle(eb) + assert len(eb.effects) == 2 + assert "handler1" in eb.effects + assert "handler2" in eb.effects - def test_subscribe_event_malformed_multiple_inheritance(self): - events = self.createManager() - events.subscribe(self.makeHandler("handler1"), EventA) - events.subscribe(self.makeHandler("handler2"), EventD) - events.subscribe(self.makeHandler("handler3"), EventE) - ea = EventA() - events.fire(ea) - self.assertEqual(len(ea.effects), 1) - self.assertTrue("handler1" in ea.effects) +def test_subscribe_event_multiple_inheritance(): + events = create_manager() + events.subscribe(make_handler("handler1"), EventA) + events.subscribe(make_handler("handler2"), EventC) + events.subscribe(make_handler("handler3"), EventD) - ed = EventD() - events.fire(ed) - self.assertEqual(len(ed.effects), 2) - self.assertTrue("handler1" in ed.effects) - self.assertTrue("handler2" in ed.effects) + ea = EventA() + events.handle(ea) + assert len(ea.effects) == 1 + assert "handler1" in ea.effects - ee = EventE() - events.fire(ee) - self.assertEqual(len(ee.effects), 3) - self.assertTrue("handler1" in ee.effects) - self.assertTrue("handler2" in ee.effects) - self.assertTrue("handler3" in ee.effects) + ec = EventC() + events.handle(ec) + assert len(ec.effects) == 1 + assert "handler2" in ec.effects - def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(self): - events = self.createManager() - events.subscribe(self.makeHandler("handler1"), Event) - events.subscribe(self.makeHandler("handler2"), EventB) + ed = EventD() + events.handle(ed) + assert len(ed.effects) == 3 + assert "handler1" in ed.effects + assert "handler2" in ed.effects + assert "handler3" in ed.effects - eb = EventB() - events.fire(eb) - self.assertEqual(len(eb.effects), 2) - self.assertTrue("handler1" in eb.effects) - self.assertTrue("handler2" in eb.effects) - def test_unsubscribe_single_event(self): - events = self.createManager() - handler = self.makeHandler("handler1") - events.subscribe(handler, EventA) - events.unsubscribe(handler, EventA) - e = EventA() - events.fire(e) - self.assertEqual(len(e.effects), 0) +def test_subscribe_no_events(): + events = create_manager() - def test_unsubscribe_event_inheritance(self): - events = self.createManager() - handler1 = self.makeHandler("handler1") - handler2 = self.makeHandler("handler2") - events.subscribe(handler1, EventA) - events.subscribe(handler2, EventB) - events.unsubscribe(handler1, EventA) + ea = EventA() + events.handle(ea) + assert len(ea.effects) == 0 - ea = EventA() - events.fire(ea) - self.assertEqual(len(ea.effects), 0) - eb = EventB() - events.fire(eb) - self.assertEqual(len(eb.effects), 1) - self.assertTrue("handler2" in eb.effects) +def test_subscribe_base_event(): + events = create_manager() + events.subscribe(make_handler("handler1"), EventA) -class Event(object): + ea = EventB() + events.handle(ea) + assert len(ea.effects) == 1 + assert "handler1" in ea.effects + + +def test_subscribe_event_malformed_multiple_inheritance(): + events = create_manager() + events.subscribe(make_handler("handler1"), EventA) + events.subscribe(make_handler("handler2"), EventD) + events.subscribe(make_handler("handler3"), EventE) + + ea = EventA() + events.handle(ea) + assert len(ea.effects) == 1 + assert "handler1" in ea.effects + + ed = EventD() + events.handle(ed) + assert len(ed.effects) == 2 + assert "handler1" in ed.effects + assert "handler2" in ed.effects + + ee = EventE() + events.handle(ee) + assert len(ee.effects) == 3 + assert "handler1" in ee.effects + assert "handler2" in ee.effects + assert "handler3" in ee.effects + + +def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(): + events = create_manager() + events.subscribe(make_handler("handler1"), Event) + events.subscribe(make_handler("handler2"), EventB) + + eb = EventB() + events.handle(eb) + assert len(eb.effects) == 2 + assert "handler1" in eb.effects + assert "handler2" in eb.effects + + +def test_unsubscribe_single_event(): + events = create_manager() + handler = make_handler("handler1") + events.subscribe(handler, EventA) + events.unsubscribe(handler, EventA) + e = EventA() + events.handle(e) + assert len(e.effects) == 0 + + +def test_unsubscribe_event_inheritance(): + events = create_manager() + handler1 = make_handler("handler1") + handler2 = make_handler("handler2") + events.subscribe(handler1, EventA) + events.subscribe(handler2, EventB) + events.unsubscribe(handler1, EventA) + + ea = EventA() + events.handle(ea) + assert len(ea.effects) == 0 + + eb = EventB() + events.handle(eb) + assert len(eb.effects) == 1 + assert "handler2" in eb.effects + + +class Event: + def __init__(self) -> None: + self.effects: List[object] = [] - def __init__(self): - self.effects = [] class EventA(Event): pass + class EventB(EventA): pass + class EventC(Event): pass + class EventD(EventA, EventC): pass + class EventE(EventD, EventA): pass diff --git a/tests/test_multidispatch.py b/tests/test_multidispatch.py index bf4b02a..a99d5a0 100644 --- a/tests/test_multidispatch.py +++ b/tests/test_multidispatch.py @@ -1,270 +1,224 @@ """ Tests for :module:`generic.multidispatch`.""" -import unittest +import pytest -__all__ = ("DispatcherTests",) - -class DispatcherTests(unittest.TestCase): - - def createDispatcher(self, params_arity, args=None, varargs=None, - keywords=None, defaults=None): - from inspect import ArgSpec - from generic.multidispatch import FunctionDispatcher - return FunctionDispatcher(ArgSpec(args=args, varargs=varargs, - keywords=keywords, - defaults=defaults), params_arity) +from inspect import FullArgSpec +from generic.multidispatch import multidispatch, FunctionDispatcher - def test_one_argument(self): - dispatcher = self.createDispatcher(1, args=["x"]) +def create_dispatcher( + params_arity, args=None, varargs=None, keywords=None, defaults=None +) -> FunctionDispatcher: - dispatcher.register_rule(lambda x: x + 1, int) - self.assertEqual(dispatcher(1), 2) - self.assertRaises(TypeError, dispatcher, "s") + return FunctionDispatcher( + FullArgSpec( + args=args, + varargs=varargs, + varkw=keywords, + defaults=defaults, + kwonlyargs=[], + kwonlydefaults={}, + annotations={}, + ), + params_arity, + ) - dispatcher.register_rule(lambda x: x + "1", str) - self.assertEqual(dispatcher(1), 2) - self.assertEqual(dispatcher("1"), "11") - self.assertRaises(TypeError, dispatcher, tuple()) - def test_two_arguments(self): - dispatcher = self.createDispatcher(2, args=["x", "y"]) +def test_one_argument(): + dispatcher = create_dispatcher(1, args=["x"]) - dispatcher.register_rule(lambda x, y: x + y + 1, int, int) - self.assertEqual(dispatcher(1, 2), 4) - self.assertRaises(TypeError, dispatcher, "s", "ss") - self.assertRaises(TypeError, dispatcher, 1, "ss") - self.assertRaises(TypeError, dispatcher, "s", 2) + dispatcher.register_rule(lambda x: x + 1, int) + assert dispatcher(1) == 2 + with pytest.raises(TypeError): + dispatcher("s") - dispatcher.register_rule(lambda x, y: x + y + "1", str, str) - self.assertEqual(dispatcher(1, 2), 4) - self.assertEqual(dispatcher("1", "2"), "121") - self.assertRaises(TypeError, dispatcher, "1", 1) - self.assertRaises(TypeError, dispatcher, 1, "1") + dispatcher.register_rule(lambda x: x + "1", str) + assert dispatcher(1) == 2 + assert dispatcher("1") == "11" + with pytest.raises(TypeError): + dispatcher(tuple()) - dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str) - self.assertEqual(dispatcher(1, 2), 4) - self.assertEqual(dispatcher("1", "2"), "121") - self.assertEqual(dispatcher(1, "2"), "121") - self.assertRaises(TypeError, dispatcher, "1", 1) - def test_bottom_rule(self): - dispatcher = self.createDispatcher(1, args=["x"]) +def test_two_arguments(): + dispatcher = create_dispatcher(2, args=["x", "y"]) - dispatcher.register_rule(lambda x: x, object) - self.assertEqual(dispatcher(1), 1) - self.assertEqual(dispatcher("1"), "1") - self.assertEqual(dispatcher([1]), [1]) - self.assertEqual(dispatcher((1,)), (1,)) + dispatcher.register_rule(lambda x, y: x + y + 1, int, int) + assert dispatcher(1, 2) == 4 + with pytest.raises(TypeError): + dispatcher("s", "ss") + with pytest.raises(TypeError): + dispatcher(1, "ss") + with pytest.raises(TypeError): + dispatcher("s", 2) - def test_subtype_evaluation(self): - class Super(object): - pass - class Sub(Super): - pass + dispatcher.register_rule(lambda x, y: x + y + "1", str, str) + assert dispatcher(1, 2) == 4 + assert dispatcher("1", "2") == "121" + with pytest.raises(TypeError): + dispatcher("1", 1) + with pytest.raises(TypeError): + dispatcher(1, "1") - dispatcher = self.createDispatcher(1, args=["x"]) + dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str) + assert dispatcher(1, 2) == 4 + assert dispatcher("1", "2") == "121" + assert dispatcher(1, "2") == "121" + with pytest.raises(TypeError): + dispatcher("1", 1) - dispatcher.register_rule(lambda x: x, Super) - o_super = Super() - self.assertEqual(dispatcher(o_super), o_super) - o_sub = Sub() - self.assertEqual(dispatcher(o_sub), o_sub) - self.assertRaises(TypeError, dispatcher, object()) - dispatcher.register_rule(lambda x: (x, x), Sub) - o_super = Super() - self.assertEqual(dispatcher(o_super), o_super) - o_sub = Sub() - self.assertEqual(dispatcher(o_sub), (o_sub, o_sub)) +def test_bottom_rule(): + dispatcher = create_dispatcher(1, args=["x"]) - def test_register_rule_with_wrong_arity(self): - dispatcher = self.createDispatcher(1, args=["x"]) - dispatcher.register_rule(lambda x: x, int) - self.assertRaises( - TypeError, - dispatcher.register_rule, lambda x, y: x, str) + dispatcher.register_rule(lambda x: x, object) + assert dispatcher(1) == 1 + assert dispatcher("1") == "1" + assert dispatcher([1]) == [1] + assert dispatcher((1,)) == (1,) - def test_register_rule_with_different_arg_names(self): - dispatcher = self.createDispatcher(1, args=["x"]) - dispatcher.register_rule(lambda y: y, int) - self.assertEqual(dispatcher(1), 1) - def test_dispatching_with_varargs(self): - dispatcher = self.createDispatcher(1, args=["x"], varargs="va") - dispatcher.register_rule(lambda x, *va: x, int) - self.assertEqual(dispatcher(1), 1) - self.assertRaises(TypeError, dispatcher, "1", 2, 3) +def test_subtype_evaluation(): + class Super: + pass - def test_dispatching_with_varkw(self): - dispatcher = self.createDispatcher(1, args=["x"], keywords="vk") - dispatcher.register_rule(lambda x, **vk: x, int) - self.assertEqual(dispatcher(1), 1) - self.assertRaises(TypeError, dispatcher, "1", a=1, b=2) + class Sub(Super): + pass - def test_dispatching_with_kw(self): - dispatcher = self.createDispatcher(1, args=["x", "y"], defaults=["vk"]) - dispatcher.register_rule(lambda x, y=1: x, int) - self.assertEqual(dispatcher(1), 1) - self.assertRaises(TypeError, dispatcher, "1", k=1) + dispatcher = create_dispatcher(1, args=["x"]) - def test_create_dispatcher_with_pos_args_less_multi_arity(self): - self.assertRaises(TypeError, self.createDispatcher, 2, args=["x"]) - self.assertRaises(TypeError, self.createDispatcher, 2, args=["x", "y"], - defaults=["x"]) + dispatcher.register_rule(lambda x: x, Super) + o_super = Super() + assert dispatcher(o_super) == o_super + o_sub = Sub() + assert dispatcher(o_sub) == o_sub + with pytest.raises(TypeError): + dispatcher(object()) - def test_register_rule_with_wrong_number_types_parameters(self): - dispatcher = self.createDispatcher(1, args=["x", "y"]) - self.assertRaises( - TypeError, - dispatcher.register_rule, lambda x, y: x, int, str) + dispatcher.register_rule(lambda x: (x, x), Sub) + o_super = Super() + assert dispatcher(o_super) == o_super + o_sub = Sub() + assert dispatcher(o_sub) == (o_sub, o_sub) - def test_register_rule_with_partial_dispatching(self): - dispatcher = self.createDispatcher(1, args=["x", "y"]) - dispatcher.register_rule(lambda x, y: x, int) - self.assertEqual(dispatcher(1, 2), 1) - self.assertEqual(dispatcher(1, "2"), 1) - self.assertRaises(TypeError, dispatcher, "2", 1) + +def test_register_rule_with_wrong_arity(): + dispatcher = create_dispatcher(1, args=["x"]) + dispatcher.register_rule(lambda x: x, int) + with pytest.raises(TypeError): dispatcher.register_rule(lambda x, y: x, str) - self.assertEqual(dispatcher(1, 2), 1) - self.assertEqual(dispatcher(1, "2"), 1) - self.assertEqual(dispatcher("1", "2"), "1") - self.assertEqual(dispatcher("1", 2), "1") -class MultifunctionTests(unittest.TestCase): - def test_it(self): - from generic.multidispatch import multifunction +def test_register_rule_with_different_arg_names(): + dispatcher = create_dispatcher(1, args=["x"]) + dispatcher.register_rule(lambda y: y, int) + assert dispatcher(1) == 1 - @multifunction(int, str) - def func(x, y): - return str(x) + y - self.assertEqual(func(1, "2"), "12") - self.assertRaises(TypeError, func, 1, 2) - self.assertRaises(TypeError, func, "1", 2) - self.assertRaises(TypeError, func, "1", "2") +def test_dispatching_with_varargs(): + dispatcher = create_dispatcher(1, args=["x"], varargs="va") + dispatcher.register_rule(lambda x, *va: x, int) + assert dispatcher(1) == 1 + with pytest.raises(TypeError): + dispatcher("1", 2, 3) - @func.when(str, str) - def func(x, y): - return x + y - self.assertEqual(func(1, "2"), "12") - self.assertEqual(func("1", "2"), "12") - self.assertRaises(TypeError, func, 1, 2) - self.assertRaises(TypeError, func, "1", 2) +def test_dispatching_with_varkw(): + dispatcher = create_dispatcher(1, args=["x"], keywords="vk") + dispatcher.register_rule(lambda x, **vk: x, int) + assert dispatcher(1) == 1 + with pytest.raises(TypeError): + dispatcher("1", a=1, b=2) - def test_overriding(self): - from generic.multidispatch import multifunction - @multifunction(int, str) - def func(x, y): - return str(x) + y +def test_dispatching_with_kw(): + dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"]) + dispatcher.register_rule(lambda x, y=1: x, int) + assert dispatcher(1) == 1 + with pytest.raises(TypeError): + dispatcher("1", k=1) - self.assertEqual(func(1, "2"), "12") - self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x)) - @func.override(int, str) - def func(x, y): - return y + str(x) +def test_create_dispatcher_with_pos_args_less_multi_arity(): + with pytest.raises(TypeError): + create_dispatcher(2, args=["x"]) + with pytest.raises(TypeError): + create_dispatcher(2, args=["x", "y"], defaults=["x"]) - self.assertEqual(func(1, "2"), "21") -class MultimethodTests(unittest.TestCase): +def test_register_rule_with_wrong_number_types_parameters(): + dispatcher = create_dispatcher(1, args=["x", "y"]) + with pytest.raises(TypeError): + dispatcher.register_rule(lambda x, y: x, int, str) - def test_multimethod(self): - from generic.multidispatch import multimethod - from generic.multidispatch import has_multimethods - @has_multimethods - class Dummy(object): +def test_register_rule_with_partial_dispatching(): + dispatcher = create_dispatcher(1, args=["x", "y"]) + dispatcher.register_rule(lambda x, y: x, int) + assert dispatcher(1, 2) == 1 + assert dispatcher(1, "2") == 1 + with pytest.raises(TypeError): + dispatcher("2", 1) + dispatcher.register_rule(lambda x, y: x, str) + assert dispatcher(1, 2) == 1 + assert dispatcher(1, "2") == 1 + assert dispatcher("1", "2") == "1" + assert dispatcher("1", 2) == "1" - @multimethod(int) - def foo(self, x): - return x + 1 - @foo.when(str) - def foo(self, x): - return x + "1" +def test_default_dispatcher(): + @multidispatch(int, str) + def func(x, y): + return str(x) + y - self.assertEqual(Dummy().foo(1), 2) - self.assertEqual(Dummy().foo("1"), "11") - self.assertRaises(TypeError, Dummy().foo, []) + assert func(1, "2") == "12" + with pytest.raises(TypeError): + func(1, 2) + with pytest.raises(TypeError): + func("1", 2) + with pytest.raises(TypeError): + func("1", "2") - def test_inheritance(self): - from generic.multidispatch import multimethod - from generic.multidispatch import has_multimethods - @has_multimethods - class Dummy(object): +def test_multiple_functions(): + @multidispatch(int, str) + def func(x, y): + return str(x) + y - @multimethod(int) - def foo(self, x): - return x + 1 + @func.register(str, str) + def _(x, y): + return x + y - @foo.when(float) - def foo(self, x): - return x + 1.5 + assert func(1, "2") == "12" + assert func("1", "2") == "12" + with pytest.raises(TypeError): + func(1, 2) + with pytest.raises(TypeError): + func("1", 2) - @has_multimethods - class DummySub(Dummy): - @Dummy.foo.when(str) - def foo(self, x): - return x + "1" +def test_default(): + @multidispatch() + def func(x, y): + return x + y - @foo.when(tuple) - def foo(self, x): - return x + (1,) + @func.register(str, str) + def _(x, y): + return y + x - @Dummy.foo.when(bool) - def foo(self, x): - return not x + assert func(1, 1) == 2 + assert func("1", "2") == "21" - self.assertEqual(Dummy().foo(1), 2) - self.assertEqual(Dummy().foo(1.5), 3.0) - self.assertRaises(TypeError, Dummy().foo, "1") - self.assertEqual(DummySub().foo(1), 2) - self.assertEqual(DummySub().foo(1.5), 3.0) - self.assertEqual(DummySub().foo("1"), "11") - self.assertEqual(DummySub().foo((1,2)), (1,2,1)) - self.assertEqual(DummySub().foo(True), False) - self.assertRaises(TypeError, DummySub().foo, []) - def test_override(self): - from generic.multidispatch import multimethod - from generic.multidispatch import has_multimethods +def test_on_classes(): + @multidispatch() + class A: + def __init__(self, a, b): + self.v = a + b - @has_multimethods - class Dummy(object): + @A.register(str, str) # type: ignore[attr-defined] + class B: + def __init__(self, a, b): + self.v = b + a - @multimethod(str, str) - def foo(self, x, y): - return x + y - - @foo.when(str, str) - def foo(self, x, y): - return y + x - - self.assertEqual(Dummy().foo("1", "2"), "21") - - def test_inheritance_override(self): - from generic.multidispatch import multimethod - from generic.multidispatch import has_multimethods - - @has_multimethods - class Dummy(object): - - @multimethod(int) - def foo(self, x): - return x + 1 - - @has_multimethods - class DummySub(Dummy): - - @Dummy.foo.when(int) - def foo(self, x): - return x + 2 - - self.assertEqual(Dummy().foo(1), 2) - self.assertEqual(DummySub().foo(1), 3) + assert A(1, 1).v == 2 + assert A("1", "2").v == "21" diff --git a/tests/test_registry.py b/tests/test_registry.py index eda2b80..d02384f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,135 +1,147 @@ """ Tests for :module:`generic.registry`.""" -import unittest +import pytest -__all__ = ("RegistryTests",) +from typing import Union +from generic.registry import Registry, SimpleAxis, TypeAxis -class RegistryTests(unittest.TestCase): - def test_one_axis_no_specificity(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry(('foo', SimpleAxis())) - a = object() - b = object() - registry.register(a) - registry.register(b, 'foo') - - self.assertEqual(registry.lookup(), a) - self.assertEqual(registry.lookup('foo'), b) - self.assertEqual(registry.lookup('bar'), None) - - def test_two_axes(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - from generic.registry import TypeAxis - registry = Registry(('type', TypeAxis()), - ('name', SimpleAxis())) - - target1 = Target('one') - registry.register(target1, object) - - target2 = Target('two') - registry.register(target2, DummyA) - - target3 = Target('three') - registry.register(target3, DummyA, 'foo') - - context1 = object() - self.assertEqual(registry.lookup(context1), target1) - - context2 = DummyB() - self.assertEqual(registry.lookup(context2), target2) - self.assertEqual(registry.lookup(context2, 'foo'), target3) - - target4 = object() - registry.register(target4, DummyB) - - self.assertEqual(registry.lookup(context2), target4) - self.assertEqual(registry.lookup(context2, 'foo'), target3) - - def test_get_registration(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - from generic.registry import TypeAxis - registry = Registry(('type', TypeAxis()), - ('name', SimpleAxis())) - registry.register('one', object) - registry.register('two', DummyA, 'foo') - self.assertEqual(registry.get_registration(object), 'one') - self.assertEqual(registry.get_registration(DummyA, 'foo'), 'two') - self.assertEqual(registry.get_registration(object, 'foo'), None) - self.assertEqual(registry.get_registration(DummyA), None) - - def test_register_too_many_keys(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry(('name', SimpleAxis())) - self.assertRaises(ValueError, registry.register, object(), - 'one', 'two') - - def test_lookup_too_many_keys(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry(('name', SimpleAxis())) - self.assertRaises(ValueError, registry.lookup, 'one', 'two') - - def test_conflict_error(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry(('name', SimpleAxis())) - registry.register(object(), name='foo') - self.assertRaises(ValueError, registry.register, object(), 'foo') - - def test_override(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry(('name', SimpleAxis())) - registry.register(1, name='foo') - registry.override(2, name='foo') - self.assertEqual(registry.lookup('foo'), 2) - - def test_skip_nodes(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry( - ('one', SimpleAxis()), - ('two', SimpleAxis()), - ('three', SimpleAxis()) - ) - registry.register('foo', one=1, three=3) - self.assertEqual(registry.lookup(1, three=3), 'foo') - - def test_miss(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry( - ('one', SimpleAxis()), - ('two', SimpleAxis()), - ('three', SimpleAxis()) - ) - registry.register('foo', 1, 2) - self.assertEqual(registry.lookup(one=1, three=3), None) - - def test_bad_lookup(self): - from generic.registry import Registry - from generic.registry import SimpleAxis - registry = Registry(('name', SimpleAxis()), - ('grade', SimpleAxis())) - self.assertRaises(ValueError, registry.register, 1, foo=1) - self.assertRaises(ValueError, registry.lookup, foo=1) - self.assertRaises(ValueError, registry.register, 1, 'foo', name='foo') - -class DummyA(object): +class DummyA: pass + class DummyB(DummyA): pass -class Target(object): - def __init__(self, name): - self.name = name - # Only called if being printed due to a failing test - def __repr__(self): #pragma NO COVERAGE - return "Target('%s')" % self.name +def test_one_axis_no_specificity(): + registry: Registry[object] = Registry(("foo", SimpleAxis())) + a = object() + b = object() + registry.register(a) + registry.register(b, "foo") + + assert registry.lookup() == a + assert registry.lookup("foo") == b + assert registry.lookup("bar") is None + + +def test_subtyping_on_axes(): + registry: Registry[str] = Registry(("type", TypeAxis())) + + target1 = "one" + registry.register(target1, object) + + target2 = "two" + registry.register(target2, DummyA) + + target3 = "three" + registry.register(target3, DummyB) + + assert registry.lookup(object()) == target1 + assert registry.lookup(DummyA()) == target2 + assert registry.lookup(DummyB()) == target3 + + +def test_query_subtyping_on_axes(): + registry: Registry[str] = Registry(("type", TypeAxis())) + + target1 = "one" + registry.register(target1, object) + + target2 = "two" + registry.register(target2, DummyA) + + target3 = "three" + registry.register(target3, DummyB) + + target4 = "four" + registry.register(target4, int) + + assert list(registry.query(object())) == [target1] + assert list(registry.query(DummyA())) == [target2, target1] + assert list(registry.query(DummyB())) == [target3, target2, target1] + assert list(registry.query(3)) == [target4, target1] + + +def test_two_axes(): + registry: Registry[Union[str, object]] = Registry( + ("type", TypeAxis()), ("name", SimpleAxis()) + ) + + target1 = "one" + registry.register(target1, object) + + target2 = "two" + registry.register(target2, DummyA) + + target3 = "three" + registry.register(target3, DummyA, "foo") + + context1 = object() + assert registry.lookup(context1) == target1 + + context2 = DummyB() + assert registry.lookup(context2) == target2 + assert registry.lookup(context2, "foo") == target3 + + target4 = object() + registry.register(target4, DummyB) + + assert registry.lookup(context2) == target4 + assert registry.lookup(context2, "foo") == target3 + + +def test_get_registration(): + registry: Registry[str] = Registry(("type", TypeAxis()), ("name", SimpleAxis())) + registry.register("one", object) + registry.register("two", DummyA, "foo") + assert registry.get_registration(object) == "one" + assert registry.get_registration(DummyA, "foo") == "two" + assert registry.get_registration(object, "foo") is None + assert registry.get_registration(DummyA) is None + + +def test_register_too_many_keys(): + registry: Registry[type] = Registry(("name", SimpleAxis())) + with pytest.raises(ValueError): + registry.register(object, "one", "two") + + +def test_lookup_too_many_keys(): + registry: Registry[object] = Registry(("name", SimpleAxis())) + with pytest.raises(ValueError): + registry.register(registry.lookup("one", "two")) + + +def test_conflict_error(): + registry: Registry[Union[object, type]] = Registry(("name", SimpleAxis())) + registry.register(object(), name="foo") + with pytest.raises(ValueError): + registry.register(object, "foo") + + +def test_skip_nodes(): + registry: Registry[str] = Registry( + ("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis()) + ) + registry.register("foo", one=1, three=3) + assert registry.lookup(1, three=3) == "foo" + + +def test_miss(): + registry: Registry[str] = Registry( + ("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis()) + ) + registry.register("foo", 1, 2) + assert registry.lookup(one=1, three=3) is None + + +def test_bad_lookup(): + registry: Registry[int] = Registry(("name", SimpleAxis()), ("grade", SimpleAxis())) + with pytest.raises(ValueError): + registry.register(1, foo=1) + with pytest.raises(ValueError): + registry.lookup(foo=1) + with pytest.raises(ValueError): + registry.register(1, "foo", name="foo")