Replace old code by Python3 implementation from Gaphor
Multimethods are missing now. Those have to be added back.
This commit is contained in:
parent
ee35c528cb
commit
bb20f6992b
105
generic/event.py
105
generic/event.py
@ -2,7 +2,7 @@
|
||||
|
||||
This module provides API for event management. There are two APIs provided:
|
||||
|
||||
* Global event management API: subscribe, unsubscribe, fire.
|
||||
* Global event management API: subscribe, unsubscribe, handle.
|
||||
* Local event management API: Manager
|
||||
|
||||
If you run only one instance of your application per Python
|
||||
@ -12,116 +12,83 @@ to have different configurations for them -- you should use local API
|
||||
and have one instance of Manager object per application instance.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Callable, Set, Type
|
||||
|
||||
from generic.registry import Registry
|
||||
from generic.registry import TypeAxis
|
||||
from generic.registry import Registry, TypeAxis
|
||||
|
||||
__all__ = ("Manager", "subscribe", "unsubscribe", "fire", "subscriber")
|
||||
|
||||
class HandlerSet(namedtuple("HandlerSet", ["parents", "handlers"])):
|
||||
""" Set of handlers for specific type of event.
|
||||
__all__ = "Manager"
|
||||
|
||||
This object stores ``handlers`` for specific event type and
|
||||
``parents`` reference to handler sets of event's supertypes.
|
||||
"""
|
||||
Event = object
|
||||
Handler = Callable[[object], None]
|
||||
HandlerSet = Set[Handler]
|
||||
|
||||
@property
|
||||
def all_handlers(self):
|
||||
""" Iterate over own and supertypes' handlers.
|
||||
|
||||
This iterator yields just unique values, so it won't yield the
|
||||
same handler twice, even if it was registered both for some
|
||||
event type and its supertype.
|
||||
"""
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
|
||||
# yield own handlers first
|
||||
for handler in self.handlers:
|
||||
seen_add(handler)
|
||||
yield handler
|
||||
|
||||
# yield supertypes' handlers then
|
||||
for parent in self.parents:
|
||||
for handler in parent.all_handlers:
|
||||
if not handler in seen:
|
||||
seen_add(handler)
|
||||
yield handler
|
||||
|
||||
class Manager(object):
|
||||
class Manager:
|
||||
""" Event manager
|
||||
|
||||
Provides API for subscribing for and firing events. There's also global
|
||||
event manager instantiated at module level with functions
|
||||
:func:`.subscribe`, :func:`.fire` and decorator :func:`.subscriber` aliased
|
||||
:func:`.subscribe`, :func:`.handle` and decorator :func:`.subscriber` aliased
|
||||
to corresponding methods of class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
registry: Registry[HandlerSet]
|
||||
|
||||
def __init__(self) -> None:
|
||||
axes = (("event_type", TypeAxis()),)
|
||||
self.registry = Registry(*axes)
|
||||
|
||||
def subscribe(self, handler, event_type):
|
||||
def subscribe(self, handler: Handler, event_type: Type[Event]) -> None:
|
||||
""" Subscribe ``handler`` to specified ``event_type``"""
|
||||
handler_set = self.registry.get_registration(event_type)
|
||||
if not handler_set:
|
||||
if handler_set is None:
|
||||
handler_set = self._register_handler_set(event_type)
|
||||
handler_set.handlers.add(handler)
|
||||
handler_set.add(handler)
|
||||
|
||||
def unsubscribe(self, handler, event_type):
|
||||
def unsubscribe(self, handler: Handler, event_type: Type[Event]) -> None:
|
||||
""" Unsubscribe ``handler`` from ``event_type``"""
|
||||
handler_set = self.registry.get_registration(event_type)
|
||||
if handler_set and handler in handler_set.handlers:
|
||||
handler_set.handlers.remove(handler)
|
||||
if handler_set and handler in handler_set:
|
||||
handler_set.remove(handler)
|
||||
|
||||
def fire(self, event):
|
||||
def handle(self, event: Event) -> None:
|
||||
""" Fire ``event``
|
||||
|
||||
All subscribers will be executed with no determined order.
|
||||
"""
|
||||
handler_set = self.registry.lookup(event)
|
||||
for handler in handler_set.all_handlers:
|
||||
handler(event)
|
||||
handler_sets = self.registry.query(event)
|
||||
for handler_set in handler_sets:
|
||||
if handler_set:
|
||||
for handler in set(handler_set):
|
||||
handler(event)
|
||||
|
||||
def _register_handler_set(self, event_type):
|
||||
""" Register new handler set for ``event_type``."""
|
||||
# Collect handler sets for supertypes
|
||||
parent_handler_sets = []
|
||||
parents = event_type.__bases__
|
||||
for parent in parents:
|
||||
parent_handlers = self.registry.get_registration(parent)
|
||||
if parent_handlers is None:
|
||||
parent_handlers = self._register_handler_set(parent)
|
||||
parent_handler_sets.append(parent_handlers)
|
||||
|
||||
handler_set = HandlerSet(parents=parent_handler_sets, handlers=set())
|
||||
def _register_handler_set(self, event_type: Type[Event]) -> HandlerSet:
|
||||
""" Register new handler set for ``event_type``.
|
||||
"""
|
||||
handler_set: HandlerSet = set()
|
||||
self.registry.register(handler_set, event_type)
|
||||
return handler_set
|
||||
|
||||
def subscriber(self, event_type):
|
||||
def subscriber(self, event_type: Type[Event]) -> Callable[[Handler], Handler]:
|
||||
""" Decorator for subscribing handlers
|
||||
|
||||
Works like this:
|
||||
|
||||
>>> mymanager = Manager()
|
||||
>>> class MyEvent():
|
||||
... pass
|
||||
>>> @mymanager.subscriber(MyEvent)
|
||||
... def mysubscriber(evt):
|
||||
... # handle event
|
||||
... return
|
||||
|
||||
>>> mymanager.fire(MyEvent())
|
||||
>>> mymanager.handle(MyEvent())
|
||||
|
||||
"""
|
||||
def registrator(func):
|
||||
|
||||
def registrator(func: Handler) -> Handler:
|
||||
self.subscribe(func, event_type)
|
||||
return func
|
||||
|
||||
return registrator
|
||||
|
||||
# Global event manager
|
||||
_global_manager = Manager()
|
||||
|
||||
# Global event management API
|
||||
subscribe = _global_manager.subscribe
|
||||
unsubscribe = _global_manager.unsubscribe
|
||||
fire = _global_manager.fire
|
||||
subscriber = _global_manager.subscriber
|
||||
|
@ -1,206 +1,132 @@
|
||||
""" Multidispatch for functions and methods"""
|
||||
""" Multidispatch for functions and methods.
|
||||
|
||||
This code is a Python 3, slimmed down version of the
|
||||
generic package by Andrey Popp.
|
||||
|
||||
Only the generic function code is left in tact -- no generic methods.
|
||||
The interface has been made in line with `functools.singledispatch`.
|
||||
|
||||
Note that this module does not support annotated functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast, Any, Callable, Generic, TypeVar, Union
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
import threading
|
||||
|
||||
from generic.registry import Registry
|
||||
from generic.registry import TypeAxis
|
||||
from generic.registry import Registry, TypeAxis
|
||||
|
||||
__all__ = ("multifunction", "multimethod", "has_multimethods")
|
||||
__all__ = "multidispatch"
|
||||
|
||||
def multifunction(*argtypes):
|
||||
""" Declare function as multifunction
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], type])
|
||||
KeyType = Union[type, None]
|
||||
|
||||
|
||||
def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]:
|
||||
""" Declare function as multidispatch
|
||||
|
||||
This decorator takes ``argtypes`` argument types and replace decorated
|
||||
function with :class:`.FunctionDispatcher` object, which is responsible for
|
||||
multiple dispatch feature.
|
||||
"""
|
||||
def _replace_with_dispatcher(func):
|
||||
dispatcher = _make_dispatcher(FunctionDispatcher, func, len(argtypes))
|
||||
|
||||
def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]:
|
||||
nonlocal argtypes
|
||||
argspec = inspect.getfullargspec(func)
|
||||
if not argtypes:
|
||||
arity = _arity(argspec)
|
||||
if isinstance(func, type):
|
||||
# It's a class we deal with:
|
||||
arity -= 1
|
||||
argtypes = (object,) * arity
|
||||
|
||||
dispatcher = cast(
|
||||
FunctionDispatcher[T],
|
||||
functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func),
|
||||
)
|
||||
dispatcher.register_rule(func, *argtypes)
|
||||
return dispatcher
|
||||
|
||||
return _replace_with_dispatcher
|
||||
|
||||
def multimethod(*argtypes):
|
||||
""" Declare method as multimethod
|
||||
|
||||
This decorator works exactly the same as :func:`.multifunction` decorator
|
||||
but replaces decorated method with :class:`.MethodDispatcher` object
|
||||
instead.
|
||||
|
||||
Should be used only for decorating methods and enclosing class should have
|
||||
:func:`.has_multimethods` decorator.
|
||||
"""
|
||||
def _replace_with_dispatcher(func):
|
||||
dispatcher = _make_dispatcher(MethodDispatcher, func, len(argtypes) + 1)
|
||||
dispatcher.register_unbound_rule(func, *argtypes)
|
||||
return dispatcher
|
||||
return _replace_with_dispatcher
|
||||
|
||||
def has_multimethods(cls):
|
||||
""" Declare class as one that have multimethods
|
||||
|
||||
Should only be used for decorating classes which have methods decorated with
|
||||
:func:`.multimethod` decorator.
|
||||
"""
|
||||
for name, obj in cls.__dict__.items():
|
||||
if isinstance(obj, MethodDispatcher):
|
||||
obj.proceed_unbound_rules(cls)
|
||||
return cls
|
||||
|
||||
class FunctionDispatcher(object):
|
||||
class FunctionDispatcher(Generic[T]):
|
||||
""" Multidispatcher for functions
|
||||
|
||||
This object dispatch calls to function by its argument types. Usually it is
|
||||
produced by :func:`.multifunction` decorator.
|
||||
produced by :func:`.multidispatch` decorator.
|
||||
|
||||
You should not manually create objects of this type.
|
||||
"""
|
||||
|
||||
def __init__(self, argspec, params_arity):
|
||||
registry: Registry[T]
|
||||
|
||||
def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None:
|
||||
""" Initialize dispatcher with ``argspec`` of type
|
||||
:class:`inspect.ArgSpec` and ``params_arity`` that represent number
|
||||
params."""
|
||||
# Check if we have enough positional arguments for number of type params
|
||||
if arity(argspec) < params_arity:
|
||||
raise TypeError("Not enough positional arguments "
|
||||
"for number of type parameters provided.")
|
||||
if _arity(argspec) < params_arity:
|
||||
raise TypeError(
|
||||
"Not enough positional arguments "
|
||||
"for number of type parameters provided."
|
||||
)
|
||||
|
||||
self.argspec = argspec
|
||||
self.params_arity = params_arity
|
||||
|
||||
axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)]
|
||||
axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)]
|
||||
self.registry = Registry(*axis)
|
||||
|
||||
def check_rule(self, rule, *argtypes):
|
||||
def check_rule(self, rule: T, *argtypes: KeyType) -> None:
|
||||
# Check if we have the right number of parametrized types
|
||||
if len(argtypes) != self.params_arity:
|
||||
raise TypeError("Wrong number of type parameters.")
|
||||
raise TypeError(
|
||||
f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}."
|
||||
)
|
||||
|
||||
# Check if we have the same argspec (by number of args)
|
||||
rule_argspec = inspect.getargspec(rule)
|
||||
if not is_equalent_argspecs(rule_argspec, self.argspec):
|
||||
raise TypeError("Rule does not conform "
|
||||
"to previous implementations.")
|
||||
rule_argspec = inspect.getfullargspec(rule)
|
||||
left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4])
|
||||
right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4])
|
||||
if left_spec != right_spec:
|
||||
raise TypeError(
|
||||
f"Rule does not conform to previous implementations: {left_spec} != {right_spec}."
|
||||
)
|
||||
|
||||
def register_rule(self, rule, *argtypes):
|
||||
def register_rule(self, rule: T, *argtypes: KeyType) -> None:
|
||||
""" Register new ``rule`` for ``argtypes``."""
|
||||
self.check_rule(rule, *argtypes)
|
||||
self.registry.register(rule, *argtypes)
|
||||
|
||||
def override_rule(self, rule, *argtypes):
|
||||
""" Override ``rule`` for ``argtypes``."""
|
||||
self.check_rule(rule, *argtypes)
|
||||
self.registry.override(rule, *argtypes)
|
||||
|
||||
def lookup_rule(self, *args):
|
||||
""" Lookup rule by ``args``. Returns None if no rule was found."""
|
||||
args = args[:self.params_arity]
|
||||
rule = self.registry.lookup(*args)
|
||||
if rule is None:
|
||||
raise TypeError("No available rule found for %r" % (args,))
|
||||
return rule
|
||||
|
||||
def when(self, *argtypes):
|
||||
""" Decorator for registering new case for multifunction
|
||||
def register(self, *argtypes: KeyType) -> Callable[[T], T]:
|
||||
""" Decorator for registering new case for multidispatch
|
||||
|
||||
New case will be registered for types identified by ``argtypes``. The
|
||||
length of ``argtypes`` should be equal to the length of ``argtypes``
|
||||
argument were passed corresponding :func:`.multifunction` call, which
|
||||
also indicated the number of arguments multifunction dispatches on.
|
||||
argument were passed corresponding :func:`.multidispatch` call, which
|
||||
also indicated the number of arguments multidispatch dispatches on.
|
||||
"""
|
||||
def register_rule(func):
|
||||
|
||||
def register_rule(func: T) -> T:
|
||||
self.register_rule(func, *argtypes)
|
||||
return self
|
||||
return func
|
||||
|
||||
return register_rule
|
||||
|
||||
@property
|
||||
def otherwise(self):
|
||||
""" Decorator which registeres "catch-all" case for multifunction"""
|
||||
def register_rule(func):
|
||||
self.register_rule(func, [object]*self.params_arity)
|
||||
return self
|
||||
return register_rule
|
||||
|
||||
def override(self, *argtypes):
|
||||
""" Decorator for overriding case for ``argtypes``"""
|
||||
def override_rule(func):
|
||||
self.override_rule(func, *argtypes)
|
||||
return self
|
||||
return override_rule
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
""" Dispatch call to appropriate rule."""
|
||||
rule = self.lookup_rule(*args)
|
||||
trimmed_args = args[: self.params_arity]
|
||||
rule = self.registry.lookup(*trimmed_args)
|
||||
if not rule:
|
||||
raise TypeError(f"No available rule found for {trimmed_args!r}")
|
||||
return rule(*args, **kwargs)
|
||||
|
||||
class MethodDispatcher(FunctionDispatcher):
|
||||
""" Multiple dispatch for methods
|
||||
|
||||
This object dispatch call to method by its class and arguments types.
|
||||
Usually it is produced by :func:`.multimethod` decorator.
|
||||
|
||||
You should not manually create objects of this type.
|
||||
"""
|
||||
|
||||
def __init__(self, argspec, params_arity):
|
||||
FunctionDispatcher.__init__(self, argspec, params_arity)
|
||||
|
||||
# some data, that should be local to thread of execution
|
||||
self.local = threading.local()
|
||||
self.local.unbound_rules = []
|
||||
|
||||
def register_unbound_rule(self, func, *argtypes):
|
||||
""" Register unbound rule that should be processed by
|
||||
``proceed_unbound_rules`` later."""
|
||||
self.local.unbound_rules.append((argtypes, func))
|
||||
|
||||
def proceed_unbound_rules(self, cls):
|
||||
""" Process all unbound rule by binding them to ``cls`` type."""
|
||||
for argtypes, func in self.local.unbound_rules:
|
||||
argtypes = (cls,) + argtypes
|
||||
self.override_rule(func, *argtypes)
|
||||
self.local.unbound_rules = []
|
||||
|
||||
def __get__(self, obj, cls):
|
||||
if obj is None:
|
||||
return self
|
||||
return types.MethodType(self, obj)
|
||||
|
||||
def when(self, *argtypes):
|
||||
""" Register new case for multimethod for ``argtypes``"""
|
||||
def make_declaration(meth):
|
||||
self.register_unbound_rule(meth, *argtypes)
|
||||
return self
|
||||
return make_declaration
|
||||
|
||||
def override(self, *argtypes):
|
||||
""" Decorator for overriding case for ``argtypes``"""
|
||||
return self.when(*argtypes)
|
||||
|
||||
@property
|
||||
def otherwise(self):
|
||||
""" Decorator which registeres "catch-all" case for multimethod"""
|
||||
def make_declaration(func):
|
||||
self.register_unbound_rule(func, [object]*self.params_arity)
|
||||
return self
|
||||
return make_declaration
|
||||
|
||||
def arity(argspec):
|
||||
def _arity(argspec: inspect.FullArgSpec) -> int:
|
||||
""" Determinal positional arity of argspec."""
|
||||
args = argspec.args if argspec.args else []
|
||||
defaults = argspec.defaults if argspec.defaults else []
|
||||
return len(args) - len(defaults)
|
||||
|
||||
def is_equalent_argspecs(left, right):
|
||||
""" Check argspec equalence."""
|
||||
return map(lambda x: len(x) if x else 0, left) == \
|
||||
map(lambda x: len(x) if x else 0, right)
|
||||
|
||||
def _make_dispatcher(dispacther_cls, func, params_arity):
|
||||
argspec = inspect.getargspec(func)
|
||||
wrapper = functools.wraps(func)
|
||||
dispatcher = wrapper(dispacther_cls(argspec, params_arity))
|
||||
return dispatcher
|
||||
|
@ -5,84 +5,101 @@ This implementation was borrowed from happy[1] project by Chris Rossi.
|
||||
[1]: http://bitbucket.org/chrisrossi/happy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ("Registry", "SimpleAxis", "TypeAxis")
|
||||
|
||||
class Registry(object):
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
KeysView,
|
||||
List,
|
||||
Generator,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
K = TypeVar("K")
|
||||
S = TypeVar("S")
|
||||
T = TypeVar("T")
|
||||
V = TypeVar("V")
|
||||
Axis = Union["SimpleAxis", "TypeAxis"]
|
||||
|
||||
|
||||
class Registry(Generic[T]):
|
||||
""" Registry implementation."""
|
||||
|
||||
def __init__(self, *axes):
|
||||
self._tree = _TreeNode()
|
||||
def __init__(self, *axes: Tuple[str, Axis]):
|
||||
self._tree: _TreeNode[T] = _TreeNode()
|
||||
self._axes = [axis for name, axis in axes]
|
||||
self._axes_dict = dict([
|
||||
(name, (i, axis)) for i, (name, axis) in enumerate(axes)
|
||||
])
|
||||
self._axes_dict = {name: (i, axis) for i, (name, axis) in enumerate(axes)}
|
||||
|
||||
def register(self, target, *arg_keys, **kw_keys):
|
||||
self._register(target, self._align_with_axes(arg_keys, kw_keys), False)
|
||||
|
||||
def override(self, target, *arg_keys, **kw_keys):
|
||||
self._register(target, self._align_with_axes(arg_keys, kw_keys), True)
|
||||
|
||||
def _register(self, target, keys, override):
|
||||
def register(self, target: T, *arg_keys: K, **kw_keys: K) -> None:
|
||||
tree_node = self._tree
|
||||
for key in keys:
|
||||
tree_node = tree_node.setdefault(key, _TreeNode())
|
||||
for key in self._align_with_axes(arg_keys, kw_keys):
|
||||
tree_node = tree_node.setdefault(key, _TreeNode[T]())
|
||||
|
||||
if not override and not tree_node.target is None:
|
||||
if not tree_node.target is None:
|
||||
raise ValueError(
|
||||
"Registration conflicts with existing registration. Use "
|
||||
"override method to override.")
|
||||
f"Registration for {target} conflicts with existing registration {tree_node.target}."
|
||||
)
|
||||
|
||||
tree_node.target = target
|
||||
|
||||
def get_registration(self, *arg_keys, **kw_keys):
|
||||
def get_registration(self, *arg_keys: K, **kw_keys: K) -> Optional[T]:
|
||||
tree_node = self._tree
|
||||
for key in self._align_with_axes(arg_keys, kw_keys):
|
||||
if not tree_node.has_key(key):
|
||||
if not key in tree_node:
|
||||
return None
|
||||
tree_node = tree_node[key]
|
||||
|
||||
return tree_node.target
|
||||
|
||||
def lookup(self, *arg_objs, **kw_objs):
|
||||
def lookup(self, *arg_objs: V, **kw_objs: V) -> Optional[T]:
|
||||
return next(self.query(*arg_objs, **kw_objs), None)
|
||||
|
||||
def query(self, *arg_objs: V, **kw_objs: V) -> Generator[Optional[T], None, None]:
|
||||
objs = self._align_with_axes(arg_objs, kw_objs)
|
||||
axes = self._axes
|
||||
return self._lookup(self._tree, objs, axes)
|
||||
return self._query(self._tree, objs, axes)
|
||||
|
||||
def _lookup(self, tree_node, objs, axes):
|
||||
def _query(
|
||||
self, tree_node: _TreeNode[T], objs: Sequence[Optional[V]], axes: Sequence[Axis]
|
||||
) -> Generator[Optional[T], None, None]:
|
||||
""" Recursively traverse registration tree, from left to right, most
|
||||
specific to least specific, returning the first target found on a
|
||||
matching node. """
|
||||
if not objs:
|
||||
return tree_node.target
|
||||
yield tree_node.target
|
||||
else:
|
||||
obj = objs[0]
|
||||
|
||||
obj = objs[0]
|
||||
# Skip non-participating nodes
|
||||
if obj is None:
|
||||
next_node: Optional[_TreeNode[T]] = tree_node.get(None, None)
|
||||
if next_node is not None:
|
||||
yield from self._query(next_node, objs[1:], axes[1:])
|
||||
else:
|
||||
# Get matches on this axis and iterate from most to least specific
|
||||
axis = axes[0]
|
||||
for match_key in axis.matches(obj, tree_node.keys()):
|
||||
yield from self._query(tree_node[match_key], objs[1:], axes[1:])
|
||||
|
||||
# Skip non-participating nodes
|
||||
if obj is None:
|
||||
next_node = tree_node.get(None, None)
|
||||
if next_node is not None:
|
||||
return self._lookup(next_node, objs[1:], axes[1:])
|
||||
return None
|
||||
|
||||
# Get matches on this axis and iterate from most to least specific
|
||||
axis = axes[0]
|
||||
for match_key in axis.matches(obj, tree_node.keys()):
|
||||
target = self._lookup(tree_node[match_key], objs[1:], axes[1:])
|
||||
if target is not None:
|
||||
return target
|
||||
|
||||
return None
|
||||
|
||||
def _align_with_axes(self, args, kw):
|
||||
def _align_with_axes(
|
||||
self, args: Sequence[S], kw: Dict[str, S]
|
||||
) -> Sequence[Optional[S]]:
|
||||
""" Create a list matching up all args and kwargs with their
|
||||
corresponding axes, in order, using ``None`` as a placeholder for
|
||||
skipped axes. """
|
||||
axes_dict = self._axes_dict
|
||||
aligned = [None for i in xrange(len(axes_dict))]
|
||||
aligned: List[Optional[S]] = [None for i in range(len(axes_dict))]
|
||||
|
||||
args_len = len(args)
|
||||
if args_len + len(kw) > len(aligned):
|
||||
if args_len + len(kw) > len(aligned):
|
||||
raise ValueError("Cannot have more arguments than axes.")
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
@ -91,12 +108,13 @@ class Registry(object):
|
||||
for k, v in kw.items():
|
||||
i_axis = axes_dict.get(k, None)
|
||||
if i_axis is None:
|
||||
raise ValueError("No axis with name: %s" % k)
|
||||
raise ValueError(f"No axis with name: {k}")
|
||||
|
||||
i, axis = i_axis
|
||||
if aligned[i] is not None:
|
||||
raise ValueError("Axis defined twice between positional and "
|
||||
"keyword arguments")
|
||||
raise ValueError(
|
||||
"Axis defined twice between positional and " "keyword arguments"
|
||||
)
|
||||
|
||||
aligned[i] = v
|
||||
|
||||
@ -106,13 +124,15 @@ class Registry(object):
|
||||
|
||||
return aligned
|
||||
|
||||
class _TreeNode(dict):
|
||||
target = None
|
||||
|
||||
def __str__(self):
|
||||
return "<TreeNode %s %s>" % (self.target, dict.__str__(self))
|
||||
class _TreeNode(Generic[T], Dict[Any, Any]):
|
||||
target: Optional[T] = None
|
||||
|
||||
class SimpleAxis(object):
|
||||
def __str__(self) -> str:
|
||||
return f"<TreeNode {self.target} {dict.__str__(self)}>"
|
||||
|
||||
|
||||
class SimpleAxis:
|
||||
""" A simple axis where the key into the axis is the same as the object to
|
||||
be matched (aka the identity axis). This axis behaves just like a
|
||||
dictionary. You might use this axis if you are interested in registering
|
||||
@ -122,21 +142,23 @@ class SimpleAxis(object):
|
||||
Subclasses can override the ``get_keys`` method for implementing arbitrary
|
||||
axes.
|
||||
"""
|
||||
def matches(self, obj, keys):
|
||||
for key in self.get_keys(obj):
|
||||
|
||||
def matches(
|
||||
self, obj: object, keys: KeysView[Optional[object]]
|
||||
) -> Generator[object, None, None]:
|
||||
for key in [obj]:
|
||||
if key in keys:
|
||||
yield key
|
||||
yield obj
|
||||
|
||||
def get_keys(self, obj):
|
||||
"""
|
||||
Return the keys for the given object that could match this axis, from
|
||||
most specific to least specific. A convenient override point.
|
||||
"""
|
||||
return [obj,]
|
||||
|
||||
class TypeAxis(SimpleAxis):
|
||||
class TypeAxis:
|
||||
""" An axis which matches the class and super classes of an object in
|
||||
method resolution order.
|
||||
"""
|
||||
def get_keys(self, obj):
|
||||
return type(obj).mro()
|
||||
|
||||
def matches(
|
||||
self, obj: object, keys: KeysView[Optional[type]]
|
||||
) -> Generator[type, None, None]:
|
||||
for key in type(obj).mro():
|
||||
if key in keys:
|
||||
yield key
|
||||
|
@ -1,150 +1,181 @@
|
||||
""" Tests for :module:`generic.event`."""
|
||||
|
||||
import unittest
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ("ManagerTests",)
|
||||
from typing import Callable, List
|
||||
from generic.event import Manager
|
||||
|
||||
class ManagerTests(unittest.TestCase):
|
||||
|
||||
def makeHandler(self, effect):
|
||||
return lambda e: e.effects.append(effect)
|
||||
def make_handler(effect: object) -> Callable[[Event], None]:
|
||||
return lambda e: e.effects.append(effect)
|
||||
|
||||
def createManager(self):
|
||||
from generic.event import Manager
|
||||
return Manager()
|
||||
|
||||
def test_subscribe_single_event(self):
|
||||
events = self.createManager()
|
||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
||||
e = EventA()
|
||||
events.fire(e)
|
||||
self.assertEqual(len(e.effects), 1)
|
||||
self.assertTrue("handler1" in e.effects)
|
||||
def create_manager():
|
||||
return Manager()
|
||||
|
||||
def test_subscribe_via_decorator(self):
|
||||
events = self.createManager()
|
||||
events.subscriber(EventA)(self.makeHandler("handler1"))
|
||||
e = EventA()
|
||||
events.fire(e)
|
||||
self.assertEqual(len(e.effects), 1)
|
||||
self.assertTrue("handler1" in e.effects)
|
||||
|
||||
def test_subscribe_event_inheritance(self):
|
||||
events = self.createManager()
|
||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
||||
events.subscribe(self.makeHandler("handler2"), EventB)
|
||||
def test_subscribe_single_event():
|
||||
events = create_manager()
|
||||
events.subscribe(make_handler("handler1"), EventA)
|
||||
e = EventA()
|
||||
events.handle(e)
|
||||
assert len(e.effects) == 1
|
||||
assert "handler1" in e.effects
|
||||
|
||||
ea = EventA()
|
||||
events.fire(ea)
|
||||
self.assertEqual(len(ea.effects), 1)
|
||||
self.assertTrue("handler1" in ea.effects)
|
||||
|
||||
eb = EventB()
|
||||
events.fire(eb)
|
||||
self.assertEqual(len(eb.effects), 2)
|
||||
self.assertTrue("handler1" in eb.effects)
|
||||
self.assertTrue("handler2" in eb.effects)
|
||||
def test_subscribe_via_decorator():
|
||||
events = create_manager()
|
||||
events.subscriber(EventA)(make_handler("handler1"))
|
||||
e = EventA()
|
||||
events.handle(e)
|
||||
assert len(e.effects) == 1
|
||||
assert "handler1" in e.effects
|
||||
|
||||
def test_subscribe_event_multiple_inheritance(self):
|
||||
events = self.createManager()
|
||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
||||
events.subscribe(self.makeHandler("handler2"), EventC)
|
||||
events.subscribe(self.makeHandler("handler3"), EventD)
|
||||
|
||||
ea = EventA()
|
||||
events.fire(ea)
|
||||
self.assertEqual(len(ea.effects), 1)
|
||||
self.assertTrue("handler1" in ea.effects)
|
||||
def test_subscribe_event_inheritance():
|
||||
events = create_manager()
|
||||
events.subscribe(make_handler("handler1"), EventA)
|
||||
events.subscribe(make_handler("handler2"), EventB)
|
||||
|
||||
ec = EventC()
|
||||
events.fire(ec)
|
||||
self.assertEqual(len(ec.effects), 1)
|
||||
self.assertTrue("handler2" in ec.effects)
|
||||
ea = EventA()
|
||||
events.handle(ea)
|
||||
assert len(ea.effects) == 1
|
||||
assert "handler1" in ea.effects
|
||||
|
||||
ed = EventD()
|
||||
events.fire(ed)
|
||||
self.assertEqual(len(ed.effects), 3)
|
||||
self.assertTrue("handler1" in ed.effects)
|
||||
self.assertTrue("handler2" in ed.effects)
|
||||
self.assertTrue("handler3" in ed.effects)
|
||||
eb = EventB()
|
||||
events.handle(eb)
|
||||
assert len(eb.effects) == 2
|
||||
assert "handler1" in eb.effects
|
||||
assert "handler2" in eb.effects
|
||||
|
||||
def test_subscribe_event_malformed_multiple_inheritance(self):
|
||||
events = self.createManager()
|
||||
events.subscribe(self.makeHandler("handler1"), EventA)
|
||||
events.subscribe(self.makeHandler("handler2"), EventD)
|
||||
events.subscribe(self.makeHandler("handler3"), EventE)
|
||||
|
||||
ea = EventA()
|
||||
events.fire(ea)
|
||||
self.assertEqual(len(ea.effects), 1)
|
||||
self.assertTrue("handler1" in ea.effects)
|
||||
def test_subscribe_event_multiple_inheritance():
|
||||
events = create_manager()
|
||||
events.subscribe(make_handler("handler1"), EventA)
|
||||
events.subscribe(make_handler("handler2"), EventC)
|
||||
events.subscribe(make_handler("handler3"), EventD)
|
||||
|
||||
ed = EventD()
|
||||
events.fire(ed)
|
||||
self.assertEqual(len(ed.effects), 2)
|
||||
self.assertTrue("handler1" in ed.effects)
|
||||
self.assertTrue("handler2" in ed.effects)
|
||||
ea = EventA()
|
||||
events.handle(ea)
|
||||
assert len(ea.effects) == 1
|
||||
assert "handler1" in ea.effects
|
||||
|
||||
ee = EventE()
|
||||
events.fire(ee)
|
||||
self.assertEqual(len(ee.effects), 3)
|
||||
self.assertTrue("handler1" in ee.effects)
|
||||
self.assertTrue("handler2" in ee.effects)
|
||||
self.assertTrue("handler3" in ee.effects)
|
||||
ec = EventC()
|
||||
events.handle(ec)
|
||||
assert len(ec.effects) == 1
|
||||
assert "handler2" in ec.effects
|
||||
|
||||
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(self):
|
||||
events = self.createManager()
|
||||
events.subscribe(self.makeHandler("handler1"), Event)
|
||||
events.subscribe(self.makeHandler("handler2"), EventB)
|
||||
ed = EventD()
|
||||
events.handle(ed)
|
||||
assert len(ed.effects) == 3
|
||||
assert "handler1" in ed.effects
|
||||
assert "handler2" in ed.effects
|
||||
assert "handler3" in ed.effects
|
||||
|
||||
eb = EventB()
|
||||
events.fire(eb)
|
||||
self.assertEqual(len(eb.effects), 2)
|
||||
self.assertTrue("handler1" in eb.effects)
|
||||
self.assertTrue("handler2" in eb.effects)
|
||||
|
||||
def test_unsubscribe_single_event(self):
|
||||
events = self.createManager()
|
||||
handler = self.makeHandler("handler1")
|
||||
events.subscribe(handler, EventA)
|
||||
events.unsubscribe(handler, EventA)
|
||||
e = EventA()
|
||||
events.fire(e)
|
||||
self.assertEqual(len(e.effects), 0)
|
||||
def test_subscribe_no_events():
|
||||
events = create_manager()
|
||||
|
||||
def test_unsubscribe_event_inheritance(self):
|
||||
events = self.createManager()
|
||||
handler1 = self.makeHandler("handler1")
|
||||
handler2 = self.makeHandler("handler2")
|
||||
events.subscribe(handler1, EventA)
|
||||
events.subscribe(handler2, EventB)
|
||||
events.unsubscribe(handler1, EventA)
|
||||
ea = EventA()
|
||||
events.handle(ea)
|
||||
assert len(ea.effects) == 0
|
||||
|
||||
ea = EventA()
|
||||
events.fire(ea)
|
||||
self.assertEqual(len(ea.effects), 0)
|
||||
|
||||
eb = EventB()
|
||||
events.fire(eb)
|
||||
self.assertEqual(len(eb.effects), 1)
|
||||
self.assertTrue("handler2" in eb.effects)
|
||||
def test_subscribe_base_event():
|
||||
events = create_manager()
|
||||
events.subscribe(make_handler("handler1"), EventA)
|
||||
|
||||
class Event(object):
|
||||
ea = EventB()
|
||||
events.handle(ea)
|
||||
assert len(ea.effects) == 1
|
||||
assert "handler1" in ea.effects
|
||||
|
||||
|
||||
def test_subscribe_event_malformed_multiple_inheritance():
|
||||
events = create_manager()
|
||||
events.subscribe(make_handler("handler1"), EventA)
|
||||
events.subscribe(make_handler("handler2"), EventD)
|
||||
events.subscribe(make_handler("handler3"), EventE)
|
||||
|
||||
ea = EventA()
|
||||
events.handle(ea)
|
||||
assert len(ea.effects) == 1
|
||||
assert "handler1" in ea.effects
|
||||
|
||||
ed = EventD()
|
||||
events.handle(ed)
|
||||
assert len(ed.effects) == 2
|
||||
assert "handler1" in ed.effects
|
||||
assert "handler2" in ed.effects
|
||||
|
||||
ee = EventE()
|
||||
events.handle(ee)
|
||||
assert len(ee.effects) == 3
|
||||
assert "handler1" in ee.effects
|
||||
assert "handler2" in ee.effects
|
||||
assert "handler3" in ee.effects
|
||||
|
||||
|
||||
def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro():
|
||||
events = create_manager()
|
||||
events.subscribe(make_handler("handler1"), Event)
|
||||
events.subscribe(make_handler("handler2"), EventB)
|
||||
|
||||
eb = EventB()
|
||||
events.handle(eb)
|
||||
assert len(eb.effects) == 2
|
||||
assert "handler1" in eb.effects
|
||||
assert "handler2" in eb.effects
|
||||
|
||||
|
||||
def test_unsubscribe_single_event():
|
||||
events = create_manager()
|
||||
handler = make_handler("handler1")
|
||||
events.subscribe(handler, EventA)
|
||||
events.unsubscribe(handler, EventA)
|
||||
e = EventA()
|
||||
events.handle(e)
|
||||
assert len(e.effects) == 0
|
||||
|
||||
|
||||
def test_unsubscribe_event_inheritance():
|
||||
events = create_manager()
|
||||
handler1 = make_handler("handler1")
|
||||
handler2 = make_handler("handler2")
|
||||
events.subscribe(handler1, EventA)
|
||||
events.subscribe(handler2, EventB)
|
||||
events.unsubscribe(handler1, EventA)
|
||||
|
||||
ea = EventA()
|
||||
events.handle(ea)
|
||||
assert len(ea.effects) == 0
|
||||
|
||||
eb = EventB()
|
||||
events.handle(eb)
|
||||
assert len(eb.effects) == 1
|
||||
assert "handler2" in eb.effects
|
||||
|
||||
|
||||
class Event:
|
||||
def __init__(self) -> None:
|
||||
self.effects: List[object] = []
|
||||
|
||||
def __init__(self):
|
||||
self.effects = []
|
||||
|
||||
class EventA(Event):
|
||||
pass
|
||||
|
||||
|
||||
class EventB(EventA):
|
||||
pass
|
||||
|
||||
|
||||
class EventC(Event):
|
||||
pass
|
||||
|
||||
|
||||
class EventD(EventA, EventC):
|
||||
pass
|
||||
|
||||
|
||||
class EventE(EventD, EventA):
|
||||
pass
|
||||
|
@ -1,270 +1,224 @@
|
||||
""" Tests for :module:`generic.multidispatch`."""
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
__all__ = ("DispatcherTests",)
|
||||
|
||||
class DispatcherTests(unittest.TestCase):
|
||||
|
||||
def createDispatcher(self, params_arity, args=None, varargs=None,
|
||||
keywords=None, defaults=None):
|
||||
from inspect import ArgSpec
|
||||
from generic.multidispatch import FunctionDispatcher
|
||||
return FunctionDispatcher(ArgSpec(args=args, varargs=varargs,
|
||||
keywords=keywords,
|
||||
defaults=defaults), params_arity)
|
||||
from inspect import FullArgSpec
|
||||
from generic.multidispatch import multidispatch, FunctionDispatcher
|
||||
|
||||
|
||||
def test_one_argument(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x"])
|
||||
def create_dispatcher(
|
||||
params_arity, args=None, varargs=None, keywords=None, defaults=None
|
||||
) -> FunctionDispatcher:
|
||||
|
||||
dispatcher.register_rule(lambda x: x + 1, int)
|
||||
self.assertEqual(dispatcher(1), 2)
|
||||
self.assertRaises(TypeError, dispatcher, "s")
|
||||
return FunctionDispatcher(
|
||||
FullArgSpec(
|
||||
args=args,
|
||||
varargs=varargs,
|
||||
varkw=keywords,
|
||||
defaults=defaults,
|
||||
kwonlyargs=[],
|
||||
kwonlydefaults={},
|
||||
annotations={},
|
||||
),
|
||||
params_arity,
|
||||
)
|
||||
|
||||
dispatcher.register_rule(lambda x: x + "1", str)
|
||||
self.assertEqual(dispatcher(1), 2)
|
||||
self.assertEqual(dispatcher("1"), "11")
|
||||
self.assertRaises(TypeError, dispatcher, tuple())
|
||||
|
||||
def test_two_arguments(self):
|
||||
dispatcher = self.createDispatcher(2, args=["x", "y"])
|
||||
def test_one_argument():
|
||||
dispatcher = create_dispatcher(1, args=["x"])
|
||||
|
||||
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
|
||||
self.assertEqual(dispatcher(1, 2), 4)
|
||||
self.assertRaises(TypeError, dispatcher, "s", "ss")
|
||||
self.assertRaises(TypeError, dispatcher, 1, "ss")
|
||||
self.assertRaises(TypeError, dispatcher, "s", 2)
|
||||
dispatcher.register_rule(lambda x: x + 1, int)
|
||||
assert dispatcher(1) == 2
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("s")
|
||||
|
||||
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
|
||||
self.assertEqual(dispatcher(1, 2), 4)
|
||||
self.assertEqual(dispatcher("1", "2"), "121")
|
||||
self.assertRaises(TypeError, dispatcher, "1", 1)
|
||||
self.assertRaises(TypeError, dispatcher, 1, "1")
|
||||
dispatcher.register_rule(lambda x: x + "1", str)
|
||||
assert dispatcher(1) == 2
|
||||
assert dispatcher("1") == "11"
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher(tuple())
|
||||
|
||||
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
|
||||
self.assertEqual(dispatcher(1, 2), 4)
|
||||
self.assertEqual(dispatcher("1", "2"), "121")
|
||||
self.assertEqual(dispatcher(1, "2"), "121")
|
||||
self.assertRaises(TypeError, dispatcher, "1", 1)
|
||||
|
||||
def test_bottom_rule(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x"])
|
||||
def test_two_arguments():
|
||||
dispatcher = create_dispatcher(2, args=["x", "y"])
|
||||
|
||||
dispatcher.register_rule(lambda x: x, object)
|
||||
self.assertEqual(dispatcher(1), 1)
|
||||
self.assertEqual(dispatcher("1"), "1")
|
||||
self.assertEqual(dispatcher([1]), [1])
|
||||
self.assertEqual(dispatcher((1,)), (1,))
|
||||
dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
|
||||
assert dispatcher(1, 2) == 4
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("s", "ss")
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher(1, "ss")
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("s", 2)
|
||||
|
||||
def test_subtype_evaluation(self):
|
||||
class Super(object):
|
||||
pass
|
||||
class Sub(Super):
|
||||
pass
|
||||
dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
|
||||
assert dispatcher(1, 2) == 4
|
||||
assert dispatcher("1", "2") == "121"
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("1", 1)
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher(1, "1")
|
||||
|
||||
dispatcher = self.createDispatcher(1, args=["x"])
|
||||
dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
|
||||
assert dispatcher(1, 2) == 4
|
||||
assert dispatcher("1", "2") == "121"
|
||||
assert dispatcher(1, "2") == "121"
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("1", 1)
|
||||
|
||||
dispatcher.register_rule(lambda x: x, Super)
|
||||
o_super = Super()
|
||||
self.assertEqual(dispatcher(o_super), o_super)
|
||||
o_sub = Sub()
|
||||
self.assertEqual(dispatcher(o_sub), o_sub)
|
||||
self.assertRaises(TypeError, dispatcher, object())
|
||||
|
||||
dispatcher.register_rule(lambda x: (x, x), Sub)
|
||||
o_super = Super()
|
||||
self.assertEqual(dispatcher(o_super), o_super)
|
||||
o_sub = Sub()
|
||||
self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
|
||||
def test_bottom_rule():
|
||||
dispatcher = create_dispatcher(1, args=["x"])
|
||||
|
||||
def test_register_rule_with_wrong_arity(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x"])
|
||||
dispatcher.register_rule(lambda x: x, int)
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
dispatcher.register_rule, lambda x, y: x, str)
|
||||
dispatcher.register_rule(lambda x: x, object)
|
||||
assert dispatcher(1) == 1
|
||||
assert dispatcher("1") == "1"
|
||||
assert dispatcher([1]) == [1]
|
||||
assert dispatcher((1,)) == (1,)
|
||||
|
||||
def test_register_rule_with_different_arg_names(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x"])
|
||||
dispatcher.register_rule(lambda y: y, int)
|
||||
self.assertEqual(dispatcher(1), 1)
|
||||
|
||||
def test_dispatching_with_varargs(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x"], varargs="va")
|
||||
dispatcher.register_rule(lambda x, *va: x, int)
|
||||
self.assertEqual(dispatcher(1), 1)
|
||||
self.assertRaises(TypeError, dispatcher, "1", 2, 3)
|
||||
def test_subtype_evaluation():
|
||||
class Super:
|
||||
pass
|
||||
|
||||
def test_dispatching_with_varkw(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x"], keywords="vk")
|
||||
dispatcher.register_rule(lambda x, **vk: x, int)
|
||||
self.assertEqual(dispatcher(1), 1)
|
||||
self.assertRaises(TypeError, dispatcher, "1", a=1, b=2)
|
||||
class Sub(Super):
|
||||
pass
|
||||
|
||||
def test_dispatching_with_kw(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x", "y"], defaults=["vk"])
|
||||
dispatcher.register_rule(lambda x, y=1: x, int)
|
||||
self.assertEqual(dispatcher(1), 1)
|
||||
self.assertRaises(TypeError, dispatcher, "1", k=1)
|
||||
dispatcher = create_dispatcher(1, args=["x"])
|
||||
|
||||
def test_create_dispatcher_with_pos_args_less_multi_arity(self):
|
||||
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x"])
|
||||
self.assertRaises(TypeError, self.createDispatcher, 2, args=["x", "y"],
|
||||
defaults=["x"])
|
||||
dispatcher.register_rule(lambda x: x, Super)
|
||||
o_super = Super()
|
||||
assert dispatcher(o_super) == o_super
|
||||
o_sub = Sub()
|
||||
assert dispatcher(o_sub) == o_sub
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher(object())
|
||||
|
||||
def test_register_rule_with_wrong_number_types_parameters(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x", "y"])
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
dispatcher.register_rule, lambda x, y: x, int, str)
|
||||
dispatcher.register_rule(lambda x: (x, x), Sub)
|
||||
o_super = Super()
|
||||
assert dispatcher(o_super) == o_super
|
||||
o_sub = Sub()
|
||||
assert dispatcher(o_sub) == (o_sub, o_sub)
|
||||
|
||||
def test_register_rule_with_partial_dispatching(self):
|
||||
dispatcher = self.createDispatcher(1, args=["x", "y"])
|
||||
dispatcher.register_rule(lambda x, y: x, int)
|
||||
self.assertEqual(dispatcher(1, 2), 1)
|
||||
self.assertEqual(dispatcher(1, "2"), 1)
|
||||
self.assertRaises(TypeError, dispatcher, "2", 1)
|
||||
|
||||
def test_register_rule_with_wrong_arity():
|
||||
dispatcher = create_dispatcher(1, args=["x"])
|
||||
dispatcher.register_rule(lambda x: x, int)
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher.register_rule(lambda x, y: x, str)
|
||||
self.assertEqual(dispatcher(1, 2), 1)
|
||||
self.assertEqual(dispatcher(1, "2"), 1)
|
||||
self.assertEqual(dispatcher("1", "2"), "1")
|
||||
self.assertEqual(dispatcher("1", 2), "1")
|
||||
|
||||
class MultifunctionTests(unittest.TestCase):
|
||||
|
||||
def test_it(self):
|
||||
from generic.multidispatch import multifunction
|
||||
def test_register_rule_with_different_arg_names():
|
||||
dispatcher = create_dispatcher(1, args=["x"])
|
||||
dispatcher.register_rule(lambda y: y, int)
|
||||
assert dispatcher(1) == 1
|
||||
|
||||
@multifunction(int, str)
|
||||
def func(x, y):
|
||||
return str(x) + y
|
||||
|
||||
self.assertEqual(func(1, "2"), "12")
|
||||
self.assertRaises(TypeError, func, 1, 2)
|
||||
self.assertRaises(TypeError, func, "1", 2)
|
||||
self.assertRaises(TypeError, func, "1", "2")
|
||||
def test_dispatching_with_varargs():
|
||||
dispatcher = create_dispatcher(1, args=["x"], varargs="va")
|
||||
dispatcher.register_rule(lambda x, *va: x, int)
|
||||
assert dispatcher(1) == 1
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("1", 2, 3)
|
||||
|
||||
@func.when(str, str)
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
self.assertEqual(func(1, "2"), "12")
|
||||
self.assertEqual(func("1", "2"), "12")
|
||||
self.assertRaises(TypeError, func, 1, 2)
|
||||
self.assertRaises(TypeError, func, "1", 2)
|
||||
def test_dispatching_with_varkw():
|
||||
dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
|
||||
dispatcher.register_rule(lambda x, **vk: x, int)
|
||||
assert dispatcher(1) == 1
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("1", a=1, b=2)
|
||||
|
||||
def test_overriding(self):
|
||||
from generic.multidispatch import multifunction
|
||||
|
||||
@multifunction(int, str)
|
||||
def func(x, y):
|
||||
return str(x) + y
|
||||
def test_dispatching_with_kw():
|
||||
dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
|
||||
dispatcher.register_rule(lambda x, y=1: x, int)
|
||||
assert dispatcher(1) == 1
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("1", k=1)
|
||||
|
||||
self.assertEqual(func(1, "2"), "12")
|
||||
self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
|
||||
|
||||
@func.override(int, str)
|
||||
def func(x, y):
|
||||
return y + str(x)
|
||||
def test_create_dispatcher_with_pos_args_less_multi_arity():
|
||||
with pytest.raises(TypeError):
|
||||
create_dispatcher(2, args=["x"])
|
||||
with pytest.raises(TypeError):
|
||||
create_dispatcher(2, args=["x", "y"], defaults=["x"])
|
||||
|
||||
self.assertEqual(func(1, "2"), "21")
|
||||
|
||||
class MultimethodTests(unittest.TestCase):
|
||||
def test_register_rule_with_wrong_number_types_parameters():
|
||||
dispatcher = create_dispatcher(1, args=["x", "y"])
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher.register_rule(lambda x, y: x, int, str)
|
||||
|
||||
def test_multimethod(self):
|
||||
from generic.multidispatch import multimethod
|
||||
from generic.multidispatch import has_multimethods
|
||||
|
||||
@has_multimethods
|
||||
class Dummy(object):
|
||||
def test_register_rule_with_partial_dispatching():
|
||||
dispatcher = create_dispatcher(1, args=["x", "y"])
|
||||
dispatcher.register_rule(lambda x, y: x, int)
|
||||
assert dispatcher(1, 2) == 1
|
||||
assert dispatcher(1, "2") == 1
|
||||
with pytest.raises(TypeError):
|
||||
dispatcher("2", 1)
|
||||
dispatcher.register_rule(lambda x, y: x, str)
|
||||
assert dispatcher(1, 2) == 1
|
||||
assert dispatcher(1, "2") == 1
|
||||
assert dispatcher("1", "2") == "1"
|
||||
assert dispatcher("1", 2) == "1"
|
||||
|
||||
@multimethod(int)
|
||||
def foo(self, x):
|
||||
return x + 1
|
||||
|
||||
@foo.when(str)
|
||||
def foo(self, x):
|
||||
return x + "1"
|
||||
def test_default_dispatcher():
|
||||
@multidispatch(int, str)
|
||||
def func(x, y):
|
||||
return str(x) + y
|
||||
|
||||
self.assertEqual(Dummy().foo(1), 2)
|
||||
self.assertEqual(Dummy().foo("1"), "11")
|
||||
self.assertRaises(TypeError, Dummy().foo, [])
|
||||
assert func(1, "2") == "12"
|
||||
with pytest.raises(TypeError):
|
||||
func(1, 2)
|
||||
with pytest.raises(TypeError):
|
||||
func("1", 2)
|
||||
with pytest.raises(TypeError):
|
||||
func("1", "2")
|
||||
|
||||
def test_inheritance(self):
|
||||
from generic.multidispatch import multimethod
|
||||
from generic.multidispatch import has_multimethods
|
||||
|
||||
@has_multimethods
|
||||
class Dummy(object):
|
||||
def test_multiple_functions():
|
||||
@multidispatch(int, str)
|
||||
def func(x, y):
|
||||
return str(x) + y
|
||||
|
||||
@multimethod(int)
|
||||
def foo(self, x):
|
||||
return x + 1
|
||||
@func.register(str, str)
|
||||
def _(x, y):
|
||||
return x + y
|
||||
|
||||
@foo.when(float)
|
||||
def foo(self, x):
|
||||
return x + 1.5
|
||||
assert func(1, "2") == "12"
|
||||
assert func("1", "2") == "12"
|
||||
with pytest.raises(TypeError):
|
||||
func(1, 2)
|
||||
with pytest.raises(TypeError):
|
||||
func("1", 2)
|
||||
|
||||
@has_multimethods
|
||||
class DummySub(Dummy):
|
||||
|
||||
@Dummy.foo.when(str)
|
||||
def foo(self, x):
|
||||
return x + "1"
|
||||
def test_default():
|
||||
@multidispatch()
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
@foo.when(tuple)
|
||||
def foo(self, x):
|
||||
return x + (1,)
|
||||
@func.register(str, str)
|
||||
def _(x, y):
|
||||
return y + x
|
||||
|
||||
@Dummy.foo.when(bool)
|
||||
def foo(self, x):
|
||||
return not x
|
||||
assert func(1, 1) == 2
|
||||
assert func("1", "2") == "21"
|
||||
|
||||
self.assertEqual(Dummy().foo(1), 2)
|
||||
self.assertEqual(Dummy().foo(1.5), 3.0)
|
||||
self.assertRaises(TypeError, Dummy().foo, "1")
|
||||
self.assertEqual(DummySub().foo(1), 2)
|
||||
self.assertEqual(DummySub().foo(1.5), 3.0)
|
||||
self.assertEqual(DummySub().foo("1"), "11")
|
||||
self.assertEqual(DummySub().foo((1,2)), (1,2,1))
|
||||
self.assertEqual(DummySub().foo(True), False)
|
||||
self.assertRaises(TypeError, DummySub().foo, [])
|
||||
|
||||
def test_override(self):
|
||||
from generic.multidispatch import multimethod
|
||||
from generic.multidispatch import has_multimethods
|
||||
def test_on_classes():
|
||||
@multidispatch()
|
||||
class A:
|
||||
def __init__(self, a, b):
|
||||
self.v = a + b
|
||||
|
||||
@has_multimethods
|
||||
class Dummy(object):
|
||||
@A.register(str, str) # type: ignore[attr-defined]
|
||||
class B:
|
||||
def __init__(self, a, b):
|
||||
self.v = b + a
|
||||
|
||||
@multimethod(str, str)
|
||||
def foo(self, x, y):
|
||||
return x + y
|
||||
|
||||
@foo.when(str, str)
|
||||
def foo(self, x, y):
|
||||
return y + x
|
||||
|
||||
self.assertEqual(Dummy().foo("1", "2"), "21")
|
||||
|
||||
def test_inheritance_override(self):
|
||||
from generic.multidispatch import multimethod
|
||||
from generic.multidispatch import has_multimethods
|
||||
|
||||
@has_multimethods
|
||||
class Dummy(object):
|
||||
|
||||
@multimethod(int)
|
||||
def foo(self, x):
|
||||
return x + 1
|
||||
|
||||
@has_multimethods
|
||||
class DummySub(Dummy):
|
||||
|
||||
@Dummy.foo.when(int)
|
||||
def foo(self, x):
|
||||
return x + 2
|
||||
|
||||
self.assertEqual(Dummy().foo(1), 2)
|
||||
self.assertEqual(DummySub().foo(1), 3)
|
||||
assert A(1, 1).v == 2
|
||||
assert A("1", "2").v == "21"
|
||||
|
@ -1,135 +1,147 @@
|
||||
""" Tests for :module:`generic.registry`."""
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
__all__ = ("RegistryTests",)
|
||||
from typing import Union
|
||||
from generic.registry import Registry, SimpleAxis, TypeAxis
|
||||
|
||||
class RegistryTests(unittest.TestCase):
|
||||
|
||||
def test_one_axis_no_specificity(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(('foo', SimpleAxis()))
|
||||
a = object()
|
||||
b = object()
|
||||
registry.register(a)
|
||||
registry.register(b, 'foo')
|
||||
|
||||
self.assertEqual(registry.lookup(), a)
|
||||
self.assertEqual(registry.lookup('foo'), b)
|
||||
self.assertEqual(registry.lookup('bar'), None)
|
||||
|
||||
def test_two_axes(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
from generic.registry import TypeAxis
|
||||
registry = Registry(('type', TypeAxis()),
|
||||
('name', SimpleAxis()))
|
||||
|
||||
target1 = Target('one')
|
||||
registry.register(target1, object)
|
||||
|
||||
target2 = Target('two')
|
||||
registry.register(target2, DummyA)
|
||||
|
||||
target3 = Target('three')
|
||||
registry.register(target3, DummyA, 'foo')
|
||||
|
||||
context1 = object()
|
||||
self.assertEqual(registry.lookup(context1), target1)
|
||||
|
||||
context2 = DummyB()
|
||||
self.assertEqual(registry.lookup(context2), target2)
|
||||
self.assertEqual(registry.lookup(context2, 'foo'), target3)
|
||||
|
||||
target4 = object()
|
||||
registry.register(target4, DummyB)
|
||||
|
||||
self.assertEqual(registry.lookup(context2), target4)
|
||||
self.assertEqual(registry.lookup(context2, 'foo'), target3)
|
||||
|
||||
def test_get_registration(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
from generic.registry import TypeAxis
|
||||
registry = Registry(('type', TypeAxis()),
|
||||
('name', SimpleAxis()))
|
||||
registry.register('one', object)
|
||||
registry.register('two', DummyA, 'foo')
|
||||
self.assertEqual(registry.get_registration(object), 'one')
|
||||
self.assertEqual(registry.get_registration(DummyA, 'foo'), 'two')
|
||||
self.assertEqual(registry.get_registration(object, 'foo'), None)
|
||||
self.assertEqual(registry.get_registration(DummyA), None)
|
||||
|
||||
def test_register_too_many_keys(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(('name', SimpleAxis()))
|
||||
self.assertRaises(ValueError, registry.register, object(),
|
||||
'one', 'two')
|
||||
|
||||
def test_lookup_too_many_keys(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(('name', SimpleAxis()))
|
||||
self.assertRaises(ValueError, registry.lookup, 'one', 'two')
|
||||
|
||||
def test_conflict_error(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(('name', SimpleAxis()))
|
||||
registry.register(object(), name='foo')
|
||||
self.assertRaises(ValueError, registry.register, object(), 'foo')
|
||||
|
||||
def test_override(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(('name', SimpleAxis()))
|
||||
registry.register(1, name='foo')
|
||||
registry.override(2, name='foo')
|
||||
self.assertEqual(registry.lookup('foo'), 2)
|
||||
|
||||
def test_skip_nodes(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(
|
||||
('one', SimpleAxis()),
|
||||
('two', SimpleAxis()),
|
||||
('three', SimpleAxis())
|
||||
)
|
||||
registry.register('foo', one=1, three=3)
|
||||
self.assertEqual(registry.lookup(1, three=3), 'foo')
|
||||
|
||||
def test_miss(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(
|
||||
('one', SimpleAxis()),
|
||||
('two', SimpleAxis()),
|
||||
('three', SimpleAxis())
|
||||
)
|
||||
registry.register('foo', 1, 2)
|
||||
self.assertEqual(registry.lookup(one=1, three=3), None)
|
||||
|
||||
def test_bad_lookup(self):
|
||||
from generic.registry import Registry
|
||||
from generic.registry import SimpleAxis
|
||||
registry = Registry(('name', SimpleAxis()),
|
||||
('grade', SimpleAxis()))
|
||||
self.assertRaises(ValueError, registry.register, 1, foo=1)
|
||||
self.assertRaises(ValueError, registry.lookup, foo=1)
|
||||
self.assertRaises(ValueError, registry.register, 1, 'foo', name='foo')
|
||||
|
||||
class DummyA(object):
|
||||
class DummyA:
|
||||
pass
|
||||
|
||||
|
||||
class DummyB(DummyA):
|
||||
pass
|
||||
|
||||
class Target(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
# Only called if being printed due to a failing test
|
||||
def __repr__(self): #pragma NO COVERAGE
|
||||
return "Target('%s')" % self.name
|
||||
def test_one_axis_no_specificity():
|
||||
registry: Registry[object] = Registry(("foo", SimpleAxis()))
|
||||
a = object()
|
||||
b = object()
|
||||
registry.register(a)
|
||||
registry.register(b, "foo")
|
||||
|
||||
assert registry.lookup() == a
|
||||
assert registry.lookup("foo") == b
|
||||
assert registry.lookup("bar") is None
|
||||
|
||||
|
||||
def test_subtyping_on_axes():
|
||||
registry: Registry[str] = Registry(("type", TypeAxis()))
|
||||
|
||||
target1 = "one"
|
||||
registry.register(target1, object)
|
||||
|
||||
target2 = "two"
|
||||
registry.register(target2, DummyA)
|
||||
|
||||
target3 = "three"
|
||||
registry.register(target3, DummyB)
|
||||
|
||||
assert registry.lookup(object()) == target1
|
||||
assert registry.lookup(DummyA()) == target2
|
||||
assert registry.lookup(DummyB()) == target3
|
||||
|
||||
|
||||
def test_query_subtyping_on_axes():
|
||||
registry: Registry[str] = Registry(("type", TypeAxis()))
|
||||
|
||||
target1 = "one"
|
||||
registry.register(target1, object)
|
||||
|
||||
target2 = "two"
|
||||
registry.register(target2, DummyA)
|
||||
|
||||
target3 = "three"
|
||||
registry.register(target3, DummyB)
|
||||
|
||||
target4 = "four"
|
||||
registry.register(target4, int)
|
||||
|
||||
assert list(registry.query(object())) == [target1]
|
||||
assert list(registry.query(DummyA())) == [target2, target1]
|
||||
assert list(registry.query(DummyB())) == [target3, target2, target1]
|
||||
assert list(registry.query(3)) == [target4, target1]
|
||||
|
||||
|
||||
def test_two_axes():
|
||||
registry: Registry[Union[str, object]] = Registry(
|
||||
("type", TypeAxis()), ("name", SimpleAxis())
|
||||
)
|
||||
|
||||
target1 = "one"
|
||||
registry.register(target1, object)
|
||||
|
||||
target2 = "two"
|
||||
registry.register(target2, DummyA)
|
||||
|
||||
target3 = "three"
|
||||
registry.register(target3, DummyA, "foo")
|
||||
|
||||
context1 = object()
|
||||
assert registry.lookup(context1) == target1
|
||||
|
||||
context2 = DummyB()
|
||||
assert registry.lookup(context2) == target2
|
||||
assert registry.lookup(context2, "foo") == target3
|
||||
|
||||
target4 = object()
|
||||
registry.register(target4, DummyB)
|
||||
|
||||
assert registry.lookup(context2) == target4
|
||||
assert registry.lookup(context2, "foo") == target3
|
||||
|
||||
|
||||
def test_get_registration():
|
||||
registry: Registry[str] = Registry(("type", TypeAxis()), ("name", SimpleAxis()))
|
||||
registry.register("one", object)
|
||||
registry.register("two", DummyA, "foo")
|
||||
assert registry.get_registration(object) == "one"
|
||||
assert registry.get_registration(DummyA, "foo") == "two"
|
||||
assert registry.get_registration(object, "foo") is None
|
||||
assert registry.get_registration(DummyA) is None
|
||||
|
||||
|
||||
def test_register_too_many_keys():
|
||||
registry: Registry[type] = Registry(("name", SimpleAxis()))
|
||||
with pytest.raises(ValueError):
|
||||
registry.register(object, "one", "two")
|
||||
|
||||
|
||||
def test_lookup_too_many_keys():
|
||||
registry: Registry[object] = Registry(("name", SimpleAxis()))
|
||||
with pytest.raises(ValueError):
|
||||
registry.register(registry.lookup("one", "two"))
|
||||
|
||||
|
||||
def test_conflict_error():
|
||||
registry: Registry[Union[object, type]] = Registry(("name", SimpleAxis()))
|
||||
registry.register(object(), name="foo")
|
||||
with pytest.raises(ValueError):
|
||||
registry.register(object, "foo")
|
||||
|
||||
|
||||
def test_skip_nodes():
|
||||
registry: Registry[str] = Registry(
|
||||
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
|
||||
)
|
||||
registry.register("foo", one=1, three=3)
|
||||
assert registry.lookup(1, three=3) == "foo"
|
||||
|
||||
|
||||
def test_miss():
|
||||
registry: Registry[str] = Registry(
|
||||
("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
|
||||
)
|
||||
registry.register("foo", 1, 2)
|
||||
assert registry.lookup(one=1, three=3) is None
|
||||
|
||||
|
||||
def test_bad_lookup():
|
||||
registry: Registry[int] = Registry(("name", SimpleAxis()), ("grade", SimpleAxis()))
|
||||
with pytest.raises(ValueError):
|
||||
registry.register(1, foo=1)
|
||||
with pytest.raises(ValueError):
|
||||
registry.lookup(foo=1)
|
||||
with pytest.raises(ValueError):
|
||||
registry.register(1, "foo", name="foo")
|
||||
|
Loading…
Reference in New Issue
Block a user