Replace old code by Python3 implementation from Gaphor
Multimethods are missing now. Those have to be added back.
This commit is contained in:
parent
ee35c528cb
commit
bb20f6992b
105
generic/event.py
105
generic/event.py
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
This module provides API for event management. There are two APIs provided:
|
This module provides API for event management. There are two APIs provided:
|
||||||
|
|
||||||
* Global event management API: subscribe, unsubscribe, fire.
|
* Global event management API: subscribe, unsubscribe, handle.
|
||||||
* Local event management API: Manager
|
* Local event management API: Manager
|
||||||
|
|
||||||
If you run only one instance of your application per Python
|
If you run only one instance of your application per Python
|
||||||
@ -12,116 +12,83 @@ to have different configurations for them -- you should use local API
|
|||||||
and have one instance of Manager object per application instance.
|
and have one instance of Manager object per application instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import namedtuple
|
from typing import Callable, Set, Type
|
||||||
|
|
||||||
from generic.registry import Registry
|
from generic.registry import Registry, TypeAxis
|
||||||
from generic.registry import TypeAxis
|
|
||||||
|
|
||||||
__all__ = ("Manager", "subscribe", "unsubscribe", "fire", "subscriber")
|
|
||||||
|
|
||||||
class HandlerSet(namedtuple("HandlerSet", ["parents", "handlers"])):
|
__all__ = "Manager"
|
||||||
""" Set of handlers for specific type of event.
|
|
||||||
|
|
||||||
This object stores ``handlers`` for specific event type and
|
Event = object
|
||||||
``parents`` reference to handler sets of event's supertypes.
|
Handler = Callable[[object], None]
|
||||||
"""
|
HandlerSet = Set[Handler]
|
||||||
|
|
||||||
@property
|
|
||||||
def all_handlers(self):
|
|
||||||
""" Iterate over own and supertypes' handlers.
|
|
||||||
|
|
||||||
This iterator yields just unique values, so it won't yield the
|
class Manager:
|
||||||
same handler twice, even if it was registered both for some
|
|
||||||
event type and its supertype.
|
|
||||||
"""
|
|
||||||
seen = set()
|
|
||||||
seen_add = seen.add
|
|
||||||
|
|
||||||
# yield own handlers first
|
|
||||||
for handler in self.handlers:
|
|
||||||
seen_add(handler)
|
|
||||||
yield handler
|
|
||||||
|
|
||||||
# yield supertypes' handlers then
|
|
||||||
for parent in self.parents:
|
|
||||||
for handler in parent.all_handlers:
|
|
||||||
if not handler in seen:
|
|
||||||
seen_add(handler)
|
|
||||||
yield handler
|
|
||||||
|
|
||||||
class Manager(object):
|
|
||||||
""" Event manager
|
""" Event manager
|
||||||
|
|
||||||
Provides API for subscribing for and firing events. There's also global
|
Provides API for subscribing for and firing events. There's also global
|
||||||
event manager instantiated at module level with functions
|
event manager instantiated at module level with functions
|
||||||
:func:`.subscribe`, :func:`.fire` and decorator :func:`.subscriber` aliased
|
:func:`.subscribe`, :func:`.handle` and decorator :func:`.subscriber` aliased
|
||||||
to corresponding methods of class.
|
to corresponding methods of class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
registry: Registry[HandlerSet]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
axes = (("event_type", TypeAxis()),)
|
axes = (("event_type", TypeAxis()),)
|
||||||
self.registry = Registry(*axes)
|
self.registry = Registry(*axes)
|
||||||
|
|
||||||
def subscribe(self, handler, event_type):
|
def subscribe(self, handler: Handler, event_type: Type[Event]) -> None:
|
||||||
""" Subscribe ``handler`` to specified ``event_type``"""
|
""" Subscribe ``handler`` to specified ``event_type``"""
|
||||||
handler_set = self.registry.get_registration(event_type)
|
handler_set = self.registry.get_registration(event_type)
|
||||||
if not handler_set:
|
if handler_set is None:
|
||||||
handler_set = self._register_handler_set(event_type)
|
handler_set = self._register_handler_set(event_type)
|
||||||
handler_set.handlers.add(handler)
|
handler_set.add(handler)
|
||||||
|
|
||||||
def unsubscribe(self, handler, event_type):
|
def unsubscribe(self, handler: Handler, event_type: Type[Event]) -> None:
|
||||||
""" Unsubscribe ``handler`` from ``event_type``"""
|
""" Unsubscribe ``handler`` from ``event_type``"""
|
||||||
handler_set = self.registry.get_registration(event_type)
|
handler_set = self.registry.get_registration(event_type)
|
||||||
if handler_set and handler in handler_set.handlers:
|
if handler_set and handler in handler_set:
|
||||||
handler_set.handlers.remove(handler)
|
handler_set.remove(handler)
|
||||||
|
|
||||||
def fire(self, event):
|
def handle(self, event: Event) -> None:
|
||||||
""" Fire ``event``
|
""" Fire ``event``
|
||||||
|
|
||||||
All subscribers will be executed with no determined order.
|
All subscribers will be executed with no determined order.
|
||||||
"""
|
"""
|
||||||
handler_set = self.registry.lookup(event)
|
handler_sets = self.registry.query(event)
|
||||||
for handler in handler_set.all_handlers:
|
for handler_set in handler_sets:
|
||||||
handler(event)
|
if handler_set:
|
||||||
|
for handler in set(handler_set):
|
||||||
|
handler(event)
|
||||||
|
|
||||||
def _register_handler_set(self, event_type):
|
def _register_handler_set(self, event_type: Type[Event]) -> HandlerSet:
|
||||||
""" Register new handler set for ``event_type``."""
|
""" Register new handler set for ``event_type``.
|
||||||
# Collect handler sets for supertypes
|
"""
|
||||||
parent_handler_sets = []
|
handler_set: HandlerSet = set()
|
||||||
parents = event_type.__bases__
|
|
||||||
for parent in parents:
|
|
||||||
parent_handlers = self.registry.get_registration(parent)
|
|
||||||
if parent_handlers is None:
|
|
||||||
parent_handlers = self._register_handler_set(parent)
|
|
||||||
parent_handler_sets.append(parent_handlers)
|
|
||||||
|
|
||||||
handler_set = HandlerSet(parents=parent_handler_sets, handlers=set())
|
|
||||||
self.registry.register(handler_set, event_type)
|
self.registry.register(handler_set, event_type)
|
||||||
return handler_set
|
return handler_set
|
||||||
|
|
||||||
def subscriber(self, event_type):
|
def subscriber(self, event_type: Type[Event]) -> Callable[[Handler], Handler]:
|
||||||
""" Decorator for subscribing handlers
|
""" Decorator for subscribing handlers
|
||||||
|
|
||||||
Works like this:
|
Works like this:
|
||||||
|
|
||||||
|
>>> mymanager = Manager()
|
||||||
|
>>> class MyEvent():
|
||||||
|
... pass
|
||||||
>>> @mymanager.subscriber(MyEvent)
|
>>> @mymanager.subscriber(MyEvent)
|
||||||
... def mysubscriber(evt):
|
... def mysubscriber(evt):
|
||||||
... # handle event
|
... # handle event
|
||||||
... return
|
... return
|
||||||
|
|
||||||
>>> mymanager.fire(MyEvent())
|
>>> mymanager.handle(MyEvent())
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def registrator(func):
|
|
||||||
|
def registrator(func: Handler) -> Handler:
|
||||||
self.subscribe(func, event_type)
|
self.subscribe(func, event_type)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return registrator
|
return registrator
|
||||||
|
|
||||||
# Global event manager
|
|
||||||
_global_manager = Manager()
|
|
||||||
|
|
||||||
# Global event management API
|
|
||||||
subscribe = _global_manager.subscribe
|
|
||||||
unsubscribe = _global_manager.unsubscribe
|
|
||||||
fire = _global_manager.fire
|
|
||||||
subscriber = _global_manager.subscriber
|
|
||||||
|
@ -1,206 +1,132 @@
|
|||||||
""" Multidispatch for functions and methods"""
|
""" Multidispatch for functions and methods.
|
||||||
|
|
||||||
|
This code is a Python 3, slimmed down version of the
|
||||||
|
generic package by Andrey Popp.
|
||||||
|
|
||||||
|
Only the generic function code is left in tact -- no generic methods.
|
||||||
|
The interface has been made in line with `functools.singledispatch`.
|
||||||
|
|
||||||
|
Note that this module does not support annotated functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import cast, Any, Callable, Generic, TypeVar, Union
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import types
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from generic.registry import Registry
|
from generic.registry import Registry, TypeAxis
|
||||||
from generic.registry import TypeAxis
|
|
||||||
|
|
||||||
__all__ = ("multifunction", "multimethod", "has_multimethods")
|
__all__ = "multidispatch"
|
||||||
|
|
||||||
def multifunction(*argtypes):
|
T = TypeVar("T", bound=Union[Callable[..., Any], type])
|
||||||
""" Declare function as multifunction
|
KeyType = Union[type, None]
|
||||||
|
|
||||||
|
|
||||||
|
def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]:
|
||||||
|
""" Declare function as multidispatch
|
||||||
|
|
||||||
This decorator takes ``argtypes`` argument types and replace decorated
|
This decorator takes ``argtypes`` argument types and replace decorated
|
||||||
function with :class:`.FunctionDispatcher` object, which is responsible for
|
function with :class:`.FunctionDispatcher` object, which is responsible for
|
||||||
multiple dispatch feature.
|
multiple dispatch feature.
|
||||||
"""
|
"""
|
||||||
def _replace_with_dispatcher(func):
|
|
||||||
dispatcher = _make_dispatcher(FunctionDispatcher, func, len(argtypes))
|
def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]:
|
||||||
|
nonlocal argtypes
|
||||||
|
argspec = inspect.getfullargspec(func)
|
||||||
|
if not argtypes:
|
||||||
|
arity = _arity(argspec)
|
||||||
|
if isinstance(func, type):
|
||||||
|
# It's a class we deal with:
|
||||||
|
arity -= 1
|
||||||
|
argtypes = (object,) * arity
|
||||||
|
|
||||||
|
dispatcher = cast(
|
||||||
|
FunctionDispatcher[T],
|
||||||
|
functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func),
|
||||||
|
)
|
||||||
dispatcher.register_rule(func, *argtypes)
|
dispatcher.register_rule(func, *argtypes)
|
||||||
return dispatcher
|
return dispatcher
|
||||||
|
|
||||||
return _replace_with_dispatcher
|
return _replace_with_dispatcher
|
||||||
|
|
||||||
def multimethod(*argtypes):
|
|
||||||
""" Declare method as multimethod
|
|
||||||
|
|
||||||
This decorator works exactly the same as :func:`.multifunction` decorator
|
class FunctionDispatcher(Generic[T]):
|
||||||
but replaces decorated method with :class:`.MethodDispatcher` object
|
|
||||||
instead.
|
|
||||||
|
|
||||||
Should be used only for decorating methods and enclosing class should have
|
|
||||||
:func:`.has_multimethods` decorator.
|
|
||||||
"""
|
|
||||||
def _replace_with_dispatcher(func):
|
|
||||||
dispatcher = _make_dispatcher(MethodDispatcher, func, len(argtypes) + 1)
|
|
||||||
dispatcher.register_unbound_rule(func, *argtypes)
|
|
||||||
return dispatcher
|
|
||||||
return _replace_with_dispatcher
|
|
||||||
|
|
||||||
def has_multimethods(cls):
|
|
||||||
""" Declare class as one that have multimethods
|
|
||||||
|
|
||||||
Should only be used for decorating classes which have methods decorated with
|
|
||||||
:func:`.multimethod` decorator.
|
|
||||||
"""
|
|
||||||
for name, obj in cls.__dict__.items():
|
|
||||||
if isinstance(obj, MethodDispatcher):
|
|
||||||
obj.proceed_unbound_rules(cls)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
class FunctionDispatcher(object):
|
|
||||||
""" Multidispatcher for functions
|
""" Multidispatcher for functions
|
||||||
|
|
||||||
This object dispatch calls to function by its argument types. Usually it is
|
This object dispatch calls to function by its argument types. Usually it is
|
||||||
produced by :func:`.multifunction` decorator.
|
produced by :func:`.multidispatch` decorator.
|
||||||
|
|
||||||
You should not manually create objects of this type.
|
You should not manually create objects of this type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, argspec, params_arity):
|
registry: Registry[T]
|
||||||
|
|
||||||
|
def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None:
|
||||||
""" Initialize dispatcher with ``argspec`` of type
|
""" Initialize dispatcher with ``argspec`` of type
|
||||||
:class:`inspect.ArgSpec` and ``params_arity`` that represent number
|
:class:`inspect.ArgSpec` and ``params_arity`` that represent number
|
||||||
params."""
|
params."""
|
||||||
# Check if we have enough positional arguments for number of type params
|
# Check if we have enough positional arguments for number of type params
|
||||||
if arity(argspec) < params_arity:
|
if _arity(argspec) < params_arity:
|
||||||
raise TypeError("Not enough positional arguments "
|
raise TypeError(
|
||||||
"for number of type parameters provided.")
|
"Not enough positional arguments "
|
||||||
|
"for number of type parameters provided."
|
||||||
|
)
|
||||||
|
|
||||||
self.argspec = argspec
|
self.argspec = argspec
|
||||||
self.params_arity = params_arity
|
self.params_arity = params_arity
|
||||||
|
|
||||||
axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)]
|
axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)]
|
||||||
self.registry = Registry(*axis)
|
self.registry = Registry(*axis)
|
||||||
|
|
||||||
def check_rule(self, rule, *argtypes):
|
def check_rule(self, rule: T, *argtypes: KeyType) -> None:
|
||||||
# Check if we have the right number of parametrized types
|
# Check if we have the right number of parametrized types
|
||||||
if len(argtypes) != self.params_arity:
|
if len(argtypes) != self.params_arity:
|
||||||
raise TypeError("Wrong number of type parameters.")
|
raise TypeError(
|
||||||
|
f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}."
|
||||||
|
)
|
||||||
|
|
||||||
# Check if we have the same argspec (by number of args)
|
# Check if we have the same argspec (by number of args)
|
||||||
rule_argspec = inspect.getargspec(rule)
|
rule_argspec = inspect.getfullargspec(rule)
|
||||||
if not is_equalent_argspecs(rule_argspec, self.argspec):
|
left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4])
|
||||||
raise TypeError("Rule does not conform "
|
right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4])
|
||||||
"to previous implementations.")
|
if left_spec != right_spec:
|
||||||
|
raise TypeError(
|
||||||
|
f"Rule does not conform to previous implementations: {left_spec} != {right_spec}."
|
||||||
|
)
|
||||||
|
|
||||||
def register_rule(self, rule, *argtypes):
|
def register_rule(self, rule: T, *argtypes: KeyType) -> None:
|
||||||
""" Register new ``rule`` for ``argtypes``."""
|
""" Register new ``rule`` for ``argtypes``."""
|
||||||
self.check_rule(rule, *argtypes)
|
self.check_rule(rule, *argtypes)
|
||||||
self.registry.register(rule, *argtypes)
|
self.registry.register(rule, *argtypes)
|
||||||
|
|
||||||
def override_rule(self, rule, *argtypes):
|
def register(self, *argtypes: KeyType) -> Callable[[T], T]:
|
||||||
""" Override ``rule`` for ``argtypes``."""
|
""" Decorator for registering new case for multidispatch
|
||||||
self.check_rule(rule, *argtypes)
|
|
||||||
self.registry.override(rule, *argtypes)
|
|
||||||
|
|
||||||
def lookup_rule(self, *args):
|
|
||||||
""" Lookup rule by ``args``. Returns None if no rule was found."""
|
|
||||||
args = args[:self.params_arity]
|
|
||||||
rule = self.registry.lookup(*args)
|
|
||||||
if rule is None:
|
|
||||||
raise TypeError("No available rule found for %r" % (args,))
|
|
||||||
return rule
|
|
||||||
|
|
||||||
def when(self, *argtypes):
|
|
||||||
""" Decorator for registering new case for multifunction
|
|
||||||
|
|
||||||
New case will be registered for types identified by ``argtypes``. The
|
New case will be registered for types identified by ``argtypes``. The
|
||||||
length of ``argtypes`` should be equal to the length of ``argtypes``
|
length of ``argtypes`` should be equal to the length of ``argtypes``
|
||||||
argument were passed corresponding :func:`.multifunction` call, which
|
argument were passed corresponding :func:`.multidispatch` call, which
|
||||||
also indicated the number of arguments multifunction dispatches on.
|
also indicated the number of arguments multidispatch dispatches on.
|
||||||
"""
|
"""
|
||||||
def register_rule(func):
|
|
||||||
|
def register_rule(func: T) -> T:
|
||||||
self.register_rule(func, *argtypes)
|
self.register_rule(func, *argtypes)
|
||||||
return self
|
return func
|
||||||
|
|
||||||
return register_rule
|
return register_rule
|
||||||
|
|
||||||
@property
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
def otherwise(self):
|
|
||||||
""" Decorator which registeres "catch-all" case for multifunction"""
|
|
||||||
def register_rule(func):
|
|
||||||
self.register_rule(func, [object]*self.params_arity)
|
|
||||||
return self
|
|
||||||
return register_rule
|
|
||||||
|
|
||||||
def override(self, *argtypes):
|
|
||||||
""" Decorator for overriding case for ``argtypes``"""
|
|
||||||
def override_rule(func):
|
|
||||||
self.override_rule(func, *argtypes)
|
|
||||||
return self
|
|
||||||
return override_rule
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
""" Dispatch call to appropriate rule."""
|
""" Dispatch call to appropriate rule."""
|
||||||
rule = self.lookup_rule(*args)
|
trimmed_args = args[: self.params_arity]
|
||||||
|
rule = self.registry.lookup(*trimmed_args)
|
||||||
|
if not rule:
|
||||||
|
raise TypeError(f"No available rule found for {trimmed_args!r}")
|
||||||
return rule(*args, **kwargs)
|
return rule(*args, **kwargs)
|
||||||
|
|
||||||
class MethodDispatcher(FunctionDispatcher):
|
|
||||||
""" Multiple dispatch for methods
|
|
||||||
|
|
||||||
This object dispatch call to method by its class and arguments types.
|
def _arity(argspec: inspect.FullArgSpec) -> int:
|
||||||
Usually it is produced by :func:`.multimethod` decorator.
|
|
||||||
|
|
||||||
You should not manually create objects of this type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, argspec, params_arity):
|
|
||||||
FunctionDispatcher.__init__(self, argspec, params_arity)
|
|
||||||
|
|
||||||
# some data, that should be local to thread of execution
|
|
||||||
self.local = threading.local()
|
|
||||||
self.local.unbound_rules = []
|
|
||||||
|
|
||||||
def register_unbound_rule(self, func, *argtypes):
|
|
||||||
""" Register unbound rule that should be processed by
|
|
||||||
``proceed_unbound_rules`` later."""
|
|
||||||
self.local.unbound_rules.append((argtypes, func))
|
|
||||||
|
|
||||||
def proceed_unbound_rules(self, cls):
|
|
||||||
""" Process all unbound rule by binding them to ``cls`` type."""
|
|
||||||
for argtypes, func in self.local.unbound_rules:
|
|
||||||
argtypes = (cls,) + argtypes
|
|
||||||
self.override_rule(func, *argtypes)
|
|
||||||
self.local.unbound_rules = []
|
|
||||||
|
|
||||||
def __get__(self, obj, cls):
|
|
||||||
if obj is None:
|
|
||||||
return self
|
|
||||||
return types.MethodType(self, obj)
|
|
||||||
|
|
||||||
def when(self, *argtypes):
|
|
||||||
""" Register new case for multimethod for ``argtypes``"""
|
|
||||||
def make_declaration(meth):
|
|
||||||
self.register_unbound_rule(meth, *argtypes)
|
|
||||||
return self
|
|
||||||
return make_declaration
|
|
||||||
|
|
||||||
def override(self, *argtypes):
|
|
||||||
""" Decorator for overriding case for ``argtypes``"""
|
|
||||||
return self.when(*argtypes)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def otherwise(self):
|
|
||||||
""" Decorator which registeres "catch-all" case for multimethod"""
|
|
||||||
def make_declaration(func):
|
|
||||||
self.register_unbound_rule(func, [object]*self.params_arity)
|
|
||||||
return self
|
|
||||||
return make_declaration
|
|
||||||
|
|
||||||
def arity(argspec):
|
|
||||||
""" Determinal positional arity of argspec."""
|
""" Determinal positional arity of argspec."""
|
||||||
args = argspec.args if argspec.args else []
|
args = argspec.args if argspec.args else []
|
||||||
defaults = argspec.defaults if argspec.defaults else []
|
defaults = argspec.defaults if argspec.defaults else []
|
||||||
return len(args) - len(defaults)
|
return len(args) - len(defaults)
|
||||||
|
|
||||||
def is_equalent_argspecs(left, right):
|
|
||||||
""" Check argspec equalence."""
|
|
||||||
return map(lambda x: len(x) if x else 0, left) == \
|
|
||||||
map(lambda x: len(x) if x else 0, right)
|
|
||||||
|
|
||||||
def _make_dispatcher(dispacther_cls, func, params_arity):
|
|
||||||
argspec = inspect.getargspec(func)
|
|
||||||
wrapper = functools.wraps(func)
|
|
||||||
dispatcher = wrapper(dispacther_cls(argspec, params_arity))
|
|
||||||
return dispatcher
|
|
||||||
|
@ -5,84 +5,101 @@ This implementation was borrowed from happy[1] project by Chris Rossi.
|
|||||||
[1]: http://bitbucket.org/chrisrossi/happy
|
[1]: http://bitbucket.org/chrisrossi/happy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
__all__ = ("Registry", "SimpleAxis", "TypeAxis")
|
__all__ = ("Registry", "SimpleAxis", "TypeAxis")
|
||||||
|
|
||||||
class Registry(object):
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
KeysView,
|
||||||
|
List,
|
||||||
|
Generator,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
K = TypeVar("K")
|
||||||
|
S = TypeVar("S")
|
||||||
|
T = TypeVar("T")
|
||||||
|
V = TypeVar("V")
|
||||||
|
Axis = Union["SimpleAxis", "TypeAxis"]
|
||||||
|
|
||||||
|
|
||||||
|
class Registry(Generic[T]):
|
||||||
""" Registry implementation."""
|
""" Registry implementation."""
|
||||||
|
|
||||||
def __init__(self, *axes):
|
def __init__(self, *axes: Tuple[str, Axis]):
|
||||||
self._tree = _TreeNode()
|
self._tree: _TreeNode[T] = _TreeNode()
|
||||||
self._axes = [axis for name, axis in axes]
|
self._axes = [axis for name, axis in axes]
|
||||||
self._axes_dict = dict([
|
self._axes_dict = {name: (i, axis) for i, (name, axis) in enumerate(axes)}
|
||||||
(name, (i, axis)) for i, (name, axis) in enumerate(axes)
|
|
||||||
])
|
|
||||||
|
|
||||||
def register(self, target, *arg_keys, **kw_keys):
|
def register(self, target: T, *arg_keys: K, **kw_keys: K) -> None:
|
||||||
self._register(target, self._align_with_axes(arg_keys, kw_keys), False)
|
|
||||||
|
|
||||||
def override(self, target, *arg_keys, **kw_keys):
|
|
||||||
self._register(target, self._align_with_axes(arg_keys, kw_keys), True)
|
|
||||||
|
|
||||||
def _register(self, target, keys, override):
|
|
||||||
tree_node = self._tree
|
tree_node = self._tree
|
||||||
for key in keys:
|
for key in self._align_with_axes(arg_keys, kw_keys):
|
||||||
tree_node = tree_node.setdefault(key, _TreeNode())
|
tree_node = tree_node.setdefault(key, _TreeNode[T]())
|
||||||
|
|
||||||
if not override and not tree_node.target is None:
|
if not tree_node.target is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Registration conflicts with existing registration. Use "
|
f"Registration for {target} conflicts with existing registration {tree_node.target}."
|
||||||
"override method to override.")
|
)
|
||||||
|
|
||||||
tree_node.target = target
|
tree_node.target = target
|
||||||
|
|
||||||
def get_registration(self, *arg_keys, **kw_keys):
|
def get_registration(self, *arg_keys: K, **kw_keys: K) -> Optional[T]:
|
||||||
tree_node = self._tree
|
tree_node = self._tree
|
||||||
for key in self._align_with_axes(arg_keys, kw_keys):
|
for key in self._align_with_axes(arg_keys, kw_keys):
|
||||||
if not tree_node.has_key(key):
|
if not key in tree_node:
|
||||||
return None
|
return None
|
||||||
tree_node = tree_node[key]
|
tree_node = tree_node[key]
|
||||||
|
|
||||||
return tree_node.target
|
return tree_node.target
|
||||||
|
|
||||||
def lookup(self, *arg_objs, **kw_objs):
|
def lookup(self, *arg_objs: V, **kw_objs: V) -> Optional[T]:
|
||||||
|
return next(self.query(*arg_objs, **kw_objs), None)
|
||||||
|
|
||||||
|
def query(self, *arg_objs: V, **kw_objs: V) -> Generator[Optional[T], None, None]:
|
||||||
objs = self._align_with_axes(arg_objs, kw_objs)
|
objs = self._align_with_axes(arg_objs, kw_objs)
|
||||||
axes = self._axes
|
axes = self._axes
|
||||||
return self._lookup(self._tree, objs, axes)
|
return self._query(self._tree, objs, axes)
|
||||||
|
|
||||||
def _lookup(self, tree_node, objs, axes):
|
def _query(
|
||||||
|
self, tree_node: _TreeNode[T], objs: Sequence[Optional[V]], axes: Sequence[Axis]
|
||||||
|
) -> Generator[Optional[T], None, None]:
|
||||||
""" Recursively traverse registration tree, from left to right, most
|
""" Recursively traverse registration tree, from left to right, most
|
||||||
specific to least specific, returning the first target found on a
|
specific to least specific, returning the first target found on a
|
||||||
matching node. """
|
matching node. """
|
||||||
if not objs:
|
if not objs:
|
||||||
return tree_node.target
|
yield tree_node.target
|
||||||
|
else:
|
||||||
|
obj = objs[0]
|
||||||
|
|
||||||
obj = objs[0]
|
# Skip non-participating nodes
|
||||||
|
if obj is None:
|
||||||
|
next_node: Optional[_TreeNode[T]] = tree_node.get(None, None)
|
||||||
|
if next_node is not None:
|
||||||
|
yield from self._query(next_node, objs[1:], axes[1:])
|
||||||
|
else:
|
||||||
|
# Get matches on this axis and iterate from most to least specific
|
||||||
|
axis = axes[0]
|
||||||
|
for match_key in axis.matches(obj, tree_node.keys()):
|
||||||
|
yield from self._query(tree_node[match_key], objs[1:], axes[1:])
|
||||||
|
|
||||||
# Skip non-participating nodes
|
def _align_with_axes(
|
||||||
if obj is None:
|
self, args: Sequence[S], kw: Dict[str, S]
|
||||||
next_node = tree_node.get(None, None)
|
) -> Sequence[Optional[S]]:
|
||||||
if next_node is not None:
|
|
||||||
return self._lookup(next_node, objs[1:], axes[1:])
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get matches on this axis and iterate from most to least specific
|
|
||||||
axis = axes[0]
|
|
||||||
for match_key in axis.matches(obj, tree_node.keys()):
|
|
||||||
target = self._lookup(tree_node[match_key], objs[1:], axes[1:])
|
|
||||||
if target is not None:
|
|
||||||
return target
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _align_with_axes(self, args, kw):
|
|
||||||
""" Create a list matching up all args and kwargs with their
|
""" Create a list matching up all args and kwargs with their
|
||||||
corresponding axes, in order, using ``None`` as a placeholder for
|
corresponding axes, in order, using ``None`` as a placeholder for
|
||||||
skipped axes. """
|
skipped axes. """
|
||||||
axes_dict = self._axes_dict
|
axes_dict = self._axes_dict
|
||||||
aligned = [None for i in xrange(len(axes_dict))]
|
aligned: List[Optional[S]] = [None for i in range(len(axes_dict))]
|
||||||
|
|
||||||
args_len = len(args)
|
args_len = len(args)
|
||||||
if args_len + len(kw) > len(aligned):
|
if args_len + len(kw) > len(aligned):
|
||||||
raise ValueError("Cannot have more arguments than axes.")
|
raise ValueError("Cannot have more arguments than axes.")
|
||||||
|
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
@ -91,12 +108,13 @@ class Registry(object):
|
|||||||
for k, v in kw.items():
|
for k, v in kw.items():
|
||||||
i_axis = axes_dict.get(k, None)
|
i_axis = axes_dict.get(k, None)
|
||||||
if i_axis is None:
|
if i_axis is None:
|
||||||
raise ValueError("No axis with name: %s" % k)
|
raise ValueError(f"No axis with name: {k}")
|
||||||
|
|
||||||
i, axis = i_axis
|
i, axis = i_axis
|
||||||
if aligned[i] is not None:
|
if aligned[i] is not None:
|
||||||
raise ValueError("Axis defined twice between positional and "
|
raise ValueError(
|
||||||
"keyword arguments")
|
"Axis defined twice between positional and " "keyword arguments"
|
||||||
|
)
|
||||||
|
|
||||||
aligned[i] = v
|
aligned[i] = v
|
||||||
|
|
||||||
@ -106,13 +124,15 @@ class Registry(object):
|
|||||||
|
|
||||||
return aligned
|
return aligned
|
||||||
|
|
||||||
class _TreeNode(dict):
|
|
||||||
target = None
|
|
||||||
|
|
||||||
def __str__(self):
|
class _TreeNode(Generic[T], Dict[Any, Any]):
|
||||||
return "<TreeNode %s %s>" % (self.target, dict.__str__(self))
|
target: Optional[T] = None
|
||||||
|
|
||||||
class SimpleAxis(object):
|
def __str__(self) -> str:
|
||||||
|
return f"<TreeNode {self.target} {dict.__str__(self)}>"
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleAxis:
|
||||||
""" A simple axis where the key into the axis is the same as the object to
|
""" A simple axis where the key into the axis is the same as the object to
|
||||||
be matched (aka the identity axis). This axis behaves just like a
|
be matched (aka the identity axis). This axis behaves just like a
|
||||||
dictionary. You might use this axis if you are interested in registering
|
dictionary. You might use this axis if you are interested in registering
|
||||||
@ -122,21 +142,23 @@ class SimpleAxis(object):
|
|||||||
Subclasses can override the ``get_keys`` method for implementing arbitrary
|
Subclasses can override the ``get_keys`` method for implementing arbitrary
|
||||||
axes.
|
axes.
|
||||||
"""
|
"""
|
||||||
def matches(self, obj, keys):
|
|
||||||
for key in self.get_keys(obj):
|
def matches(
|
||||||
|
self, obj: object, keys: KeysView[Optional[object]]
|
||||||
|
) -> Generator[object, None, None]:
|
||||||
|
for key in [obj]:
|
||||||
if key in keys:
|
if key in keys:
|
||||||
yield key
|
yield obj
|
||||||
|
|
||||||
def get_keys(self, obj):
|
|
||||||
"""
|
|
||||||
Return the keys for the given object that could match this axis, from
|
|
||||||
most specific to least specific. A convenient override point.
|
|
||||||
"""
|
|
||||||
return [obj,]
|
|
||||||
|
|
||||||
class TypeAxis(SimpleAxis):
|
class TypeAxis:
|
||||||
""" An axis which matches the class and super classes of an object in
|
""" An axis which matches the class and super classes of an object in
|
||||||
method resolution order.
|
method resolution order.
|
||||||
"""
|
"""
|
||||||
def get_keys(self, obj):
|
|
||||||
return type(obj).mro()
|
def matches(
|
||||||
|
self, obj: object, keys: KeysView[Optional[type]]
|
||||||
|
) -> Generator[type, None, None]:
|
||||||
|
for key in type(obj).mro():
|
||||||
|
if key in keys:
|
||||||
|
yield key
|
||||||
|
@ -1,150 +1,181 @@
|
|||||||
""" Tests for :module:`generic.event`."""
|
""" Tests for :module:`generic.event`."""
|
||||||
|
|
||||||
import unittest
|
from __future__ import annotations
|
||||||
|
|
||||||
__all__ = ("ManagerTests",)
|
from typing import Callable, List
|
||||||
|
from generic.event import Manager
|
||||||
|
|
||||||
class ManagerTests(unittest.TestCase):
|
|
||||||
|
|
||||||
def makeHandler(self, effect):
|
def make_handler(effect: object) -> Callable[[Event], None]:
|
||||||
return lambda e: e.effects.append(effect)
|
return lambda e: e.effects.append(effect)
|
||||||
|
|
||||||
def createManager(self):
|
|
||||||
from generic.event import Manager
|
|
||||||
return Manager()
|
|
||||||
|
|
||||||
def test_subscribe_single_event(self):
|
def create_manager():
|
||||||
events = self.createManager()
|
return Manager()
|
||||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
|
||||||
e = EventA()
|
|
||||||
events.fire(e)
|
|
||||||
self.assertEqual(len(e.effects), 1)
|
|
||||||
self.assertTrue("handler1" in e.effects)
|
|
||||||
|
|
||||||
def test_subscribe_via_decorator(self):
|
|
||||||
events = self.createManager()
|
|
||||||
events.subscriber(EventA)(self.makeHandler("handler1"))
|
|
||||||
e = EventA()
|
|
||||||
events.fire(e)
|
|
||||||
self.assertEqual(len(e.effects), 1)
|
|
||||||
self.assertTrue("handler1" in e.effects)
|
|
||||||
|
|
||||||
def test_subscribe_event_inheritance(self):
|
def test_subscribe_single_event():
|
||||||
events = self.createManager()
|
events = create_manager()
|
||||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
events.subscribe(make_handler("handler1"), EventA)
|
||||||
events.subscribe(self.makeHandler("handler2"), EventB)
|
e = EventA()
|
||||||
|
events.handle(e)
|
||||||
|
assert len(e.effects) == 1
|
||||||
|
assert "handler1" in e.effects
|
||||||
|
|
||||||
ea = EventA()
|
|
||||||
events.fire(ea)
|
|
||||||
self.assertEqual(len(ea.effects), 1)
|
|
||||||
self.assertTrue("handler1" in ea.effects)
|
|
||||||
|
|
||||||
eb = EventB()
|
def test_subscribe_via_decorator():
|
||||||
events.fire(eb)
|
events = create_manager()
|
||||||
self.assertEqual(len(eb.effects), 2)
|
events.subscriber(EventA)(make_handler("handler1"))
|
||||||
self.assertTrue("handler1" in eb.effects)
|
e = EventA()
|
||||||
self.assertTrue("handler2" in eb.effects)
|
events.handle(e)
|
||||||
|
assert len(e.effects) == 1
|
||||||
|
assert "handler1" in e.effects
|
||||||
|
|
||||||
def test_subscribe_event_multiple_inheritance(self):
|
|
||||||
events = self.createManager()
|
|
||||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
|
||||||
events.subscribe(self.makeHandler("handler2"), EventC)
|
|
||||||
events.subscribe(self.makeHandler("handler3"), EventD)
|
|
||||||
|
|
||||||
ea = EventA()
|
def test_subscribe_event_inheritance():
|
||||||
events.fire(ea)
|
events = create_manager()
|
||||||
self.assertEqual(len(ea.effects), 1)
|
events.subscribe(make_handler("handler1"), EventA)
|
||||||
self.assertTrue("handler1" in ea.effects)
|
events.subscribe(make_handler("handler2"), EventB)
|
||||||
|
|
||||||
ec = EventC()
|
ea = EventA()
|
||||||
events.fire(ec)
|
events.handle(ea)
|
||||||
self.assertEqual(len(ec.effects), 1)
|
assert len(ea.effects) == 1
|
||||||
self.assertTrue("handler2" in ec.effects)
|
assert "handler1" in ea.effects
|
||||||
|
|
||||||
ed = EventD()
|
eb = EventB()
|
||||||
events.fire(ed)
|
events.handle(eb)
|
||||||
self.assertEqual(len(ed.effects), 3)
|
assert len(eb.effects) == 2
|
||||||
self.assertTrue("handler1" in ed.effects)
|
assert "handler1" in eb.effects
|
||||||
self.assertTrue("handler2" in ed.effects)
|
assert "handler2" in eb.effects
|
||||||
self.assertTrue("handler3" in ed.effects)
|
|
||||||
|
|
||||||
def test_subscribe_event_malformed_multiple_inheritance(self):
|
|
||||||
events = self.createManager()
|
|
||||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
|
||||||
events.subscribe(self.makeHandler("handler2"), EventD)
|
|
||||||
events.subscribe(self.makeHandler("handler3"), EventE)
|
|
||||||
|
|
||||||
ea = EventA()
|
def test_subscribe_event_multiple_inheritance():
|
||||||
events.fire(ea)
|
events = create_manager()
|
||||||
self.assertEqual(len(ea.effects), 1)
|
events.subscribe(make_handler("handler1"), EventA)
|
||||||
self.assertTrue("handler1" in ea.effects)
|
events.subscribe(make_handler("handler2"), EventC)
|
||||||
|
events.subscribe(make_handler("handler3"), EventD)
|
||||||
|
|
||||||
ed = EventD()
|
ea = EventA()
|
||||||
events.fire(ed)
|
events.handle(ea)
|
||||||
self.assertEqual(len(ed.effects), 2)
|
assert len(ea.effects) == 1
|
||||||
self.assertTrue("handler1" in ed.effects)
|
assert "handler1" in ea.effects
|
||||||
self.assertTrue("handler2" in ed.effects)
|
|
||||||
|
|
||||||
ee = EventE()
|
ec = EventC()
|
||||||
events.fire(ee)
|
events.handle(ec)
|
||||||
self.assertEqual(len(ee.effects), 3)
|
assert len(ec.effects) == 1
|
||||||
self.assertTrue("handler1" in ee.effects)
|
assert "handler2" in ec.effects
|
||||||
self.assertTrue("handler2" in ee.effects)
|
|
||||||
self.assertTrue("handler3" in ee.effects)
|
|
||||||
|
|
||||||
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(self):
|
ed = EventD()
|
||||||
events = self.createManager()
|
events.handle(ed)
|
||||||
events.subscribe(self.makeHandler("handler1"), Event)
|
assert len(ed.effects) == 3
|
||||||
events.subscribe(self.makeHandler("handler2"), EventB)
|
assert "handler1" in ed.effects
|
||||||
|
assert "handler2" in ed.effects
|
||||||
|
assert "handler3" in ed.effects
|
||||||
|
|
||||||
eb = EventB()
|
|
||||||
events.fire(eb)
|
|
||||||
self.assertEqual(len(eb.effects), 2)
|
|
||||||
self.assertTrue("handler1" in eb.effects)
|
|
||||||
self.assertTrue("handler2" in eb.effects)
|
|
||||||
|
|
||||||
def test_unsubscribe_single_event(self):
|
def test_subscribe_no_events():
|
||||||
events = self.createManager()
|
events = create_manager()
|
||||||
handler = self.makeHandler("handler1")
|
|
||||||
events.subscribe(handler, EventA)
|
|
||||||
events.unsubscribe(handler, EventA)
|
|
||||||
e = EventA()
|
|
||||||
events.fire(e)
|
|
||||||
self.assertEqual(len(e.effects), 0)
|
|
||||||
|
|
||||||
def test_unsubscribe_event_inheritance(self):
|
ea = EventA()
|
||||||
events = self.createManager()
|
events.handle(ea)
|
||||||
handler1 = self.makeHandler("handler1")
|
assert len(ea.effects) == 0
|
||||||
handler2 = self.makeHandler("handler2")
|
|
||||||
events.subscribe(handler1, EventA)
|
|
||||||
events.subscribe(handler2, EventB)
|
|
||||||
events.unsubscribe(handler1, EventA)
|
|
||||||
|
|
||||||
ea = EventA()
|
|
||||||
events.fire(ea)
|
|
||||||
self.assertEqual(len(ea.effects), 0)
|
|
||||||
|
|
||||||
eb = EventB()
|
def test_subscribe_base_event():
|
||||||
events.fire(eb)
|
events = create_manager()
|
||||||
self.assertEqual(len(eb.effects), 1)
|
events.subscribe(make_handler("handler1"), EventA)
|
||||||
self.assertTrue("handler2" in eb.effects)
|
|
||||||
|
|
||||||
class Event(object):
|
ea = EventB()
|
||||||
|
events.handle(ea)
|
||||||
|
assert len(ea.effects) == 1
|
||||||
|
assert "handler1" in ea.effects
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_event_malformed_multiple_inheritance():
|
||||||
|
events = create_manager()
|
||||||
|
events.subscribe(make_handler("handler1"), EventA)
|
||||||
|
events.subscribe(make_handler("handler2"), EventD)
|
||||||
|
events.subscribe(make_handler("handler3"), EventE)
|
||||||
|
|
||||||
|
ea = EventA()
|
||||||
|
events.handle(ea)
|
||||||
|
assert len(ea.effects) == 1
|
||||||
|
assert "handler1" in ea.effects
|
||||||
|
|
||||||
|
ed = EventD()
|
||||||
|
events.handle(ed)
|
||||||
|
assert len(ed.effects) == 2
|
||||||
|
assert "handler1" in ed.effects
|
||||||
|
assert "handler2" in ed.effects
|
||||||
|
|
||||||
|
ee = EventE()
|
||||||
|
events.handle(ee)
|
||||||
|
assert len(ee.effects) == 3
|
||||||
|
assert "handler1" in ee.effects
|
||||||
|
assert "handler2" in ee.effects
|
||||||
|
assert "handler3" in ee.effects
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro():
|
||||||
|
events = create_manager()
|
||||||
|
events.subscribe(make_handler("handler1"), Event)
|
||||||
|
events.subscribe(make_handler("handler2"), EventB)
|
||||||
|
|
||||||
|
eb = EventB()
|
||||||
|
events.handle(eb)
|
||||||
|
assert len(eb.effects) == 2
|
||||||
|
assert "handler1" in eb.effects
|
||||||
|
assert "handler2" in eb.effects
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsubscribe_single_event():
|
||||||
|
events = create_manager()
|
||||||
|
handler = make_handler("handler1")
|
||||||
|
events.subscribe(handler, EventA)
|
||||||
|
events.unsubscribe(handler, EventA)
|
||||||
|
e = EventA()
|
||||||
|
events.handle(e)
|
||||||
|
assert len(e.effects) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsubscribe_event_inheritance():
|
||||||
|
events = create_manager()
|
||||||
|
handler1 = make_handler("handler1")
|
||||||
|
handler2 = make_handler("handler2")
|
||||||
|
events.subscribe(handler1, EventA)
|
||||||
|
events.subscribe(handler2, EventB)
|
||||||
|
events.unsubscribe(handler1, EventA)
|
||||||
|
|
||||||
|
ea = EventA()
|
||||||
|
events.handle(ea)
|
||||||
|
assert len(ea.effects) == 0
|
||||||
|
|
||||||
|
eb = EventB()
|
||||||
|
events.handle(eb)
|
||||||
|
assert len(eb.effects) == 1
|
||||||
|
assert "handler2" in eb.effects
|
||||||
|
|
||||||
|
|
||||||
|
class Event:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.effects: List[object] = []
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.effects = []
|
|
||||||
|
|
||||||
class EventA(Event):
|
class EventA(Event):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EventB(EventA):
|
class EventB(EventA):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EventC(Event):
|
class EventC(Event):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EventD(EventA, EventC):
|
class EventD(EventA, EventC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EventE(EventD, EventA):
|
class EventE(EventD, EventA):
|
||||||
pass
|
pass
|
||||||
|
@ -1,270 +1,224 @@
|
|||||||
""" Tests for :module:`generic.multidispatch`."""
|
""" Tests for :module:`generic.multidispatch`."""
|
||||||
|
|
||||||
import unittest
|
import pytest
|
||||||
|
|
||||||
__all__ = ("DispatcherTests",)
|
from inspect import FullArgSpec
|
||||||
|
from generic.multidispatch import multidispatch, FunctionDispatcher
|
||||||
class DispatcherTests(unittest.TestCase):
|
|
||||||
|
|
||||||
def createDispatcher(self, params_arity, args=None, varargs=None,
|
|
||||||
keywords=None, defaults=None):
|
|
||||||
from inspect import ArgSpec
|
|
||||||
from generic.multidispatch import FunctionDispatcher
|
|
||||||
return FunctionDispatcher(ArgSpec(args=args, varargs=varargs,
|
|
||||||
keywords=keywords,
|
|
||||||
defaults=defaults), params_arity)
|
|
||||||
|
|
||||||
|
|
||||||
def test_one_argument(self):
|
def create_dispatcher(
|
||||||
dispatcher = self.createDispatcher(1, args=["x"])
|
params_arity, args=None, varargs=None, keywords=None, defaults=None
|
||||||
|
) -> FunctionDispatcher:
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x: x + 1, int)
|
return FunctionDispatcher(
|
||||||
self.assertEqual(dispatcher(1), 2)
|
FullArgSpec(
|
||||||
self.assertRaises(TypeError, dispatcher, "s")
|
args=args,
|
||||||
|
varargs=varargs,
|
||||||
|
varkw=keywords,
|
||||||
|
defaults=defaults,
|
||||||
|
kwonlyargs=[],
|
||||||
|
kwonlydefaults={},
|
||||||
|
annotations={},
|
||||||
|
),
|
||||||
|
params_arity,
|
||||||
|
)
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x: x + "1", str)
|
|
||||||
self.assertEqual(dispatcher(1), 2)
|
|
||||||
self.assertEqual(dispatcher("1"), "11")
|
|
||||||
self.assertRaises(TypeError, dispatcher, tuple())
|
|
||||||
|
|
||||||
def test_two_arguments(self):
|
def test_one_argument():
|
||||||
dispatcher = self.createDispatcher(2, args=["x", "y"])
|
dispatcher = create_dispatcher(1, args=["x"])
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
|
dispatcher.register_rule(lambda x: x + 1, int)
|
||||||
self.assertEqual(dispatcher(1, 2), 4)
|
assert dispatcher(1) == 2
|
||||||
self.assertRaises(TypeError, dispatcher, "s", "ss")
|
with pytest.raises(TypeError):
|
||||||
self.assertRaises(TypeError, dispatcher, 1, "ss")
|
dispatcher("s")
|
||||||
self.assertRaises(TypeError, dispatcher, "s", 2)
|
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
|
dispatcher.register_rule(lambda x: x + "1", str)
|
||||||
self.assertEqual(dispatcher(1, 2), 4)
|
assert dispatcher(1) == 2
|
||||||
self.assertEqual(dispatcher("1", "2"), "121")
|
assert dispatcher("1") == "11"
|
||||||
self.assertRaises(TypeError, dispatcher, "1", 1)
|
with pytest.raises(TypeError):
|
||||||
self.assertRaises(TypeError, dispatcher, 1, "1")
|
dispatcher(tuple())
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
|
|
||||||
self.assertEqual(dispatcher(1, 2), 4)
|
|
||||||
self.assertEqual(dispatcher("1", "2"), "121")
|
|
||||||
self.assertEqual(dispatcher(1, "2"), "121")
|
|
||||||
self.assertRaises(TypeError, dispatcher, "1", 1)
|
|
||||||
|
|
||||||
def test_bottom_rule(self):
|
def test_two_arguments():
|
||||||
dispatcher = self.createDispatcher(1, args=["x"])
|
dispatcher = create_dispatcher(2, args=["x", "y"])
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x: x, object)
|
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
|
||||||
self.assertEqual(dispatcher(1), 1)
|
assert dispatcher(1, 2) == 4
|
||||||
self.assertEqual(dispatcher("1"), "1")
|
with pytest.raises(TypeError):
|
||||||
self.assertEqual(dispatcher([1]), [1])
|
dispatcher("s", "ss")
|
||||||
self.assertEqual(dispatcher((1,)), (1,))
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher(1, "ss")
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher("s", 2)
|
||||||
|
|
||||||
def test_subtype_evaluation(self):
|
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
|
||||||
class Super(object):
|
assert dispatcher(1, 2) == 4
|
||||||
pass
|
assert dispatcher("1", "2") == "121"
|
||||||
class Sub(Super):
|
with pytest.raises(TypeError):
|
||||||
pass
|
dispatcher("1", 1)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher(1, "1")
|
||||||
|
|
||||||
dispatcher = self.createDispatcher(1, args=["x"])
|
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
|
||||||
|
assert dispatcher(1, 2) == 4
|
||||||
|
assert dispatcher("1", "2") == "121"
|
||||||
|
assert dispatcher(1, "2") == "121"
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher("1", 1)
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x: x, Super)
|
|
||||||
o_super = Super()
|
|
||||||
self.assertEqual(dispatcher(o_super), o_super)
|
|
||||||
o_sub = Sub()
|
|
||||||
self.assertEqual(dispatcher(o_sub), o_sub)
|
|
||||||
self.assertRaises(TypeError, dispatcher, object())
|
|
||||||
|
|
||||||
dispatcher.register_rule(lambda x: (x, x), Sub)
|
def test_bottom_rule():
|
||||||
o_super = Super()
|
dispatcher = create_dispatcher(1, args=["x"])
|
||||||
self.assertEqual(dispatcher(o_super), o_super)
|
|
||||||
o_sub = Sub()
|
|
||||||
self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
|
|
||||||
|
|
||||||
def test_register_rule_with_wrong_arity(self):
|
dispatcher.register_rule(lambda x: x, object)
|
||||||
dispatcher = self.createDispatcher(1, args=["x"])
|
assert dispatcher(1) == 1
|
||||||
dispatcher.register_rule(lambda x: x, int)
|
assert dispatcher("1") == "1"
|
||||||
self.assertRaises(
|
assert dispatcher([1]) == [1]
|
||||||
TypeError,
|
assert dispatcher((1,)) == (1,)
|
||||||
dispatcher.register_rule, lambda x, y: x, str)
|
|
||||||
|
|
||||||
def test_register_rule_with_different_arg_names(self):
|
|
||||||
dispatcher = self.createDispatcher(1, args=["x"])
|
|
||||||
dispatcher.register_rule(lambda y: y, int)
|
|
||||||
self.assertEqual(dispatcher(1), 1)
|
|
||||||
|
|
||||||
def test_dispatching_with_varargs(self):
|
def test_subtype_evaluation():
|
||||||
dispatcher = self.createDispatcher(1, args=["x"], varargs="va")
|
class Super:
|
||||||
dispatcher.register_rule(lambda x, *va: x, int)
|
pass
|
||||||
self.assertEqual(dispatcher(1), 1)
|
|
||||||
self.assertRaises(TypeError, dispatcher, "1", 2, 3)
|
|
||||||
|
|
||||||
def test_dispatching_with_varkw(self):
|
class Sub(Super):
|
||||||
dispatcher = self.createDispatcher(1, args=["x"], keywords="vk")
|
pass
|
||||||
dispatcher.register_rule(lambda x, **vk: x, int)
|
|
||||||
self.assertEqual(dispatcher(1), 1)
|
|
||||||
self.assertRaises(TypeError, dispatcher, "1", a=1, b=2)
|
|
||||||
|
|
||||||
def test_dispatching_with_kw(self):
|
dispatcher = create_dispatcher(1, args=["x"])
|
||||||
dispatcher = self.createDispatcher(1, args=["x", "y"], defaults=["vk"])
|
|
||||||
dispatcher.register_rule(lambda x, y=1: x, int)
|
|
||||||
self.assertEqual(dispatcher(1), 1)
|
|
||||||
self.assertRaises(TypeError, dispatcher, "1", k=1)
|
|
||||||
|
|
||||||
def test_create_dispatcher_with_pos_args_less_multi_arity(self):
|
dispatcher.register_rule(lambda x: x, Super)
|
||||||
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x"])
|
o_super = Super()
|
||||||
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x", "y"],
|
assert dispatcher(o_super) == o_super
|
||||||
defaults=["x"])
|
o_sub = Sub()
|
||||||
|
assert dispatcher(o_sub) == o_sub
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher(object())
|
||||||
|
|
||||||
def test_register_rule_with_wrong_number_types_parameters(self):
|
dispatcher.register_rule(lambda x: (x, x), Sub)
|
||||||
dispatcher = self.createDispatcher(1, args=["x", "y"])
|
o_super = Super()
|
||||||
self.assertRaises(
|
assert dispatcher(o_super) == o_super
|
||||||
TypeError,
|
o_sub = Sub()
|
||||||
dispatcher.register_rule, lambda x, y: x, int, str)
|
assert dispatcher(o_sub) == (o_sub, o_sub)
|
||||||
|
|
||||||
def test_register_rule_with_partial_dispatching(self):
|
|
||||||
dispatcher = self.createDispatcher(1, args=["x", "y"])
|
def test_register_rule_with_wrong_arity():
|
||||||
dispatcher.register_rule(lambda x, y: x, int)
|
dispatcher = create_dispatcher(1, args=["x"])
|
||||||
self.assertEqual(dispatcher(1, 2), 1)
|
dispatcher.register_rule(lambda x: x, int)
|
||||||
self.assertEqual(dispatcher(1, "2"), 1)
|
with pytest.raises(TypeError):
|
||||||
self.assertRaises(TypeError, dispatcher, "2", 1)
|
|
||||||
dispatcher.register_rule(lambda x, y: x, str)
|
dispatcher.register_rule(lambda x, y: x, str)
|
||||||
self.assertEqual(dispatcher(1, 2), 1)
|
|
||||||
self.assertEqual(dispatcher(1, "2"), 1)
|
|
||||||
self.assertEqual(dispatcher("1", "2"), "1")
|
|
||||||
self.assertEqual(dispatcher("1", 2), "1")
|
|
||||||
|
|
||||||
class MultifunctionTests(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_it(self):
|
def test_register_rule_with_different_arg_names():
|
||||||
from generic.multidispatch import multifunction
|
dispatcher = create_dispatcher(1, args=["x"])
|
||||||
|
dispatcher.register_rule(lambda y: y, int)
|
||||||
|
assert dispatcher(1) == 1
|
||||||
|
|
||||||
@multifunction(int, str)
|
|
||||||
def func(x, y):
|
|
||||||
return str(x) + y
|
|
||||||
|
|
||||||
self.assertEqual(func(1, "2"), "12")
|
def test_dispatching_with_varargs():
|
||||||
self.assertRaises(TypeError, func, 1, 2)
|
dispatcher = create_dispatcher(1, args=["x"], varargs="va")
|
||||||
self.assertRaises(TypeError, func, "1", 2)
|
dispatcher.register_rule(lambda x, *va: x, int)
|
||||||
self.assertRaises(TypeError, func, "1", "2")
|
assert dispatcher(1) == 1
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher("1", 2, 3)
|
||||||
|
|
||||||
@func.when(str, str)
|
|
||||||
def func(x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
self.assertEqual(func(1, "2"), "12")
|
def test_dispatching_with_varkw():
|
||||||
self.assertEqual(func("1", "2"), "12")
|
dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
|
||||||
self.assertRaises(TypeError, func, 1, 2)
|
dispatcher.register_rule(lambda x, **vk: x, int)
|
||||||
self.assertRaises(TypeError, func, "1", 2)
|
assert dispatcher(1) == 1
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher("1", a=1, b=2)
|
||||||
|
|
||||||
def test_overriding(self):
|
|
||||||
from generic.multidispatch import multifunction
|
|
||||||
|
|
||||||
@multifunction(int, str)
|
def test_dispatching_with_kw():
|
||||||
def func(x, y):
|
dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
|
||||||
return str(x) + y
|
dispatcher.register_rule(lambda x, y=1: x, int)
|
||||||
|
assert dispatcher(1) == 1
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher("1", k=1)
|
||||||
|
|
||||||
self.assertEqual(func(1, "2"), "12")
|
|
||||||
self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
|
|
||||||
|
|
||||||
@func.override(int, str)
|
def test_create_dispatcher_with_pos_args_less_multi_arity():
|
||||||
def func(x, y):
|
with pytest.raises(TypeError):
|
||||||
return y + str(x)
|
create_dispatcher(2, args=["x"])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
create_dispatcher(2, args=["x", "y"], defaults=["x"])
|
||||||
|
|
||||||
self.assertEqual(func(1, "2"), "21")
|
|
||||||
|
|
||||||
class MultimethodTests(unittest.TestCase):
|
def test_register_rule_with_wrong_number_types_parameters():
|
||||||
|
dispatcher = create_dispatcher(1, args=["x", "y"])
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher.register_rule(lambda x, y: x, int, str)
|
||||||
|
|
||||||
def test_multimethod(self):
|
|
||||||
from generic.multidispatch import multimethod
|
|
||||||
from generic.multidispatch import has_multimethods
|
|
||||||
|
|
||||||
@has_multimethods
|
def test_register_rule_with_partial_dispatching():
|
||||||
class Dummy(object):
|
dispatcher = create_dispatcher(1, args=["x", "y"])
|
||||||
|
dispatcher.register_rule(lambda x, y: x, int)
|
||||||
|
assert dispatcher(1, 2) == 1
|
||||||
|
assert dispatcher(1, "2") == 1
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
dispatcher("2", 1)
|
||||||
|
dispatcher.register_rule(lambda x, y: x, str)
|
||||||
|
assert dispatcher(1, 2) == 1
|
||||||
|
assert dispatcher(1, "2") == 1
|
||||||
|
assert dispatcher("1", "2") == "1"
|
||||||
|
assert dispatcher("1", 2) == "1"
|
||||||
|
|
||||||
@multimethod(int)
|
|
||||||
def foo(self, x):
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
@foo.when(str)
|
def test_default_dispatcher():
|
||||||
def foo(self, x):
|
@multidispatch(int, str)
|
||||||
return x + "1"
|
def func(x, y):
|
||||||
|
return str(x) + y
|
||||||
|
|
||||||
self.assertEqual(Dummy().foo(1), 2)
|
assert func(1, "2") == "12"
|
||||||
self.assertEqual(Dummy().foo("1"), "11")
|
with pytest.raises(TypeError):
|
||||||
self.assertRaises(TypeError, Dummy().foo, [])
|
func(1, 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
func("1", 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
func("1", "2")
|
||||||
|
|
||||||
def test_inheritance(self):
|
|
||||||
from generic.multidispatch import multimethod
|
|
||||||
from generic.multidispatch import has_multimethods
|
|
||||||
|
|
||||||
@has_multimethods
|
def test_multiple_functions():
|
||||||
class Dummy(object):
|
@multidispatch(int, str)
|
||||||
|
def func(x, y):
|
||||||
|
return str(x) + y
|
||||||
|
|
||||||
@multimethod(int)
|
@func.register(str, str)
|
||||||
def foo(self, x):
|
def _(x, y):
|
||||||
return x + 1
|
return x + y
|
||||||
|
|
||||||
@foo.when(float)
|
assert func(1, "2") == "12"
|
||||||
def foo(self, x):
|
assert func("1", "2") == "12"
|
||||||
return x + 1.5
|
with pytest.raises(TypeError):
|
||||||
|
func(1, 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
func("1", 2)
|
||||||
|
|
||||||
@has_multimethods
|
|
||||||
class DummySub(Dummy):
|
|
||||||
|
|
||||||
@Dummy.foo.when(str)
|
def test_default():
|
||||||
def foo(self, x):
|
@multidispatch()
|
||||||
return x + "1"
|
def func(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
@foo.when(tuple)
|
@func.register(str, str)
|
||||||
def foo(self, x):
|
def _(x, y):
|
||||||
return x + (1,)
|
return y + x
|
||||||
|
|
||||||
@Dummy.foo.when(bool)
|
assert func(1, 1) == 2
|
||||||
def foo(self, x):
|
assert func("1", "2") == "21"
|
||||||
return not x
|
|
||||||
|
|
||||||
self.assertEqual(Dummy().foo(1), 2)
|
|
||||||
self.assertEqual(Dummy().foo(1.5), 3.0)
|
|
||||||
self.assertRaises(TypeError, Dummy().foo, "1")
|
|
||||||
self.assertEqual(DummySub().foo(1), 2)
|
|
||||||
self.assertEqual(DummySub().foo(1.5), 3.0)
|
|
||||||
self.assertEqual(DummySub().foo("1"), "11")
|
|
||||||
self.assertEqual(DummySub().foo((1,2)), (1,2,1))
|
|
||||||
self.assertEqual(DummySub().foo(True), False)
|
|
||||||
self.assertRaises(TypeError, DummySub().foo, [])
|
|
||||||
|
|
||||||
def test_override(self):
|
def test_on_classes():
|
||||||
from generic.multidispatch import multimethod
|
@multidispatch()
|
||||||
from generic.multidispatch import has_multimethods
|
class A:
|
||||||
|
def __init__(self, a, b):
|
||||||
|
self.v = a + b
|
||||||
|
|
||||||
@has_multimethods
|
@A.register(str, str) # type: ignore[attr-defined]
|
||||||
class Dummy(object):
|
class B:
|
||||||
|
def __init__(self, a, b):
|
||||||
|
self.v = b + a
|
||||||
|
|
||||||
@multimethod(str, str)
|
assert A(1, 1).v == 2
|
||||||
def foo(self, x, y):
|
assert A("1", "2").v == "21"
|
||||||
return x + y
|
|
||||||
|
|
||||||
@foo.when(str, str)
|
|
||||||
def foo(self, x, y):
|
|
||||||
return y + x
|
|
||||||
|
|
||||||
self.assertEqual(Dummy().foo("1", "2"), "21")
|
|
||||||
|
|
||||||
def test_inheritance_override(self):
|
|
||||||
from generic.multidispatch import multimethod
|
|
||||||
from generic.multidispatch import has_multimethods
|
|
||||||
|
|
||||||
@has_multimethods
|
|
||||||
class Dummy(object):
|
|
||||||
|
|
||||||
@multimethod(int)
|
|
||||||
def foo(self, x):
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
@has_multimethods
|
|
||||||
class DummySub(Dummy):
|
|
||||||
|
|
||||||
@Dummy.foo.when(int)
|
|
||||||
def foo(self, x):
|
|
||||||
return x + 2
|
|
||||||
|
|
||||||
self.assertEqual(Dummy().foo(1), 2)
|
|
||||||
self.assertEqual(DummySub().foo(1), 3)
|
|
||||||
|
@ -1,135 +1,147 @@
|
|||||||
""" Tests for :module:`generic.registry`."""
|
""" Tests for :module:`generic.registry`."""
|
||||||
|
|
||||||
import unittest
|
import pytest
|
||||||
|
|
||||||
__all__ = ("RegistryTests",)
|
from typing import Union
|
||||||
|
from generic.registry import Registry, SimpleAxis, TypeAxis
|
||||||
|
|
||||||
class RegistryTests(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_one_axis_no_specificity(self):
|
class DummyA:
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(('foo', SimpleAxis()))
|
|
||||||
a = object()
|
|
||||||
b = object()
|
|
||||||
registry.register(a)
|
|
||||||
registry.register(b, 'foo')
|
|
||||||
|
|
||||||
self.assertEqual(registry.lookup(), a)
|
|
||||||
self.assertEqual(registry.lookup('foo'), b)
|
|
||||||
self.assertEqual(registry.lookup('bar'), None)
|
|
||||||
|
|
||||||
def test_two_axes(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
from generic.registry import TypeAxis
|
|
||||||
registry = Registry(('type', TypeAxis()),
|
|
||||||
('name', SimpleAxis()))
|
|
||||||
|
|
||||||
target1 = Target('one')
|
|
||||||
registry.register(target1, object)
|
|
||||||
|
|
||||||
target2 = Target('two')
|
|
||||||
registry.register(target2, DummyA)
|
|
||||||
|
|
||||||
target3 = Target('three')
|
|
||||||
registry.register(target3, DummyA, 'foo')
|
|
||||||
|
|
||||||
context1 = object()
|
|
||||||
self.assertEqual(registry.lookup(context1), target1)
|
|
||||||
|
|
||||||
context2 = DummyB()
|
|
||||||
self.assertEqual(registry.lookup(context2), target2)
|
|
||||||
self.assertEqual(registry.lookup(context2, 'foo'), target3)
|
|
||||||
|
|
||||||
target4 = object()
|
|
||||||
registry.register(target4, DummyB)
|
|
||||||
|
|
||||||
self.assertEqual(registry.lookup(context2), target4)
|
|
||||||
self.assertEqual(registry.lookup(context2, 'foo'), target3)
|
|
||||||
|
|
||||||
def test_get_registration(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
from generic.registry import TypeAxis
|
|
||||||
registry = Registry(('type', TypeAxis()),
|
|
||||||
('name', SimpleAxis()))
|
|
||||||
registry.register('one', object)
|
|
||||||
registry.register('two', DummyA, 'foo')
|
|
||||||
self.assertEqual(registry.get_registration(object), 'one')
|
|
||||||
self.assertEqual(registry.get_registration(DummyA, 'foo'), 'two')
|
|
||||||
self.assertEqual(registry.get_registration(object, 'foo'), None)
|
|
||||||
self.assertEqual(registry.get_registration(DummyA), None)
|
|
||||||
|
|
||||||
def test_register_too_many_keys(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(('name', SimpleAxis()))
|
|
||||||
self.assertRaises(ValueError, registry.register, object(),
|
|
||||||
'one', 'two')
|
|
||||||
|
|
||||||
def test_lookup_too_many_keys(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(('name', SimpleAxis()))
|
|
||||||
self.assertRaises(ValueError, registry.lookup, 'one', 'two')
|
|
||||||
|
|
||||||
def test_conflict_error(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(('name', SimpleAxis()))
|
|
||||||
registry.register(object(), name='foo')
|
|
||||||
self.assertRaises(ValueError, registry.register, object(), 'foo')
|
|
||||||
|
|
||||||
def test_override(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(('name', SimpleAxis()))
|
|
||||||
registry.register(1, name='foo')
|
|
||||||
registry.override(2, name='foo')
|
|
||||||
self.assertEqual(registry.lookup('foo'), 2)
|
|
||||||
|
|
||||||
def test_skip_nodes(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(
|
|
||||||
('one', SimpleAxis()),
|
|
||||||
('two', SimpleAxis()),
|
|
||||||
('three', SimpleAxis())
|
|
||||||
)
|
|
||||||
registry.register('foo', one=1, three=3)
|
|
||||||
self.assertEqual(registry.lookup(1, three=3), 'foo')
|
|
||||||
|
|
||||||
def test_miss(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(
|
|
||||||
('one', SimpleAxis()),
|
|
||||||
('two', SimpleAxis()),
|
|
||||||
('three', SimpleAxis())
|
|
||||||
)
|
|
||||||
registry.register('foo', 1, 2)
|
|
||||||
self.assertEqual(registry.lookup(one=1, three=3), None)
|
|
||||||
|
|
||||||
def test_bad_lookup(self):
|
|
||||||
from generic.registry import Registry
|
|
||||||
from generic.registry import SimpleAxis
|
|
||||||
registry = Registry(('name', SimpleAxis()),
|
|
||||||
('grade', SimpleAxis()))
|
|
||||||
self.assertRaises(ValueError, registry.register, 1, foo=1)
|
|
||||||
self.assertRaises(ValueError, registry.lookup, foo=1)
|
|
||||||
self.assertRaises(ValueError, registry.register, 1, 'foo', name='foo')
|
|
||||||
|
|
||||||
class DummyA(object):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DummyB(DummyA):
|
class DummyB(DummyA):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class Target(object):
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
# Only called if being printed due to a failing test
|
def test_one_axis_no_specificity():
|
||||||
def __repr__(self): #pragma NO COVERAGE
|
registry: Registry[object] = Registry(("foo", SimpleAxis()))
|
||||||
return "Target('%s')" % self.name
|
a = object()
|
||||||
|
b = object()
|
||||||
|
registry.register(a)
|
||||||
|
registry.register(b, "foo")
|
||||||
|
|
||||||
|
assert registry.lookup() == a
|
||||||
|
assert registry.lookup("foo") == b
|
||||||
|
assert registry.lookup("bar") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtyping_on_axes():
|
||||||
|
registry: Registry[str] = Registry(("type", TypeAxis()))
|
||||||
|
|
||||||
|
target1 = "one"
|
||||||
|
registry.register(target1, object)
|
||||||
|
|
||||||
|
target2 = "two"
|
||||||
|
registry.register(target2, DummyA)
|
||||||
|
|
||||||
|
target3 = "three"
|
||||||
|
registry.register(target3, DummyB)
|
||||||
|
|
||||||
|
assert registry.lookup(object()) == target1
|
||||||
|
assert registry.lookup(DummyA()) == target2
|
||||||
|
assert registry.lookup(DummyB()) == target3
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_subtyping_on_axes():
|
||||||
|
registry: Registry[str] = Registry(("type", TypeAxis()))
|
||||||
|
|
||||||
|
target1 = "one"
|
||||||
|
registry.register(target1, object)
|
||||||
|
|
||||||
|
target2 = "two"
|
||||||
|
registry.register(target2, DummyA)
|
||||||
|
|
||||||
|
target3 = "three"
|
||||||
|
registry.register(target3, DummyB)
|
||||||
|
|
||||||
|
target4 = "four"
|
||||||
|
registry.register(target4, int)
|
||||||
|
|
||||||
|
assert list(registry.query(object())) == [target1]
|
||||||
|
assert list(registry.query(DummyA())) == [target2, target1]
|
||||||
|
assert list(registry.query(DummyB())) == [target3, target2, target1]
|
||||||
|
assert list(registry.query(3)) == [target4, target1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_axes():
|
||||||
|
registry: Registry[Union[str, object]] = Registry(
|
||||||
|
("type", TypeAxis()), ("name", SimpleAxis())
|
||||||
|
)
|
||||||
|
|
||||||
|
target1 = "one"
|
||||||
|
registry.register(target1, object)
|
||||||
|
|
||||||
|
target2 = "two"
|
||||||
|
registry.register(target2, DummyA)
|
||||||
|
|
||||||
|
target3 = "three"
|
||||||
|
registry.register(target3, DummyA, "foo")
|
||||||
|
|
||||||
|
context1 = object()
|
||||||
|
assert registry.lookup(context1) == target1
|
||||||
|
|
||||||
|
context2 = DummyB()
|
||||||
|
assert registry.lookup(context2) == target2
|
||||||
|
assert registry.lookup(context2, "foo") == target3
|
||||||
|
|
||||||
|
target4 = object()
|
||||||
|
registry.register(target4, DummyB)
|
||||||
|
|
||||||
|
assert registry.lookup(context2) == target4
|
||||||
|
assert registry.lookup(context2, "foo") == target3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_registration():
|
||||||
|
registry: Registry[str] = Registry(("type", TypeAxis()), ("name", SimpleAxis()))
|
||||||
|
registry.register("one", object)
|
||||||
|
registry.register("two", DummyA, "foo")
|
||||||
|
assert registry.get_registration(object) == "one"
|
||||||
|
assert registry.get_registration(DummyA, "foo") == "two"
|
||||||
|
assert registry.get_registration(object, "foo") is None
|
||||||
|
assert registry.get_registration(DummyA) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_too_many_keys():
|
||||||
|
registry: Registry[type] = Registry(("name", SimpleAxis()))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
registry.register(object, "one", "two")
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_too_many_keys():
|
||||||
|
registry: Registry[object] = Registry(("name", SimpleAxis()))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
registry.register(registry.lookup("one", "two"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_conflict_error():
|
||||||
|
registry: Registry[Union[object, type]] = Registry(("name", SimpleAxis()))
|
||||||
|
registry.register(object(), name="foo")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
registry.register(object, "foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_skip_nodes():
|
||||||
|
registry: Registry[str] = Registry(
|
||||||
|
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
|
||||||
|
)
|
||||||
|
registry.register("foo", one=1, three=3)
|
||||||
|
assert registry.lookup(1, three=3) == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_miss():
|
||||||
|
registry: Registry[str] = Registry(
|
||||||
|
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
|
||||||
|
)
|
||||||
|
registry.register("foo", 1, 2)
|
||||||
|
assert registry.lookup(one=1, three=3) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_lookup():
|
||||||
|
registry: Registry[int] = Registry(("name", SimpleAxis()), ("grade", SimpleAxis()))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
registry.register(1, foo=1)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
registry.lookup(foo=1)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
registry.register(1, "foo", name="foo")
|
||||||
|
Loading…
Reference in New Issue
Block a user