Replace old code by Python3 implementation from Gaphor

Multimethods are missing now. Those have to be added back.
This commit is contained in:
Arjan Molenaar 2019-11-08 16:35:11 +01:00
parent ee35c528cb
commit bb20f6992b
6 changed files with 641 additions and 729 deletions

View File

@ -2,7 +2,7 @@
This module provides API for event management. There are two APIs provided: This module provides API for event management. There are two APIs provided:
* Global event management API: subscribe, unsubscribe, fire. * Global event management API: subscribe, unsubscribe, handle.
* Local event management API: Manager * Local event management API: Manager
If you run only one instance of your application per Python If you run only one instance of your application per Python
@ -12,116 +12,83 @@ to have different configurations for them -- you should use local API
and have one instance of Manager object per application instance. and have one instance of Manager object per application instance.
""" """
from collections import namedtuple from typing import Callable, Set, Type
from generic.registry import Registry from generic.registry import Registry, TypeAxis
from generic.registry import TypeAxis
__all__ = ("Manager", "subscribe", "unsubscribe", "fire", "subscriber")
class HandlerSet(namedtuple("HandlerSet", ["parents", "handlers"])): __all__ = "Manager"
""" Set of handlers for specific type of event.
This object stores ``handlers`` for specific event type and Event = object
``parents`` reference to handler sets of event's supertypes. Handler = Callable[[object], None]
""" HandlerSet = Set[Handler]
@property
def all_handlers(self):
""" Iterate over own and supertypes' handlers.
This iterator yields just unique values, so it won't yield the class Manager:
same handler twice, even if it was registered both for some
event type and its supertype.
"""
seen = set()
seen_add = seen.add
# yield own handlers first
for handler in self.handlers:
seen_add(handler)
yield handler
# yield supertypes' handlers then
for parent in self.parents:
for handler in parent.all_handlers:
if not handler in seen:
seen_add(handler)
yield handler
class Manager(object):
""" Event manager """ Event manager
Provides API for subscribing for and firing events. There's also global Provides API for subscribing for and firing events. There's also global
event manager instantiated at module level with functions event manager instantiated at module level with functions
:func:`.subscribe`, :func:`.fire` and decorator :func:`.subscriber` aliased :func:`.subscribe`, :func:`.handle` and decorator :func:`.subscriber` aliased
to corresponding methods of class. to corresponding methods of class.
""" """
def __init__(self): registry: Registry[HandlerSet]
def __init__(self) -> None:
axes = (("event_type", TypeAxis()),) axes = (("event_type", TypeAxis()),)
self.registry = Registry(*axes) self.registry = Registry(*axes)
def subscribe(self, handler, event_type): def subscribe(self, handler: Handler, event_type: Type[Event]) -> None:
""" Subscribe ``handler`` to specified ``event_type``""" """ Subscribe ``handler`` to specified ``event_type``"""
handler_set = self.registry.get_registration(event_type) handler_set = self.registry.get_registration(event_type)
if not handler_set: if handler_set is None:
handler_set = self._register_handler_set(event_type) handler_set = self._register_handler_set(event_type)
handler_set.handlers.add(handler) handler_set.add(handler)
def unsubscribe(self, handler, event_type): def unsubscribe(self, handler: Handler, event_type: Type[Event]) -> None:
""" Unsubscribe ``handler`` from ``event_type``""" """ Unsubscribe ``handler`` from ``event_type``"""
handler_set = self.registry.get_registration(event_type) handler_set = self.registry.get_registration(event_type)
if handler_set and handler in handler_set.handlers: if handler_set and handler in handler_set:
handler_set.handlers.remove(handler) handler_set.remove(handler)
def fire(self, event): def handle(self, event: Event) -> None:
""" Fire ``event`` """ Fire ``event``
All subscribers will be executed with no determined order. All subscribers will be executed with no determined order.
""" """
handler_set = self.registry.lookup(event) handler_sets = self.registry.query(event)
for handler in handler_set.all_handlers: for handler_set in handler_sets:
handler(event) if handler_set:
for handler in set(handler_set):
handler(event)
def _register_handler_set(self, event_type): def _register_handler_set(self, event_type: Type[Event]) -> HandlerSet:
""" Register new handler set for ``event_type``.""" """ Register new handler set for ``event_type``.
# Collect handler sets for supertypes """
parent_handler_sets = [] handler_set: HandlerSet = set()
parents = event_type.__bases__
for parent in parents:
parent_handlers = self.registry.get_registration(parent)
if parent_handlers is None:
parent_handlers = self._register_handler_set(parent)
parent_handler_sets.append(parent_handlers)
handler_set = HandlerSet(parents=parent_handler_sets, handlers=set())
self.registry.register(handler_set, event_type) self.registry.register(handler_set, event_type)
return handler_set return handler_set
def subscriber(self, event_type): def subscriber(self, event_type: Type[Event]) -> Callable[[Handler], Handler]:
""" Decorator for subscribing handlers """ Decorator for subscribing handlers
Works like this: Works like this:
>>> mymanager = Manager()
>>> class MyEvent():
... pass
>>> @mymanager.subscriber(MyEvent) >>> @mymanager.subscriber(MyEvent)
... def mysubscriber(evt): ... def mysubscriber(evt):
... # handle event ... # handle event
... return ... return
>>> mymanager.fire(MyEvent()) >>> mymanager.handle(MyEvent())
""" """
def registrator(func):
def registrator(func: Handler) -> Handler:
self.subscribe(func, event_type) self.subscribe(func, event_type)
return func return func
return registrator return registrator
# Global event manager
_global_manager = Manager()
# Global event management API
subscribe = _global_manager.subscribe
unsubscribe = _global_manager.unsubscribe
fire = _global_manager.fire
subscriber = _global_manager.subscriber

View File

@ -1,206 +1,132 @@
""" Multidispatch for functions and methods""" """ Multidispatch for functions and methods.
This code is a Python 3, slimmed down version of the
generic package by Andrey Popp.
Only the generic function code is left in tact -- no generic methods.
The interface has been made in line with `functools.singledispatch`.
Note that this module does not support annotated functions.
"""
from __future__ import annotations
from typing import cast, Any, Callable, Generic, TypeVar, Union
import functools import functools
import inspect import inspect
import types
import threading
from generic.registry import Registry from generic.registry import Registry, TypeAxis
from generic.registry import TypeAxis
__all__ = ("multifunction", "multimethod", "has_multimethods") __all__ = "multidispatch"
def multifunction(*argtypes): T = TypeVar("T", bound=Union[Callable[..., Any], type])
""" Declare function as multifunction KeyType = Union[type, None]
def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]:
""" Declare function as multidispatch
This decorator takes ``argtypes`` argument types and replace decorated This decorator takes ``argtypes`` argument types and replace decorated
function with :class:`.FunctionDispatcher` object, which is responsible for function with :class:`.FunctionDispatcher` object, which is responsible for
multiple dispatch feature. multiple dispatch feature.
""" """
def _replace_with_dispatcher(func):
dispatcher = _make_dispatcher(FunctionDispatcher, func, len(argtypes)) def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]:
nonlocal argtypes
argspec = inspect.getfullargspec(func)
if not argtypes:
arity = _arity(argspec)
if isinstance(func, type):
# It's a class we deal with:
arity -= 1
argtypes = (object,) * arity
dispatcher = cast(
FunctionDispatcher[T],
functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func),
)
dispatcher.register_rule(func, *argtypes) dispatcher.register_rule(func, *argtypes)
return dispatcher return dispatcher
return _replace_with_dispatcher return _replace_with_dispatcher
def multimethod(*argtypes):
""" Declare method as multimethod
This decorator works exactly the same as :func:`.multifunction` decorator class FunctionDispatcher(Generic[T]):
but replaces decorated method with :class:`.MethodDispatcher` object
instead.
Should be used only for decorating methods and enclosing class should have
:func:`.has_multimethods` decorator.
"""
def _replace_with_dispatcher(func):
dispatcher = _make_dispatcher(MethodDispatcher, func, len(argtypes) + 1)
dispatcher.register_unbound_rule(func, *argtypes)
return dispatcher
return _replace_with_dispatcher
def has_multimethods(cls):
""" Declare class as one that have multimethods
Should only be used for decorating classes which have methods decorated with
:func:`.multimethod` decorator.
"""
for name, obj in cls.__dict__.items():
if isinstance(obj, MethodDispatcher):
obj.proceed_unbound_rules(cls)
return cls
class FunctionDispatcher(object):
""" Multidispatcher for functions """ Multidispatcher for functions
This object dispatch calls to function by its argument types. Usually it is This object dispatch calls to function by its argument types. Usually it is
produced by :func:`.multifunction` decorator. produced by :func:`.multidispatch` decorator.
You should not manually create objects of this type. You should not manually create objects of this type.
""" """
def __init__(self, argspec, params_arity): registry: Registry[T]
def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None:
""" Initialize dispatcher with ``argspec`` of type """ Initialize dispatcher with ``argspec`` of type
:class:`inspect.ArgSpec` and ``params_arity`` that represent number :class:`inspect.ArgSpec` and ``params_arity`` that represent number
params.""" params."""
# Check if we have enough positional arguments for number of type params # Check if we have enough positional arguments for number of type params
if arity(argspec) < params_arity: if _arity(argspec) < params_arity:
raise TypeError("Not enough positional arguments " raise TypeError(
"for number of type parameters provided.") "Not enough positional arguments "
"for number of type parameters provided."
)
self.argspec = argspec self.argspec = argspec
self.params_arity = params_arity self.params_arity = params_arity
axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)] axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)]
self.registry = Registry(*axis) self.registry = Registry(*axis)
def check_rule(self, rule, *argtypes): def check_rule(self, rule: T, *argtypes: KeyType) -> None:
# Check if we have the right number of parametrized types # Check if we have the right number of parametrized types
if len(argtypes) != self.params_arity: if len(argtypes) != self.params_arity:
raise TypeError("Wrong number of type parameters.") raise TypeError(
f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}."
)
# Check if we have the same argspec (by number of args) # Check if we have the same argspec (by number of args)
rule_argspec = inspect.getargspec(rule) rule_argspec = inspect.getfullargspec(rule)
if not is_equalent_argspecs(rule_argspec, self.argspec): left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4])
raise TypeError("Rule does not conform " right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4])
"to previous implementations.") if left_spec != right_spec:
raise TypeError(
f"Rule does not conform to previous implementations: {left_spec} != {right_spec}."
)
def register_rule(self, rule, *argtypes): def register_rule(self, rule: T, *argtypes: KeyType) -> None:
""" Register new ``rule`` for ``argtypes``.""" """ Register new ``rule`` for ``argtypes``."""
self.check_rule(rule, *argtypes) self.check_rule(rule, *argtypes)
self.registry.register(rule, *argtypes) self.registry.register(rule, *argtypes)
def override_rule(self, rule, *argtypes): def register(self, *argtypes: KeyType) -> Callable[[T], T]:
""" Override ``rule`` for ``argtypes``.""" """ Decorator for registering new case for multidispatch
self.check_rule(rule, *argtypes)
self.registry.override(rule, *argtypes)
def lookup_rule(self, *args):
""" Lookup rule by ``args``. Returns None if no rule was found."""
args = args[:self.params_arity]
rule = self.registry.lookup(*args)
if rule is None:
raise TypeError("No available rule found for %r" % (args,))
return rule
def when(self, *argtypes):
""" Decorator for registering new case for multifunction
New case will be registered for types identified by ``argtypes``. The New case will be registered for types identified by ``argtypes``. The
length of ``argtypes`` should be equal to the length of ``argtypes`` length of ``argtypes`` should be equal to the length of ``argtypes``
argument were passed corresponding :func:`.multifunction` call, which argument were passed corresponding :func:`.multidispatch` call, which
also indicated the number of arguments multifunction dispatches on. also indicated the number of arguments multidispatch dispatches on.
""" """
def register_rule(func):
def register_rule(func: T) -> T:
self.register_rule(func, *argtypes) self.register_rule(func, *argtypes)
return self return func
return register_rule return register_rule
@property def __call__(self, *args: Any, **kwargs: Any) -> Any:
def otherwise(self):
""" Decorator which registeres "catch-all" case for multifunction"""
def register_rule(func):
self.register_rule(func, [object]*self.params_arity)
return self
return register_rule
def override(self, *argtypes):
""" Decorator for overriding case for ``argtypes``"""
def override_rule(func):
self.override_rule(func, *argtypes)
return self
return override_rule
def __call__(self, *args, **kwargs):
""" Dispatch call to appropriate rule.""" """ Dispatch call to appropriate rule."""
rule = self.lookup_rule(*args) trimmed_args = args[: self.params_arity]
rule = self.registry.lookup(*trimmed_args)
if not rule:
raise TypeError(f"No available rule found for {trimmed_args!r}")
return rule(*args, **kwargs) return rule(*args, **kwargs)
class MethodDispatcher(FunctionDispatcher):
""" Multiple dispatch for methods
This object dispatch call to method by its class and arguments types. def _arity(argspec: inspect.FullArgSpec) -> int:
Usually it is produced by :func:`.multimethod` decorator.
You should not manually create objects of this type.
"""
def __init__(self, argspec, params_arity):
FunctionDispatcher.__init__(self, argspec, params_arity)
# some data, that should be local to thread of execution
self.local = threading.local()
self.local.unbound_rules = []
def register_unbound_rule(self, func, *argtypes):
""" Register unbound rule that should be processed by
``proceed_unbound_rules`` later."""
self.local.unbound_rules.append((argtypes, func))
def proceed_unbound_rules(self, cls):
""" Process all unbound rule by binding them to ``cls`` type."""
for argtypes, func in self.local.unbound_rules:
argtypes = (cls,) + argtypes
self.override_rule(func, *argtypes)
self.local.unbound_rules = []
def __get__(self, obj, cls):
if obj is None:
return self
return types.MethodType(self, obj)
def when(self, *argtypes):
""" Register new case for multimethod for ``argtypes``"""
def make_declaration(meth):
self.register_unbound_rule(meth, *argtypes)
return self
return make_declaration
def override(self, *argtypes):
""" Decorator for overriding case for ``argtypes``"""
return self.when(*argtypes)
@property
def otherwise(self):
""" Decorator which registeres "catch-all" case for multimethod"""
def make_declaration(func):
self.register_unbound_rule(func, [object]*self.params_arity)
return self
return make_declaration
def arity(argspec):
""" Determinal positional arity of argspec.""" """ Determinal positional arity of argspec."""
args = argspec.args if argspec.args else [] args = argspec.args if argspec.args else []
defaults = argspec.defaults if argspec.defaults else [] defaults = argspec.defaults if argspec.defaults else []
return len(args) - len(defaults) return len(args) - len(defaults)
def is_equalent_argspecs(left, right):
""" Check argspec equalence."""
return map(lambda x: len(x) if x else 0, left) == \
map(lambda x: len(x) if x else 0, right)
def _make_dispatcher(dispacther_cls, func, params_arity):
argspec = inspect.getargspec(func)
wrapper = functools.wraps(func)
dispatcher = wrapper(dispacther_cls(argspec, params_arity))
return dispatcher

View File

@ -5,84 +5,101 @@ This implementation was borrowed from happy[1] project by Chris Rossi.
[1]: http://bitbucket.org/chrisrossi/happy [1]: http://bitbucket.org/chrisrossi/happy
""" """
from __future__ import annotations
__all__ = ("Registry", "SimpleAxis", "TypeAxis") __all__ = ("Registry", "SimpleAxis", "TypeAxis")
class Registry(object): from typing import (
Any,
Dict,
Generic,
KeysView,
List,
Generator,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
K = TypeVar("K")
S = TypeVar("S")
T = TypeVar("T")
V = TypeVar("V")
Axis = Union["SimpleAxis", "TypeAxis"]
class Registry(Generic[T]):
""" Registry implementation.""" """ Registry implementation."""
def __init__(self, *axes): def __init__(self, *axes: Tuple[str, Axis]):
self._tree = _TreeNode() self._tree: _TreeNode[T] = _TreeNode()
self._axes = [axis for name, axis in axes] self._axes = [axis for name, axis in axes]
self._axes_dict = dict([ self._axes_dict = {name: (i, axis) for i, (name, axis) in enumerate(axes)}
(name, (i, axis)) for i, (name, axis) in enumerate(axes)
])
def register(self, target, *arg_keys, **kw_keys): def register(self, target: T, *arg_keys: K, **kw_keys: K) -> None:
self._register(target, self._align_with_axes(arg_keys, kw_keys), False)
def override(self, target, *arg_keys, **kw_keys):
self._register(target, self._align_with_axes(arg_keys, kw_keys), True)
def _register(self, target, keys, override):
tree_node = self._tree tree_node = self._tree
for key in keys: for key in self._align_with_axes(arg_keys, kw_keys):
tree_node = tree_node.setdefault(key, _TreeNode()) tree_node = tree_node.setdefault(key, _TreeNode[T]())
if not override and not tree_node.target is None: if not tree_node.target is None:
raise ValueError( raise ValueError(
"Registration conflicts with existing registration. Use " f"Registration for {target} conflicts with existing registration {tree_node.target}."
"override method to override.") )
tree_node.target = target tree_node.target = target
def get_registration(self, *arg_keys, **kw_keys): def get_registration(self, *arg_keys: K, **kw_keys: K) -> Optional[T]:
tree_node = self._tree tree_node = self._tree
for key in self._align_with_axes(arg_keys, kw_keys): for key in self._align_with_axes(arg_keys, kw_keys):
if not tree_node.has_key(key): if not key in tree_node:
return None return None
tree_node = tree_node[key] tree_node = tree_node[key]
return tree_node.target return tree_node.target
def lookup(self, *arg_objs, **kw_objs): def lookup(self, *arg_objs: V, **kw_objs: V) -> Optional[T]:
return next(self.query(*arg_objs, **kw_objs), None)
def query(self, *arg_objs: V, **kw_objs: V) -> Generator[Optional[T], None, None]:
objs = self._align_with_axes(arg_objs, kw_objs) objs = self._align_with_axes(arg_objs, kw_objs)
axes = self._axes axes = self._axes
return self._lookup(self._tree, objs, axes) return self._query(self._tree, objs, axes)
def _lookup(self, tree_node, objs, axes): def _query(
self, tree_node: _TreeNode[T], objs: Sequence[Optional[V]], axes: Sequence[Axis]
) -> Generator[Optional[T], None, None]:
""" Recursively traverse registration tree, from left to right, most """ Recursively traverse registration tree, from left to right, most
specific to least specific, returning the first target found on a specific to least specific, returning the first target found on a
matching node. """ matching node. """
if not objs: if not objs:
return tree_node.target yield tree_node.target
else:
obj = objs[0]
obj = objs[0] # Skip non-participating nodes
if obj is None:
next_node: Optional[_TreeNode[T]] = tree_node.get(None, None)
if next_node is not None:
yield from self._query(next_node, objs[1:], axes[1:])
else:
# Get matches on this axis and iterate from most to least specific
axis = axes[0]
for match_key in axis.matches(obj, tree_node.keys()):
yield from self._query(tree_node[match_key], objs[1:], axes[1:])
# Skip non-participating nodes def _align_with_axes(
if obj is None: self, args: Sequence[S], kw: Dict[str, S]
next_node = tree_node.get(None, None) ) -> Sequence[Optional[S]]:
if next_node is not None:
return self._lookup(next_node, objs[1:], axes[1:])
return None
# Get matches on this axis and iterate from most to least specific
axis = axes[0]
for match_key in axis.matches(obj, tree_node.keys()):
target = self._lookup(tree_node[match_key], objs[1:], axes[1:])
if target is not None:
return target
return None
def _align_with_axes(self, args, kw):
""" Create a list matching up all args and kwargs with their """ Create a list matching up all args and kwargs with their
corresponding axes, in order, using ``None`` as a placeholder for corresponding axes, in order, using ``None`` as a placeholder for
skipped axes. """ skipped axes. """
axes_dict = self._axes_dict axes_dict = self._axes_dict
aligned = [None for i in xrange(len(axes_dict))] aligned: List[Optional[S]] = [None for i in range(len(axes_dict))]
args_len = len(args) args_len = len(args)
if args_len + len(kw) > len(aligned): if args_len + len(kw) > len(aligned):
raise ValueError("Cannot have more arguments than axes.") raise ValueError("Cannot have more arguments than axes.")
for i, arg in enumerate(args): for i, arg in enumerate(args):
@ -91,12 +108,13 @@ class Registry(object):
for k, v in kw.items(): for k, v in kw.items():
i_axis = axes_dict.get(k, None) i_axis = axes_dict.get(k, None)
if i_axis is None: if i_axis is None:
raise ValueError("No axis with name: %s" % k) raise ValueError(f"No axis with name: {k}")
i, axis = i_axis i, axis = i_axis
if aligned[i] is not None: if aligned[i] is not None:
raise ValueError("Axis defined twice between positional and " raise ValueError(
"keyword arguments") "Axis defined twice between positional and " "keyword arguments"
)
aligned[i] = v aligned[i] = v
@ -106,13 +124,15 @@ class Registry(object):
return aligned return aligned
class _TreeNode(dict):
target = None
def __str__(self): class _TreeNode(Generic[T], Dict[Any, Any]):
return "<TreeNode %s %s>" % (self.target, dict.__str__(self)) target: Optional[T] = None
class SimpleAxis(object): def __str__(self) -> str:
return f"<TreeNode {self.target} {dict.__str__(self)}>"
class SimpleAxis:
""" A simple axis where the key into the axis is the same as the object to """ A simple axis where the key into the axis is the same as the object to
be matched (aka the identity axis). This axis behaves just like a be matched (aka the identity axis). This axis behaves just like a
dictionary. You might use this axis if you are interested in registering dictionary. You might use this axis if you are interested in registering
@ -122,21 +142,23 @@ class SimpleAxis(object):
Subclasses can override the ``get_keys`` method for implementing arbitrary Subclasses can override the ``get_keys`` method for implementing arbitrary
axes. axes.
""" """
def matches(self, obj, keys):
for key in self.get_keys(obj): def matches(
self, obj: object, keys: KeysView[Optional[object]]
) -> Generator[object, None, None]:
for key in [obj]:
if key in keys: if key in keys:
yield key yield obj
def get_keys(self, obj):
"""
Return the keys for the given object that could match this axis, from
most specific to least specific. A convenient override point.
"""
return [obj,]
class TypeAxis(SimpleAxis): class TypeAxis:
""" An axis which matches the class and super classes of an object in """ An axis which matches the class and super classes of an object in
method resolution order. method resolution order.
""" """
def get_keys(self, obj):
return type(obj).mro() def matches(
self, obj: object, keys: KeysView[Optional[type]]
) -> Generator[type, None, None]:
for key in type(obj).mro():
if key in keys:
yield key

View File

@ -1,150 +1,181 @@
""" Tests for :module:`generic.event`.""" """ Tests for :module:`generic.event`."""
import unittest from __future__ import annotations
__all__ = ("ManagerTests",) from typing import Callable, List
from generic.event import Manager
class ManagerTests(unittest.TestCase):
def makeHandler(self, effect): def make_handler(effect: object) -> Callable[[Event], None]:
return lambda e: e.effects.append(effect) return lambda e: e.effects.append(effect)
def createManager(self):
from generic.event import Manager
return Manager()
def test_subscribe_single_event(self): def create_manager():
events = self.createManager() return Manager()
events.subscribe(self.makeHandler("handler1"), EventA)
e = EventA()
events.fire(e)
self.assertEqual(len(e.effects), 1)
self.assertTrue("handler1" in e.effects)
def test_subscribe_via_decorator(self):
events = self.createManager()
events.subscriber(EventA)(self.makeHandler("handler1"))
e = EventA()
events.fire(e)
self.assertEqual(len(e.effects), 1)
self.assertTrue("handler1" in e.effects)
def test_subscribe_event_inheritance(self): def test_subscribe_single_event():
events = self.createManager() events = create_manager()
events.subscribe(self.makeHandler("handler1"), EventA) events.subscribe(make_handler("handler1"), EventA)
events.subscribe(self.makeHandler("handler2"), EventB) e = EventA()
events.handle(e)
assert len(e.effects) == 1
assert "handler1" in e.effects
ea = EventA()
events.fire(ea)
self.assertEqual(len(ea.effects), 1)
self.assertTrue("handler1" in ea.effects)
eb = EventB() def test_subscribe_via_decorator():
events.fire(eb) events = create_manager()
self.assertEqual(len(eb.effects), 2) events.subscriber(EventA)(make_handler("handler1"))
self.assertTrue("handler1" in eb.effects) e = EventA()
self.assertTrue("handler2" in eb.effects) events.handle(e)
assert len(e.effects) == 1
assert "handler1" in e.effects
def test_subscribe_event_multiple_inheritance(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), EventA)
events.subscribe(self.makeHandler("handler2"), EventC)
events.subscribe(self.makeHandler("handler3"), EventD)
ea = EventA() def test_subscribe_event_inheritance():
events.fire(ea) events = create_manager()
self.assertEqual(len(ea.effects), 1) events.subscribe(make_handler("handler1"), EventA)
self.assertTrue("handler1" in ea.effects) events.subscribe(make_handler("handler2"), EventB)
ec = EventC() ea = EventA()
events.fire(ec) events.handle(ea)
self.assertEqual(len(ec.effects), 1) assert len(ea.effects) == 1
self.assertTrue("handler2" in ec.effects) assert "handler1" in ea.effects
ed = EventD() eb = EventB()
events.fire(ed) events.handle(eb)
self.assertEqual(len(ed.effects), 3) assert len(eb.effects) == 2
self.assertTrue("handler1" in ed.effects) assert "handler1" in eb.effects
self.assertTrue("handler2" in ed.effects) assert "handler2" in eb.effects
self.assertTrue("handler3" in ed.effects)
def test_subscribe_event_malformed_multiple_inheritance(self):
events = self.createManager()
events.subscribe(self.makeHandler("handler1"), EventA)
events.subscribe(self.makeHandler("handler2"), EventD)
events.subscribe(self.makeHandler("handler3"), EventE)
ea = EventA() def test_subscribe_event_multiple_inheritance():
events.fire(ea) events = create_manager()
self.assertEqual(len(ea.effects), 1) events.subscribe(make_handler("handler1"), EventA)
self.assertTrue("handler1" in ea.effects) events.subscribe(make_handler("handler2"), EventC)
events.subscribe(make_handler("handler3"), EventD)
ed = EventD() ea = EventA()
events.fire(ed) events.handle(ea)
self.assertEqual(len(ed.effects), 2) assert len(ea.effects) == 1
self.assertTrue("handler1" in ed.effects) assert "handler1" in ea.effects
self.assertTrue("handler2" in ed.effects)
ee = EventE() ec = EventC()
events.fire(ee) events.handle(ec)
self.assertEqual(len(ee.effects), 3) assert len(ec.effects) == 1
self.assertTrue("handler1" in ee.effects) assert "handler2" in ec.effects
self.assertTrue("handler2" in ee.effects)
self.assertTrue("handler3" in ee.effects)
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(self): ed = EventD()
events = self.createManager() events.handle(ed)
events.subscribe(self.makeHandler("handler1"), Event) assert len(ed.effects) == 3
events.subscribe(self.makeHandler("handler2"), EventB) assert "handler1" in ed.effects
assert "handler2" in ed.effects
assert "handler3" in ed.effects
eb = EventB()
events.fire(eb)
self.assertEqual(len(eb.effects), 2)
self.assertTrue("handler1" in eb.effects)
self.assertTrue("handler2" in eb.effects)
def test_unsubscribe_single_event(self): def test_subscribe_no_events():
events = self.createManager() events = create_manager()
handler = self.makeHandler("handler1")
events.subscribe(handler, EventA)
events.unsubscribe(handler, EventA)
e = EventA()
events.fire(e)
self.assertEqual(len(e.effects), 0)
def test_unsubscribe_event_inheritance(self): ea = EventA()
events = self.createManager() events.handle(ea)
handler1 = self.makeHandler("handler1") assert len(ea.effects) == 0
handler2 = self.makeHandler("handler2")
events.subscribe(handler1, EventA)
events.subscribe(handler2, EventB)
events.unsubscribe(handler1, EventA)
ea = EventA()
events.fire(ea)
self.assertEqual(len(ea.effects), 0)
eb = EventB() def test_subscribe_base_event():
events.fire(eb) events = create_manager()
self.assertEqual(len(eb.effects), 1) events.subscribe(make_handler("handler1"), EventA)
self.assertTrue("handler2" in eb.effects)
class Event(object): ea = EventB()
events.handle(ea)
assert len(ea.effects) == 1
assert "handler1" in ea.effects
def test_subscribe_event_malformed_multiple_inheritance():
events = create_manager()
events.subscribe(make_handler("handler1"), EventA)
events.subscribe(make_handler("handler2"), EventD)
events.subscribe(make_handler("handler3"), EventE)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 1
assert "handler1" in ea.effects
ed = EventD()
events.handle(ed)
assert len(ed.effects) == 2
assert "handler1" in ed.effects
assert "handler2" in ed.effects
ee = EventE()
events.handle(ee)
assert len(ee.effects) == 3
assert "handler1" in ee.effects
assert "handler2" in ee.effects
assert "handler3" in ee.effects
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro():
events = create_manager()
events.subscribe(make_handler("handler1"), Event)
events.subscribe(make_handler("handler2"), EventB)
eb = EventB()
events.handle(eb)
assert len(eb.effects) == 2
assert "handler1" in eb.effects
assert "handler2" in eb.effects
def test_unsubscribe_single_event():
events = create_manager()
handler = make_handler("handler1")
events.subscribe(handler, EventA)
events.unsubscribe(handler, EventA)
e = EventA()
events.handle(e)
assert len(e.effects) == 0
def test_unsubscribe_event_inheritance():
events = create_manager()
handler1 = make_handler("handler1")
handler2 = make_handler("handler2")
events.subscribe(handler1, EventA)
events.subscribe(handler2, EventB)
events.unsubscribe(handler1, EventA)
ea = EventA()
events.handle(ea)
assert len(ea.effects) == 0
eb = EventB()
events.handle(eb)
assert len(eb.effects) == 1
assert "handler2" in eb.effects
class Event:
def __init__(self) -> None:
self.effects: List[object] = []
def __init__(self):
self.effects = []
class EventA(Event): class EventA(Event):
pass pass
class EventB(EventA): class EventB(EventA):
pass pass
class EventC(Event): class EventC(Event):
pass pass
class EventD(EventA, EventC): class EventD(EventA, EventC):
pass pass
class EventE(EventD, EventA): class EventE(EventD, EventA):
pass pass

View File

@ -1,270 +1,224 @@
""" Tests for :module:`generic.multidispatch`.""" """ Tests for :module:`generic.multidispatch`."""
import unittest import pytest
__all__ = ("DispatcherTests",) from inspect import FullArgSpec
from generic.multidispatch import multidispatch, FunctionDispatcher
class DispatcherTests(unittest.TestCase):
def createDispatcher(self, params_arity, args=None, varargs=None,
keywords=None, defaults=None):
from inspect import ArgSpec
from generic.multidispatch import FunctionDispatcher
return FunctionDispatcher(ArgSpec(args=args, varargs=varargs,
keywords=keywords,
defaults=defaults), params_arity)
def test_one_argument(self): def create_dispatcher(
dispatcher = self.createDispatcher(1, args=["x"]) params_arity, args=None, varargs=None, keywords=None, defaults=None
) -> FunctionDispatcher:
dispatcher.register_rule(lambda x: x + 1, int) return FunctionDispatcher(
self.assertEqual(dispatcher(1), 2) FullArgSpec(
self.assertRaises(TypeError, dispatcher, "s") args=args,
varargs=varargs,
varkw=keywords,
defaults=defaults,
kwonlyargs=[],
kwonlydefaults={},
annotations={},
),
params_arity,
)
dispatcher.register_rule(lambda x: x + "1", str)
self.assertEqual(dispatcher(1), 2)
self.assertEqual(dispatcher("1"), "11")
self.assertRaises(TypeError, dispatcher, tuple())
def test_two_arguments(self): def test_one_argument():
dispatcher = self.createDispatcher(2, args=["x", "y"]) dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda x, y: x + y + 1, int, int) dispatcher.register_rule(lambda x: x + 1, int)
self.assertEqual(dispatcher(1, 2), 4) assert dispatcher(1) == 2
self.assertRaises(TypeError, dispatcher, "s", "ss") with pytest.raises(TypeError):
self.assertRaises(TypeError, dispatcher, 1, "ss") dispatcher("s")
self.assertRaises(TypeError, dispatcher, "s", 2)
dispatcher.register_rule(lambda x, y: x + y + "1", str, str) dispatcher.register_rule(lambda x: x + "1", str)
self.assertEqual(dispatcher(1, 2), 4) assert dispatcher(1) == 2
self.assertEqual(dispatcher("1", "2"), "121") assert dispatcher("1") == "11"
self.assertRaises(TypeError, dispatcher, "1", 1) with pytest.raises(TypeError):
self.assertRaises(TypeError, dispatcher, 1, "1") dispatcher(tuple())
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
self.assertEqual(dispatcher(1, 2), 4)
self.assertEqual(dispatcher("1", "2"), "121")
self.assertEqual(dispatcher(1, "2"), "121")
self.assertRaises(TypeError, dispatcher, "1", 1)
def test_bottom_rule(self): def test_two_arguments():
dispatcher = self.createDispatcher(1, args=["x"]) dispatcher = create_dispatcher(2, args=["x", "y"])
dispatcher.register_rule(lambda x: x, object) dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
self.assertEqual(dispatcher(1), 1) assert dispatcher(1, 2) == 4
self.assertEqual(dispatcher("1"), "1") with pytest.raises(TypeError):
self.assertEqual(dispatcher([1]), [1]) dispatcher("s", "ss")
self.assertEqual(dispatcher((1,)), (1,)) with pytest.raises(TypeError):
dispatcher(1, "ss")
with pytest.raises(TypeError):
dispatcher("s", 2)
def test_subtype_evaluation(self): dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
class Super(object): assert dispatcher(1, 2) == 4
pass assert dispatcher("1", "2") == "121"
class Sub(Super): with pytest.raises(TypeError):
pass dispatcher("1", 1)
with pytest.raises(TypeError):
dispatcher(1, "1")
dispatcher = self.createDispatcher(1, args=["x"]) dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
assert dispatcher(1, 2) == 4
assert dispatcher("1", "2") == "121"
assert dispatcher(1, "2") == "121"
with pytest.raises(TypeError):
dispatcher("1", 1)
dispatcher.register_rule(lambda x: x, Super)
o_super = Super()
self.assertEqual(dispatcher(o_super), o_super)
o_sub = Sub()
self.assertEqual(dispatcher(o_sub), o_sub)
self.assertRaises(TypeError, dispatcher, object())
dispatcher.register_rule(lambda x: (x, x), Sub) def test_bottom_rule():
o_super = Super() dispatcher = create_dispatcher(1, args=["x"])
self.assertEqual(dispatcher(o_super), o_super)
o_sub = Sub()
self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
def test_register_rule_with_wrong_arity(self): dispatcher.register_rule(lambda x: x, object)
dispatcher = self.createDispatcher(1, args=["x"]) assert dispatcher(1) == 1
dispatcher.register_rule(lambda x: x, int) assert dispatcher("1") == "1"
self.assertRaises( assert dispatcher([1]) == [1]
TypeError, assert dispatcher((1,)) == (1,)
dispatcher.register_rule, lambda x, y: x, str)
def test_register_rule_with_different_arg_names(self):
dispatcher = self.createDispatcher(1, args=["x"])
dispatcher.register_rule(lambda y: y, int)
self.assertEqual(dispatcher(1), 1)
def test_dispatching_with_varargs(self): def test_subtype_evaluation():
dispatcher = self.createDispatcher(1, args=["x"], varargs="va") class Super:
dispatcher.register_rule(lambda x, *va: x, int) pass
self.assertEqual(dispatcher(1), 1)
self.assertRaises(TypeError, dispatcher, "1", 2, 3)
def test_dispatching_with_varkw(self): class Sub(Super):
dispatcher = self.createDispatcher(1, args=["x"], keywords="vk") pass
dispatcher.register_rule(lambda x, **vk: x, int)
self.assertEqual(dispatcher(1), 1)
self.assertRaises(TypeError, dispatcher, "1", a=1, b=2)
def test_dispatching_with_kw(self): dispatcher = create_dispatcher(1, args=["x"])
dispatcher = self.createDispatcher(1, args=["x", "y"], defaults=["vk"])
dispatcher.register_rule(lambda x, y=1: x, int)
self.assertEqual(dispatcher(1), 1)
self.assertRaises(TypeError, dispatcher, "1", k=1)
def test_create_dispatcher_with_pos_args_less_multi_arity(self): dispatcher.register_rule(lambda x: x, Super)
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x"]) o_super = Super()
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x", "y"], assert dispatcher(o_super) == o_super
defaults=["x"]) o_sub = Sub()
assert dispatcher(o_sub) == o_sub
with pytest.raises(TypeError):
dispatcher(object())
def test_register_rule_with_wrong_number_types_parameters(self): dispatcher.register_rule(lambda x: (x, x), Sub)
dispatcher = self.createDispatcher(1, args=["x", "y"]) o_super = Super()
self.assertRaises( assert dispatcher(o_super) == o_super
TypeError, o_sub = Sub()
dispatcher.register_rule, lambda x, y: x, int, str) assert dispatcher(o_sub) == (o_sub, o_sub)
def test_register_rule_with_partial_dispatching(self):
dispatcher = self.createDispatcher(1, args=["x", "y"]) def test_register_rule_with_wrong_arity():
dispatcher.register_rule(lambda x, y: x, int) dispatcher = create_dispatcher(1, args=["x"])
self.assertEqual(dispatcher(1, 2), 1) dispatcher.register_rule(lambda x: x, int)
self.assertEqual(dispatcher(1, "2"), 1) with pytest.raises(TypeError):
self.assertRaises(TypeError, dispatcher, "2", 1)
dispatcher.register_rule(lambda x, y: x, str) dispatcher.register_rule(lambda x, y: x, str)
self.assertEqual(dispatcher(1, 2), 1)
self.assertEqual(dispatcher(1, "2"), 1)
self.assertEqual(dispatcher("1", "2"), "1")
self.assertEqual(dispatcher("1", 2), "1")
class MultifunctionTests(unittest.TestCase):
def test_it(self): def test_register_rule_with_different_arg_names():
from generic.multidispatch import multifunction dispatcher = create_dispatcher(1, args=["x"])
dispatcher.register_rule(lambda y: y, int)
assert dispatcher(1) == 1
@multifunction(int, str)
def func(x, y):
return str(x) + y
self.assertEqual(func(1, "2"), "12") def test_dispatching_with_varargs():
self.assertRaises(TypeError, func, 1, 2) dispatcher = create_dispatcher(1, args=["x"], varargs="va")
self.assertRaises(TypeError, func, "1", 2) dispatcher.register_rule(lambda x, *va: x, int)
self.assertRaises(TypeError, func, "1", "2") assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", 2, 3)
@func.when(str, str)
def func(x, y):
return x + y
self.assertEqual(func(1, "2"), "12") def test_dispatching_with_varkw():
self.assertEqual(func("1", "2"), "12") dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
self.assertRaises(TypeError, func, 1, 2) dispatcher.register_rule(lambda x, **vk: x, int)
self.assertRaises(TypeError, func, "1", 2) assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", a=1, b=2)
def test_overriding(self):
from generic.multidispatch import multifunction
@multifunction(int, str) def test_dispatching_with_kw():
def func(x, y): dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
return str(x) + y dispatcher.register_rule(lambda x, y=1: x, int)
assert dispatcher(1) == 1
with pytest.raises(TypeError):
dispatcher("1", k=1)
self.assertEqual(func(1, "2"), "12")
self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
@func.override(int, str) def test_create_dispatcher_with_pos_args_less_multi_arity():
def func(x, y): with pytest.raises(TypeError):
return y + str(x) create_dispatcher(2, args=["x"])
with pytest.raises(TypeError):
create_dispatcher(2, args=["x", "y"], defaults=["x"])
self.assertEqual(func(1, "2"), "21")
class MultimethodTests(unittest.TestCase): def test_register_rule_with_wrong_number_types_parameters():
dispatcher = create_dispatcher(1, args=["x", "y"])
with pytest.raises(TypeError):
dispatcher.register_rule(lambda x, y: x, int, str)
def test_multimethod(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods def test_register_rule_with_partial_dispatching():
class Dummy(object): dispatcher = create_dispatcher(1, args=["x", "y"])
dispatcher.register_rule(lambda x, y: x, int)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
with pytest.raises(TypeError):
dispatcher("2", 1)
dispatcher.register_rule(lambda x, y: x, str)
assert dispatcher(1, 2) == 1
assert dispatcher(1, "2") == 1
assert dispatcher("1", "2") == "1"
assert dispatcher("1", 2) == "1"
@multimethod(int)
def foo(self, x):
return x + 1
@foo.when(str) def test_default_dispatcher():
def foo(self, x): @multidispatch(int, str)
return x + "1" def func(x, y):
return str(x) + y
self.assertEqual(Dummy().foo(1), 2) assert func(1, "2") == "12"
self.assertEqual(Dummy().foo("1"), "11") with pytest.raises(TypeError):
self.assertRaises(TypeError, Dummy().foo, []) func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
with pytest.raises(TypeError):
func("1", "2")
def test_inheritance(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods def test_multiple_functions():
class Dummy(object): @multidispatch(int, str)
def func(x, y):
return str(x) + y
@multimethod(int) @func.register(str, str)
def foo(self, x): def _(x, y):
return x + 1 return x + y
@foo.when(float) assert func(1, "2") == "12"
def foo(self, x): assert func("1", "2") == "12"
return x + 1.5 with pytest.raises(TypeError):
func(1, 2)
with pytest.raises(TypeError):
func("1", 2)
@has_multimethods
class DummySub(Dummy):
@Dummy.foo.when(str) def test_default():
def foo(self, x): @multidispatch()
return x + "1" def func(x, y):
return x + y
@foo.when(tuple) @func.register(str, str)
def foo(self, x): def _(x, y):
return x + (1,) return y + x
@Dummy.foo.when(bool) assert func(1, 1) == 2
def foo(self, x): assert func("1", "2") == "21"
return not x
self.assertEqual(Dummy().foo(1), 2)
self.assertEqual(Dummy().foo(1.5), 3.0)
self.assertRaises(TypeError, Dummy().foo, "1")
self.assertEqual(DummySub().foo(1), 2)
self.assertEqual(DummySub().foo(1.5), 3.0)
self.assertEqual(DummySub().foo("1"), "11")
self.assertEqual(DummySub().foo((1,2)), (1,2,1))
self.assertEqual(DummySub().foo(True), False)
self.assertRaises(TypeError, DummySub().foo, [])
def test_override(self): def test_on_classes():
from generic.multidispatch import multimethod @multidispatch()
from generic.multidispatch import has_multimethods class A:
def __init__(self, a, b):
self.v = a + b
@has_multimethods @A.register(str, str) # type: ignore[attr-defined]
class Dummy(object): class B:
def __init__(self, a, b):
self.v = b + a
@multimethod(str, str) assert A(1, 1).v == 2
def foo(self, x, y): assert A("1", "2").v == "21"
return x + y
@foo.when(str, str)
def foo(self, x, y):
return y + x
self.assertEqual(Dummy().foo("1", "2"), "21")
def test_inheritance_override(self):
from generic.multidispatch import multimethod
from generic.multidispatch import has_multimethods
@has_multimethods
class Dummy(object):
@multimethod(int)
def foo(self, x):
return x + 1
@has_multimethods
class DummySub(Dummy):
@Dummy.foo.when(int)
def foo(self, x):
return x + 2
self.assertEqual(Dummy().foo(1), 2)
self.assertEqual(DummySub().foo(1), 3)

View File

@ -1,135 +1,147 @@
""" Tests for :module:`generic.registry`.""" """ Tests for :module:`generic.registry`."""
import unittest import pytest
__all__ = ("RegistryTests",) from typing import Union
from generic.registry import Registry, SimpleAxis, TypeAxis
class RegistryTests(unittest.TestCase):
def test_one_axis_no_specificity(self): class DummyA:
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('foo', SimpleAxis()))
a = object()
b = object()
registry.register(a)
registry.register(b, 'foo')
self.assertEqual(registry.lookup(), a)
self.assertEqual(registry.lookup('foo'), b)
self.assertEqual(registry.lookup('bar'), None)
def test_two_axes(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
from generic.registry import TypeAxis
registry = Registry(('type', TypeAxis()),
('name', SimpleAxis()))
target1 = Target('one')
registry.register(target1, object)
target2 = Target('two')
registry.register(target2, DummyA)
target3 = Target('three')
registry.register(target3, DummyA, 'foo')
context1 = object()
self.assertEqual(registry.lookup(context1), target1)
context2 = DummyB()
self.assertEqual(registry.lookup(context2), target2)
self.assertEqual(registry.lookup(context2, 'foo'), target3)
target4 = object()
registry.register(target4, DummyB)
self.assertEqual(registry.lookup(context2), target4)
self.assertEqual(registry.lookup(context2, 'foo'), target3)
def test_get_registration(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
from generic.registry import TypeAxis
registry = Registry(('type', TypeAxis()),
('name', SimpleAxis()))
registry.register('one', object)
registry.register('two', DummyA, 'foo')
self.assertEqual(registry.get_registration(object), 'one')
self.assertEqual(registry.get_registration(DummyA, 'foo'), 'two')
self.assertEqual(registry.get_registration(object, 'foo'), None)
self.assertEqual(registry.get_registration(DummyA), None)
def test_register_too_many_keys(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
self.assertRaises(ValueError, registry.register, object(),
'one', 'two')
def test_lookup_too_many_keys(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
self.assertRaises(ValueError, registry.lookup, 'one', 'two')
def test_conflict_error(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
registry.register(object(), name='foo')
self.assertRaises(ValueError, registry.register, object(), 'foo')
def test_override(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()))
registry.register(1, name='foo')
registry.override(2, name='foo')
self.assertEqual(registry.lookup('foo'), 2)
def test_skip_nodes(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(
('one', SimpleAxis()),
('two', SimpleAxis()),
('three', SimpleAxis())
)
registry.register('foo', one=1, three=3)
self.assertEqual(registry.lookup(1, three=3), 'foo')
def test_miss(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(
('one', SimpleAxis()),
('two', SimpleAxis()),
('three', SimpleAxis())
)
registry.register('foo', 1, 2)
self.assertEqual(registry.lookup(one=1, three=3), None)
def test_bad_lookup(self):
from generic.registry import Registry
from generic.registry import SimpleAxis
registry = Registry(('name', SimpleAxis()),
('grade', SimpleAxis()))
self.assertRaises(ValueError, registry.register, 1, foo=1)
self.assertRaises(ValueError, registry.lookup, foo=1)
self.assertRaises(ValueError, registry.register, 1, 'foo', name='foo')
class DummyA(object):
pass pass
class DummyB(DummyA): class DummyB(DummyA):
pass pass
class Target(object):
def __init__(self, name):
self.name = name
# Only called if being printed due to a failing test def test_one_axis_no_specificity():
def __repr__(self): #pragma NO COVERAGE registry: Registry[object] = Registry(("foo", SimpleAxis()))
return "Target('%s')" % self.name a = object()
b = object()
registry.register(a)
registry.register(b, "foo")
assert registry.lookup() == a
assert registry.lookup("foo") == b
assert registry.lookup("bar") is None
def test_subtyping_on_axes():
registry: Registry[str] = Registry(("type", TypeAxis()))
target1 = "one"
registry.register(target1, object)
target2 = "two"
registry.register(target2, DummyA)
target3 = "three"
registry.register(target3, DummyB)
assert registry.lookup(object()) == target1
assert registry.lookup(DummyA()) == target2
assert registry.lookup(DummyB()) == target3
def test_query_subtyping_on_axes():
registry: Registry[str] = Registry(("type", TypeAxis()))
target1 = "one"
registry.register(target1, object)
target2 = "two"
registry.register(target2, DummyA)
target3 = "three"
registry.register(target3, DummyB)
target4 = "four"
registry.register(target4, int)
assert list(registry.query(object())) == [target1]
assert list(registry.query(DummyA())) == [target2, target1]
assert list(registry.query(DummyB())) == [target3, target2, target1]
assert list(registry.query(3)) == [target4, target1]
def test_two_axes():
registry: Registry[Union[str, object]] = Registry(
("type", TypeAxis()), ("name", SimpleAxis())
)
target1 = "one"
registry.register(target1, object)
target2 = "two"
registry.register(target2, DummyA)
target3 = "three"
registry.register(target3, DummyA, "foo")
context1 = object()
assert registry.lookup(context1) == target1
context2 = DummyB()
assert registry.lookup(context2) == target2
assert registry.lookup(context2, "foo") == target3
target4 = object()
registry.register(target4, DummyB)
assert registry.lookup(context2) == target4
assert registry.lookup(context2, "foo") == target3
def test_get_registration():
registry: Registry[str] = Registry(("type", TypeAxis()), ("name", SimpleAxis()))
registry.register("one", object)
registry.register("two", DummyA, "foo")
assert registry.get_registration(object) == "one"
assert registry.get_registration(DummyA, "foo") == "two"
assert registry.get_registration(object, "foo") is None
assert registry.get_registration(DummyA) is None
def test_register_too_many_keys():
registry: Registry[type] = Registry(("name", SimpleAxis()))
with pytest.raises(ValueError):
registry.register(object, "one", "two")
def test_lookup_too_many_keys():
registry: Registry[object] = Registry(("name", SimpleAxis()))
with pytest.raises(ValueError):
registry.register(registry.lookup("one", "two"))
def test_conflict_error():
registry: Registry[Union[object, type]] = Registry(("name", SimpleAxis()))
registry.register(object(), name="foo")
with pytest.raises(ValueError):
registry.register(object, "foo")
def test_skip_nodes():
registry: Registry[str] = Registry(
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
)
registry.register("foo", one=1, three=3)
assert registry.lookup(1, three=3) == "foo"
def test_miss():
registry: Registry[str] = Registry(
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
)
registry.register("foo", 1, 2)
assert registry.lookup(one=1, three=3) is None
def test_bad_lookup():
registry: Registry[int] = Registry(("name", SimpleAxis()), ("grade", SimpleAxis()))
with pytest.raises(ValueError):
registry.register(1, foo=1)
with pytest.raises(ValueError):
registry.lookup(foo=1)
with pytest.raises(ValueError):
registry.register(1, "foo", name="foo")