From bb20f6992bb409e8032d8736e4527c7f14b2ce87 Mon Sep 17 00:00:00 2001
From: Arjan Molenaar <gaphor@gmail.com>
Date: Fri, 8 Nov 2019 16:35:11 +0100
Subject: [PATCH] Replace old code by Python3 implementation from Gaphor

Multimethods are missing now. Those have to be added back.
---
 generic/event.py            | 105 ++++------
 generic/multidispatch.py    | 224 +++++++--------------
 generic/registry.py         | 150 ++++++++------
 tests/test_event.py         | 247 +++++++++++++----------
 tests/test_multidispatch.py | 384 ++++++++++++++++--------------------
 tests/test_registry.py      | 260 ++++++++++++------------
 6 files changed, 641 insertions(+), 729 deletions(-)

diff --git a/generic/event.py b/generic/event.py
index 98abe7f..628df8c 100644
--- a/generic/event.py
+++ b/generic/event.py
@@ -2,7 +2,7 @@
 
 This module provides API for event management. There are two APIs provided:
 
-* Global event management API: subscribe, unsubscribe, fire.
+* Global event management API: subscribe, unsubscribe, handle.
 * Local event management API: Manager
 
 If you run only one instance of your application per Python
@@ -12,116 +12,83 @@ to have different configurations for them -- you should use local API
 and have one instance of Manager object per application instance.
 """
 
-from collections import namedtuple
+from typing import Callable, Set, Type
 
-from generic.registry import Registry
-from generic.registry import TypeAxis
+from generic.registry import Registry, TypeAxis
 
-__all__ = ("Manager", "subscribe", "unsubscribe", "fire", "subscriber")
 
-class HandlerSet(namedtuple("HandlerSet", ["parents", "handlers"])):
-    """ Set of handlers for specific type of event.
+__all__ = "Manager"
 
-    This object stores ``handlers`` for specific event type and
-    ``parents`` reference to handler sets of event's supertypes.
-    """
+Event = object
+Handler = Callable[[object], None]
+HandlerSet = Set[Handler]
 
-    @property
-    def all_handlers(self):
-        """ Iterate over own and supertypes' handlers.
 
-        This iterator yields just unique values, so it won't yield the
-        same handler twice, even if it was registered both for some
-        event type and its supertype.
-        """
-        seen = set()
-        seen_add = seen.add
-
-        # yield own handlers first
-        for handler in self.handlers:
-            seen_add(handler)
-            yield handler
-
-        # yield supertypes' handlers then
-        for parent in self.parents:
-            for handler in parent.all_handlers:
-                if not handler in seen:
-                    seen_add(handler)
-                    yield handler
-
-class Manager(object):
+class Manager:
     """ Event manager
 
     Provides API for subscribing for and firing events. There's also global
     event manager instantiated at module level with functions
-    :func:`.subscribe`, :func:`.fire` and decorator :func:`.subscriber` aliased
+    :func:`.subscribe`, :func:`.handle` and decorator :func:`.subscriber` aliased
     to corresponding methods of class.
     """
 
-    def __init__(self):
+    registry: Registry[HandlerSet]
+
+    def __init__(self) -> None:
         axes = (("event_type", TypeAxis()),)
         self.registry = Registry(*axes)
 
-    def subscribe(self, handler, event_type):
+    def subscribe(self, handler: Handler, event_type: Type[Event]) -> None:
         """ Subscribe ``handler`` to specified ``event_type``"""
         handler_set = self.registry.get_registration(event_type)
-        if not handler_set:
+        if handler_set is None:
             handler_set = self._register_handler_set(event_type)
-        handler_set.handlers.add(handler)
+        handler_set.add(handler)
 
-    def unsubscribe(self, handler, event_type):
+    def unsubscribe(self, handler: Handler, event_type: Type[Event]) -> None:
         """ Unsubscribe ``handler`` from ``event_type``"""
         handler_set = self.registry.get_registration(event_type)
-        if handler_set and handler in handler_set.handlers:
-            handler_set.handlers.remove(handler)
+        if handler_set and handler in handler_set:
+            handler_set.remove(handler)
 
-    def fire(self, event):
+    def handle(self, event: Event) -> None:
         """ Fire ``event``
 
         All subscribers will be executed with no determined order.
         """
-        handler_set = self.registry.lookup(event)
-        for handler in handler_set.all_handlers:
-            handler(event)
+        handler_sets = self.registry.query(event)
+        for handler_set in handler_sets:
+            if handler_set:
+                for handler in set(handler_set):
+                    handler(event)
 
-    def _register_handler_set(self, event_type):
-        """ Register new handler set for ``event_type``."""
-        # Collect handler sets for supertypes
-        parent_handler_sets = []
-        parents = event_type.__bases__
-        for parent in parents:
-            parent_handlers = self.registry.get_registration(parent)
-            if parent_handlers is None:
-                parent_handlers = self._register_handler_set(parent)
-            parent_handler_sets.append(parent_handlers)
-
-        handler_set = HandlerSet(parents=parent_handler_sets, handlers=set())
+    def _register_handler_set(self, event_type: Type[Event]) -> HandlerSet:
+        """ Register new handler set for ``event_type``.
+        """
+        handler_set: HandlerSet = set()
         self.registry.register(handler_set, event_type)
         return handler_set
 
-    def subscriber(self, event_type):
+    def subscriber(self, event_type: Type[Event]) -> Callable[[Handler], Handler]:
         """ Decorator for subscribing handlers
 
         Works like this:
 
+            >>> mymanager = Manager()
+            >>> class MyEvent():
+            ...     pass
             >>> @mymanager.subscriber(MyEvent)
             ... def mysubscriber(evt):
             ...     # handle event
             ...     return
 
-            >>> mymanager.fire(MyEvent())
+            >>> mymanager.handle(MyEvent())
 
         """
-        def registrator(func):
+
+        def registrator(func: Handler) -> Handler:
             self.subscribe(func, event_type)
             return func
+
         return registrator
-
-# Global event manager
-_global_manager = Manager()
-
-# Global event management API
-subscribe = _global_manager.subscribe
-unsubscribe = _global_manager.unsubscribe
-fire = _global_manager.fire
-subscriber = _global_manager.subscriber
diff --git a/generic/multidispatch.py b/generic/multidispatch.py
index 294aed9..3b4bb3e 100644
--- a/generic/multidispatch.py
+++ b/generic/multidispatch.py
@@ -1,206 +1,132 @@
-""" Multidispatch for functions and methods"""
+""" Multidispatch for functions and methods.
+
+This code is a Python 3, slimmed down version of the
+generic package by Andrey Popp.
+
+Only the generic function code is left in tact -- no generic methods.
+The interface has been made in line with `functools.singledispatch`.
+
+Note that this module does not support annotated functions.
+"""
+
+from __future__ import annotations
+
+from typing import cast, Any, Callable, Generic, TypeVar, Union
 
 import functools
 import inspect
-import types
-import threading
 
-from generic.registry import Registry
-from generic.registry import TypeAxis
+from generic.registry import Registry, TypeAxis
 
-__all__ = ("multifunction", "multimethod", "has_multimethods")
+__all__ = "multidispatch"
 
-def multifunction(*argtypes):
-    """ Declare function as multifunction
+T = TypeVar("T", bound=Union[Callable[..., Any], type])
+KeyType = Union[type, None]
+
+
+def multidispatch(*argtypes: KeyType) -> Callable[[T], FunctionDispatcher[T]]:
+    """ Declare function as multidispatch
 
     This decorator takes ``argtypes`` argument types and replace decorated
     function with :class:`.FunctionDispatcher` object, which is responsible for
     multiple dispatch feature.
     """
-    def _replace_with_dispatcher(func):
-        dispatcher = _make_dispatcher(FunctionDispatcher, func, len(argtypes))
+
+    def _replace_with_dispatcher(func: T) -> FunctionDispatcher[T]:
+        nonlocal argtypes
+        argspec = inspect.getfullargspec(func)
+        if not argtypes:
+            arity = _arity(argspec)
+            if isinstance(func, type):
+                # It's a class we deal with:
+                arity -= 1
+            argtypes = (object,) * arity
+
+        dispatcher = cast(
+            FunctionDispatcher[T],
+            functools.update_wrapper(FunctionDispatcher(argspec, len(argtypes)), func),
+        )
         dispatcher.register_rule(func, *argtypes)
         return dispatcher
+
     return _replace_with_dispatcher
 
-def multimethod(*argtypes):
-    """ Declare method as multimethod
 
-    This decorator works exactly the same as :func:`.multifunction` decorator
-    but replaces decorated method with :class:`.MethodDispatcher` object
-    instead.
-
-    Should be used only for decorating methods and enclosing class should have
-    :func:`.has_multimethods` decorator.
-    """
-    def _replace_with_dispatcher(func):
-        dispatcher = _make_dispatcher(MethodDispatcher, func, len(argtypes) + 1)
-        dispatcher.register_unbound_rule(func, *argtypes)
-        return dispatcher
-    return _replace_with_dispatcher
-
-def has_multimethods(cls):
-    """ Declare class as one that have multimethods
-
-    Should only be used for decorating classes which have methods decorated with
-    :func:`.multimethod` decorator.
-    """
-    for name, obj in cls.__dict__.items():
-        if isinstance(obj, MethodDispatcher):
-            obj.proceed_unbound_rules(cls)
-    return cls
-
-class FunctionDispatcher(object):
+class FunctionDispatcher(Generic[T]):
     """ Multidispatcher for functions
 
     This object dispatch calls to function by its argument types. Usually it is
-    produced by :func:`.multifunction` decorator.
+    produced by :func:`.multidispatch` decorator.
 
     You should not manually create objects of this type.
     """
 
-    def __init__(self, argspec, params_arity):
+    registry: Registry[T]
+
+    def __init__(self, argspec: inspect.FullArgSpec, params_arity: int) -> None:
         """ Initialize dispatcher with ``argspec`` of type
         :class:`inspect.ArgSpec` and ``params_arity`` that represent number
         params."""
         # Check if we have enough positional arguments for number of type params
-        if arity(argspec) < params_arity:
-            raise TypeError("Not enough positional arguments "
-                            "for number of type parameters provided.")
+        if _arity(argspec) < params_arity:
+            raise TypeError(
+                "Not enough positional arguments "
+                "for number of type parameters provided."
+            )
 
         self.argspec = argspec
         self.params_arity = params_arity
 
-        axis = [("arg_%d" % n, TypeAxis()) for n in range(params_arity)]
+        axis = [(f"arg_{n:d}", TypeAxis()) for n in range(params_arity)]
         self.registry = Registry(*axis)
 
-    def check_rule(self, rule, *argtypes):
+    def check_rule(self, rule: T, *argtypes: KeyType) -> None:
         # Check if we have the right number of parametrized types
         if len(argtypes) != self.params_arity:
-            raise TypeError("Wrong number of type parameters.")
+            raise TypeError(
+                f"Wrong number of type parameters: have {len(argtypes)}, expected {self.params_arity}."
+            )
 
         # Check if we have the same argspec (by number of args)
-        rule_argspec = inspect.getargspec(rule)
-        if not is_equalent_argspecs(rule_argspec, self.argspec):
-            raise TypeError("Rule does not conform "
-                            "to previous implementations.")
+        rule_argspec = inspect.getfullargspec(rule)
+        left_spec = tuple(x and len(x) or 0 for x in rule_argspec[:4])
+        right_spec = tuple(x and len(x) or 0 for x in self.argspec[:4])
+        if left_spec != right_spec:
+            raise TypeError(
+                f"Rule does not conform to previous implementations: {left_spec} != {right_spec}."
+            )
 
-    def register_rule(self, rule, *argtypes):
+    def register_rule(self, rule: T, *argtypes: KeyType) -> None:
         """ Register new ``rule`` for ``argtypes``."""
         self.check_rule(rule, *argtypes)
         self.registry.register(rule, *argtypes)
 
-    def override_rule(self, rule, *argtypes):
-        """ Override ``rule`` for ``argtypes``."""
-        self.check_rule(rule, *argtypes)
-        self.registry.override(rule, *argtypes)
-
-    def lookup_rule(self, *args):
-        """ Lookup rule by ``args``. Returns None if no rule was found."""
-        args = args[:self.params_arity]
-        rule = self.registry.lookup(*args)
-        if rule is None:
-            raise TypeError("No available rule found for %r" % (args,))
-        return rule
-
-    def when(self, *argtypes):
-        """ Decorator for registering new case for multifunction
+    def register(self, *argtypes: KeyType) -> Callable[[T], T]:
+        """ Decorator for registering new case for multidispatch
 
         New case will be registered for types identified by ``argtypes``. The
         length of ``argtypes`` should be equal to the length of ``argtypes``
-        argument were passed corresponding :func:`.multifunction` call, which
-        also indicated the number of arguments multifunction dispatches on.
+        argument were passed corresponding :func:`.multidispatch` call, which
+        also indicated the number of arguments multidispatch dispatches on.
         """
-        def register_rule(func):
+
+        def register_rule(func: T) -> T:
             self.register_rule(func, *argtypes)
-            return self
+            return func
+
         return register_rule
 
-    @property
-    def otherwise(self):
-        """ Decorator which registeres "catch-all" case for multifunction"""
-        def register_rule(func):
-            self.register_rule(func, [object]*self.params_arity)
-            return self
-        return register_rule
-
-    def override(self, *argtypes):
-        """ Decorator for overriding case for ``argtypes``"""
-        def override_rule(func):
-            self.override_rule(func, *argtypes)
-            return self
-        return override_rule
-
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args: Any, **kwargs: Any) -> Any:
         """ Dispatch call to appropriate rule."""
-        rule = self.lookup_rule(*args)
+        trimmed_args = args[: self.params_arity]
+        rule = self.registry.lookup(*trimmed_args)
+        if not rule:
+            raise TypeError(f"No available rule found for {trimmed_args!r}")
         return rule(*args, **kwargs)
 
-class MethodDispatcher(FunctionDispatcher):
-    """ Multiple dispatch for methods
 
-    This object dispatch call to method by its class and arguments types.
-    Usually it is produced by :func:`.multimethod` decorator.
-
-    You should not manually create objects of this type.
-    """
-
-    def __init__(self, argspec, params_arity):
-        FunctionDispatcher.__init__(self, argspec, params_arity)
-
-        # some data, that should be local to thread of execution
-        self.local = threading.local()
-        self.local.unbound_rules = []
-
-    def register_unbound_rule(self, func, *argtypes):
-        """ Register unbound rule that should be processed by
-        ``proceed_unbound_rules`` later."""
-        self.local.unbound_rules.append((argtypes, func))
-
-    def proceed_unbound_rules(self, cls):
-        """ Process all unbound rule by binding them to ``cls`` type."""
-        for argtypes, func in self.local.unbound_rules:
-            argtypes = (cls,) + argtypes
-            self.override_rule(func, *argtypes)
-        self.local.unbound_rules = []
-
-    def __get__(self, obj, cls):
-        if obj is None:
-            return self
-        return types.MethodType(self, obj)
-
-    def when(self, *argtypes):
-        """ Register new case for multimethod for ``argtypes``"""
-        def make_declaration(meth):
-            self.register_unbound_rule(meth, *argtypes)
-            return self
-        return make_declaration
-
-    def override(self, *argtypes):
-        """ Decorator for overriding case for ``argtypes``"""
-        return self.when(*argtypes)
-
-    @property
-    def otherwise(self):
-        """ Decorator which registeres "catch-all" case for multimethod"""
-        def make_declaration(func):
-            self.register_unbound_rule(func, [object]*self.params_arity)
-            return self
-        return make_declaration
-
-def arity(argspec):
+def _arity(argspec: inspect.FullArgSpec) -> int:
     """ Determinal positional arity of argspec."""
     args = argspec.args if argspec.args else []
     defaults = argspec.defaults if argspec.defaults else []
     return len(args) - len(defaults)
-
-def is_equalent_argspecs(left, right):
-    """ Check argspec equalence."""
-    return map(lambda x: len(x) if x else 0, left) == \
-           map(lambda x: len(x) if x else 0, right)
-
-def _make_dispatcher(dispacther_cls, func, params_arity):
-    argspec = inspect.getargspec(func)
-    wrapper = functools.wraps(func)
-    dispatcher = wrapper(dispacther_cls(argspec, params_arity))
-    return dispatcher
diff --git a/generic/registry.py b/generic/registry.py
index 982d53c..9143f7b 100644
--- a/generic/registry.py
+++ b/generic/registry.py
@@ -5,84 +5,101 @@ This implementation was borrowed from happy[1] project by Chris Rossi.
 [1]: http://bitbucket.org/chrisrossi/happy
 """
 
+from __future__ import annotations
+
 __all__ = ("Registry", "SimpleAxis", "TypeAxis")
 
-class Registry(object):
+from typing import (
+    Any,
+    Dict,
+    Generic,
+    KeysView,
+    List,
+    Generator,
+    Optional,
+    Sequence,
+    Tuple,
+    TypeVar,
+    Union,
+)
+
+K = TypeVar("K")
+S = TypeVar("S")
+T = TypeVar("T")
+V = TypeVar("V")
+Axis = Union["SimpleAxis", "TypeAxis"]
+
+
+class Registry(Generic[T]):
     """ Registry implementation."""
 
-    def __init__(self, *axes):
-        self._tree = _TreeNode()
+    def __init__(self, *axes: Tuple[str, Axis]):
+        self._tree: _TreeNode[T] = _TreeNode()
         self._axes = [axis for name, axis in axes]
-        self._axes_dict = dict([
-            (name, (i, axis)) for i, (name, axis) in enumerate(axes)
-        ])
+        self._axes_dict = {name: (i, axis) for i, (name, axis) in enumerate(axes)}
 
-    def register(self, target, *arg_keys, **kw_keys):
-        self._register(target, self._align_with_axes(arg_keys, kw_keys), False)
-
-    def override(self, target, *arg_keys, **kw_keys):
-        self._register(target, self._align_with_axes(arg_keys, kw_keys), True)
-
-    def _register(self, target, keys, override):
+    def register(self, target: T, *arg_keys: K, **kw_keys: K) -> None:
         tree_node = self._tree
-        for key in keys:
-            tree_node = tree_node.setdefault(key, _TreeNode())
+        for key in self._align_with_axes(arg_keys, kw_keys):
+            tree_node = tree_node.setdefault(key, _TreeNode[T]())
 
-        if not override and not tree_node.target is None:
+        if not tree_node.target is None:
             raise ValueError(
-                "Registration conflicts with existing registration.  Use "
-                "override method to override.")
+                f"Registration for {target} conflicts with existing registration {tree_node.target}."
+            )
 
         tree_node.target = target
 
-    def get_registration(self, *arg_keys, **kw_keys):
+    def get_registration(self, *arg_keys: K, **kw_keys: K) -> Optional[T]:
         tree_node = self._tree
         for key in self._align_with_axes(arg_keys, kw_keys):
-            if not tree_node.has_key(key):
+            if not key in tree_node:
                 return None
             tree_node = tree_node[key]
 
         return tree_node.target
 
-    def lookup(self, *arg_objs, **kw_objs):
+    def lookup(self, *arg_objs: V, **kw_objs: V) -> Optional[T]:
+        return next(self.query(*arg_objs, **kw_objs), None)
+
+    def query(self, *arg_objs: V, **kw_objs: V) -> Generator[Optional[T], None, None]:
         objs = self._align_with_axes(arg_objs, kw_objs)
         axes = self._axes
-        return self._lookup(self._tree, objs, axes)
+        return self._query(self._tree, objs, axes)
 
-    def _lookup(self, tree_node, objs, axes):
+    def _query(
+        self, tree_node: _TreeNode[T], objs: Sequence[Optional[V]], axes: Sequence[Axis]
+    ) -> Generator[Optional[T], None, None]:
         """ Recursively traverse registration tree, from left to right, most
         specific to least specific, returning the first target found on a
         matching node.  """
         if not objs:
-            return tree_node.target
+            yield tree_node.target
+        else:
+            obj = objs[0]
 
-        obj = objs[0]
+            # Skip non-participating nodes
+            if obj is None:
+                next_node: Optional[_TreeNode[T]] = tree_node.get(None, None)
+                if next_node is not None:
+                    yield from self._query(next_node, objs[1:], axes[1:])
+            else:
+                # Get matches on this axis and iterate from most to least specific
+                axis = axes[0]
+                for match_key in axis.matches(obj, tree_node.keys()):
+                    yield from self._query(tree_node[match_key], objs[1:], axes[1:])
 
-        # Skip non-participating nodes
-        if obj is None:
-            next_node = tree_node.get(None, None)
-            if next_node is not None:
-                return self._lookup(next_node, objs[1:], axes[1:])
-            return None
-
-        # Get matches on this axis and iterate from most to least specific
-        axis = axes[0]
-        for match_key in axis.matches(obj, tree_node.keys()):
-            target = self._lookup(tree_node[match_key], objs[1:], axes[1:])
-            if target is not None:
-                return target
-
-        return None
-
-    def _align_with_axes(self, args, kw):
+    def _align_with_axes(
+        self, args: Sequence[S], kw: Dict[str, S]
+    ) -> Sequence[Optional[S]]:
         """ Create a list matching up all args and kwargs with their
         corresponding axes, in order, using ``None`` as a placeholder for
         skipped axes.  """
         axes_dict = self._axes_dict
-        aligned = [None for i in xrange(len(axes_dict))]
+        aligned: List[Optional[S]] = [None for i in range(len(axes_dict))]
 
         args_len = len(args)
-        if args_len + len(kw) >  len(aligned):
+        if args_len + len(kw) > len(aligned):
             raise ValueError("Cannot have more arguments than axes.")
 
         for i, arg in enumerate(args):
@@ -91,12 +108,13 @@ class Registry(object):
         for k, v in kw.items():
             i_axis = axes_dict.get(k, None)
             if i_axis is None:
-                raise ValueError("No axis with name: %s" % k)
+                raise ValueError(f"No axis with name: {k}")
 
             i, axis = i_axis
             if aligned[i] is not None:
-                raise ValueError("Axis defined twice between positional and "
-                                 "keyword arguments")
+                raise ValueError(
+                    "Axis defined twice between positional and " "keyword arguments"
+                )
 
             aligned[i] = v
 
@@ -106,13 +124,15 @@ class Registry(object):
 
         return aligned
 
-class _TreeNode(dict):
-    target = None
 
-    def __str__(self):
-        return "<TreeNode %s %s>" % (self.target, dict.__str__(self))
+class _TreeNode(Generic[T], Dict[Any, Any]):
+    target: Optional[T] = None
 
-class SimpleAxis(object):
+    def __str__(self) -> str:
+        return f"<TreeNode {self.target} {dict.__str__(self)}>"
+
+
+class SimpleAxis:
     """ A simple axis where the key into the axis is the same as the object to
     be matched (aka the identity axis). This axis behaves just like a
     dictionary.  You might use this axis if you are interested in registering
@@ -122,21 +142,23 @@ class SimpleAxis(object):
     Subclasses can override the ``get_keys`` method for implementing arbitrary
     axes.
     """
-    def matches(self, obj, keys):
-        for key in self.get_keys(obj):
+
+    def matches(
+        self, obj: object, keys: KeysView[Optional[object]]
+    ) -> Generator[object, None, None]:
+        for key in [obj]:
             if key in keys:
-                yield key
+                yield obj
 
-    def get_keys(self, obj):
-        """
-        Return the keys for the given object that could match this axis, from
-        most specific to least specific.  A convenient override point.
-        """
-        return [obj,]
 
-class TypeAxis(SimpleAxis):
+class TypeAxis:
     """ An axis which matches the class and super classes of an object in
     method resolution order.
     """
-    def get_keys(self, obj):
-        return type(obj).mro()
+
+    def matches(
+        self, obj: object, keys: KeysView[Optional[type]]
+    ) -> Generator[type, None, None]:
+        for key in type(obj).mro():
+            if key in keys:
+                yield key
diff --git a/tests/test_event.py b/tests/test_event.py
index e967bdd..63609d8 100644
--- a/tests/test_event.py
+++ b/tests/test_event.py
@@ -1,150 +1,181 @@
 """ Tests for :module:`generic.event`."""
 
-import unittest
+from __future__ import annotations
 
-__all__ = ("ManagerTests",)
+from typing import Callable, List
+from generic.event import Manager
 
-class ManagerTests(unittest.TestCase):
 
-    def makeHandler(self, effect):
-        return lambda e: e.effects.append(effect)
+def make_handler(effect: object) -> Callable[[Event], None]:
+    return lambda e: e.effects.append(effect)
 
-    def createManager(self):
-        from generic.event import Manager
-        return Manager()
 
-    def test_subscribe_single_event(self):
-        events = self.createManager()
-        events.subscribe(self.makeHandler("handler1"), EventA)
-        e = EventA()
-        events.fire(e)
-        self.assertEqual(len(e.effects), 1)
-        self.assertTrue("handler1" in e.effects)
+def create_manager():
+    return Manager()
 
-    def test_subscribe_via_decorator(self):
-        events = self.createManager()
-        events.subscriber(EventA)(self.makeHandler("handler1"))
-        e = EventA()
-        events.fire(e)
-        self.assertEqual(len(e.effects), 1)
-        self.assertTrue("handler1" in e.effects)
 
-    def test_subscribe_event_inheritance(self):
-        events = self.createManager()
-        events.subscribe(self.makeHandler("handler1"), EventA)
-        events.subscribe(self.makeHandler("handler2"), EventB)
+def test_subscribe_single_event():
+    events = create_manager()
+    events.subscribe(make_handler("handler1"), EventA)
+    e = EventA()
+    events.handle(e)
+    assert len(e.effects) == 1
+    assert "handler1" in e.effects
 
-        ea = EventA()
-        events.fire(ea)
-        self.assertEqual(len(ea.effects), 1)
-        self.assertTrue("handler1" in ea.effects)
 
-        eb = EventB()
-        events.fire(eb)
-        self.assertEqual(len(eb.effects), 2)
-        self.assertTrue("handler1" in eb.effects)
-        self.assertTrue("handler2" in eb.effects)
+def test_subscribe_via_decorator():
+    events = create_manager()
+    events.subscriber(EventA)(make_handler("handler1"))
+    e = EventA()
+    events.handle(e)
+    assert len(e.effects) == 1
+    assert "handler1" in e.effects
 
-    def test_subscribe_event_multiple_inheritance(self):
-        events = self.createManager()
-        events.subscribe(self.makeHandler("handler1"), EventA)
-        events.subscribe(self.makeHandler("handler2"), EventC)
-        events.subscribe(self.makeHandler("handler3"), EventD)
 
-        ea = EventA()
-        events.fire(ea)
-        self.assertEqual(len(ea.effects), 1)
-        self.assertTrue("handler1" in ea.effects)
+def test_subscribe_event_inheritance():
+    events = create_manager()
+    events.subscribe(make_handler("handler1"), EventA)
+    events.subscribe(make_handler("handler2"), EventB)
 
-        ec = EventC()
-        events.fire(ec)
-        self.assertEqual(len(ec.effects), 1)
-        self.assertTrue("handler2" in ec.effects)
+    ea = EventA()
+    events.handle(ea)
+    assert len(ea.effects) == 1
+    assert "handler1" in ea.effects
 
-        ed = EventD()
-        events.fire(ed)
-        self.assertEqual(len(ed.effects), 3)
-        self.assertTrue("handler1" in ed.effects)
-        self.assertTrue("handler2" in ed.effects)
-        self.assertTrue("handler3" in ed.effects)
+    eb = EventB()
+    events.handle(eb)
+    assert len(eb.effects) == 2
+    assert "handler1" in eb.effects
+    assert "handler2" in eb.effects
 
-    def test_subscribe_event_malformed_multiple_inheritance(self):
-        events = self.createManager()
-        events.subscribe(self.makeHandler("handler1"), EventA)
-        events.subscribe(self.makeHandler("handler2"), EventD)
-        events.subscribe(self.makeHandler("handler3"), EventE)
 
-        ea = EventA()
-        events.fire(ea)
-        self.assertEqual(len(ea.effects), 1)
-        self.assertTrue("handler1" in ea.effects)
+def test_subscribe_event_multiple_inheritance():
+    events = create_manager()
+    events.subscribe(make_handler("handler1"), EventA)
+    events.subscribe(make_handler("handler2"), EventC)
+    events.subscribe(make_handler("handler3"), EventD)
 
-        ed = EventD()
-        events.fire(ed)
-        self.assertEqual(len(ed.effects), 2)
-        self.assertTrue("handler1" in ed.effects)
-        self.assertTrue("handler2" in ed.effects)
+    ea = EventA()
+    events.handle(ea)
+    assert len(ea.effects) == 1
+    assert "handler1" in ea.effects
 
-        ee = EventE()
-        events.fire(ee)
-        self.assertEqual(len(ee.effects), 3)
-        self.assertTrue("handler1" in ee.effects)
-        self.assertTrue("handler2" in ee.effects)
-        self.assertTrue("handler3" in ee.effects)
+    ec = EventC()
+    events.handle(ec)
+    assert len(ec.effects) == 1
+    assert "handler2" in ec.effects
 
-    def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro(self):
-        events = self.createManager()
-        events.subscribe(self.makeHandler("handler1"), Event)
-        events.subscribe(self.makeHandler("handler2"), EventB)
+    ed = EventD()
+    events.handle(ed)
+    assert len(ed.effects) == 3
+    assert "handler1" in ed.effects
+    assert "handler2" in ed.effects
+    assert "handler3" in ed.effects
 
-        eb = EventB()
-        events.fire(eb)
-        self.assertEqual(len(eb.effects), 2)
-        self.assertTrue("handler1" in eb.effects)
-        self.assertTrue("handler2" in eb.effects)
 
-    def test_unsubscribe_single_event(self):
-        events = self.createManager()
-        handler = self.makeHandler("handler1")
-        events.subscribe(handler, EventA)
-        events.unsubscribe(handler, EventA)
-        e = EventA()
-        events.fire(e)
-        self.assertEqual(len(e.effects), 0)
+def test_subscribe_no_events():
+    events = create_manager()
 
-    def test_unsubscribe_event_inheritance(self):
-        events = self.createManager()
-        handler1 = self.makeHandler("handler1")
-        handler2 = self.makeHandler("handler2")
-        events.subscribe(handler1, EventA)
-        events.subscribe(handler2, EventB)
-        events.unsubscribe(handler1, EventA)
+    ea = EventA()
+    events.handle(ea)
+    assert len(ea.effects) == 0
 
-        ea = EventA()
-        events.fire(ea)
-        self.assertEqual(len(ea.effects), 0)
 
-        eb = EventB()
-        events.fire(eb)
-        self.assertEqual(len(eb.effects), 1)
-        self.assertTrue("handler2" in eb.effects)
+def test_subscribe_base_event():
+    events = create_manager()
+    events.subscribe(make_handler("handler1"), EventA)
 
-class Event(object):
+    ea = EventB()
+    events.handle(ea)
+    assert len(ea.effects) == 1
+    assert "handler1" in ea.effects
+
+
+def test_subscribe_event_malformed_multiple_inheritance():
+    events = create_manager()
+    events.subscribe(make_handler("handler1"), EventA)
+    events.subscribe(make_handler("handler2"), EventD)
+    events.subscribe(make_handler("handler3"), EventE)
+
+    ea = EventA()
+    events.handle(ea)
+    assert len(ea.effects) == 1
+    assert "handler1" in ea.effects
+
+    ed = EventD()
+    events.handle(ed)
+    assert len(ed.effects) == 2
+    assert "handler1" in ed.effects
+    assert "handler2" in ed.effects
+
+    ee = EventE()
+    events.handle(ee)
+    assert len(ee.effects) == 3
+    assert "handler1" in ee.effects
+    assert "handler2" in ee.effects
+    assert "handler3" in ee.effects
+
+
+def test_subscribe_event_with_no_subscribers_in_the_middle_of_mro():
+    events = create_manager()
+    events.subscribe(make_handler("handler1"), Event)
+    events.subscribe(make_handler("handler2"), EventB)
+
+    eb = EventB()
+    events.handle(eb)
+    assert len(eb.effects) == 2
+    assert "handler1" in eb.effects
+    assert "handler2" in eb.effects
+
+
+def test_unsubscribe_single_event():
+    events = create_manager()
+    handler = make_handler("handler1")
+    events.subscribe(handler, EventA)
+    events.unsubscribe(handler, EventA)
+    e = EventA()
+    events.handle(e)
+    assert len(e.effects) == 0
+
+
+def test_unsubscribe_event_inheritance():
+    events = create_manager()
+    handler1 = make_handler("handler1")
+    handler2 = make_handler("handler2")
+    events.subscribe(handler1, EventA)
+    events.subscribe(handler2, EventB)
+    events.unsubscribe(handler1, EventA)
+
+    ea = EventA()
+    events.handle(ea)
+    assert len(ea.effects) == 0
+
+    eb = EventB()
+    events.handle(eb)
+    assert len(eb.effects) == 1
+    assert "handler2" in eb.effects
+
+
+class Event:
+    def __init__(self) -> None:
+        self.effects: List[object] = []
 
-    def __init__(self):
-        self.effects = []
 
 class EventA(Event):
     pass
 
+
 class EventB(EventA):
     pass
 
+
 class EventC(Event):
     pass
 
+
 class EventD(EventA, EventC):
     pass
 
+
 class EventE(EventD, EventA):
     pass
diff --git a/tests/test_multidispatch.py b/tests/test_multidispatch.py
index bf4b02a..a99d5a0 100644
--- a/tests/test_multidispatch.py
+++ b/tests/test_multidispatch.py
@@ -1,270 +1,224 @@
 """ Tests for :module:`generic.multidispatch`."""
 
-import unittest
+import pytest
 
-__all__ = ("DispatcherTests",)
-
-class DispatcherTests(unittest.TestCase):
-
-    def createDispatcher(self, params_arity, args=None, varargs=None,
-                         keywords=None, defaults=None):
-        from inspect import ArgSpec
-        from generic.multidispatch import FunctionDispatcher
-        return FunctionDispatcher(ArgSpec(args=args, varargs=varargs,
-                                          keywords=keywords,
-                                          defaults=defaults), params_arity)
+from inspect import FullArgSpec
+from generic.multidispatch import multidispatch, FunctionDispatcher
 
 
-    def test_one_argument(self):
-        dispatcher = self.createDispatcher(1, args=["x"])
+def create_dispatcher(
+    params_arity, args=None, varargs=None, keywords=None, defaults=None
+) -> FunctionDispatcher:
 
-        dispatcher.register_rule(lambda x: x + 1, int)
-        self.assertEqual(dispatcher(1), 2)
-        self.assertRaises(TypeError, dispatcher, "s")
+    return FunctionDispatcher(
+        FullArgSpec(
+            args=args,
+            varargs=varargs,
+            varkw=keywords,
+            defaults=defaults,
+            kwonlyargs=[],
+            kwonlydefaults={},
+            annotations={},
+        ),
+        params_arity,
+    )
 
-        dispatcher.register_rule(lambda x: x + "1", str)
-        self.assertEqual(dispatcher(1), 2)
-        self.assertEqual(dispatcher("1"), "11")
-        self.assertRaises(TypeError, dispatcher, tuple())
 
-    def test_two_arguments(self):
-        dispatcher = self.createDispatcher(2, args=["x", "y"])
+def test_one_argument():
+    dispatcher = create_dispatcher(1, args=["x"])
 
-        dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
-        self.assertEqual(dispatcher(1, 2), 4)
-        self.assertRaises(TypeError, dispatcher, "s", "ss")
-        self.assertRaises(TypeError, dispatcher, 1, "ss")
-        self.assertRaises(TypeError, dispatcher, "s", 2)
+    dispatcher.register_rule(lambda x: x + 1, int)
+    assert dispatcher(1) == 2
+    with pytest.raises(TypeError):
+        dispatcher("s")
 
-        dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
-        self.assertEqual(dispatcher(1, 2), 4)
-        self.assertEqual(dispatcher("1", "2"), "121")
-        self.assertRaises(TypeError, dispatcher, "1", 1)
-        self.assertRaises(TypeError, dispatcher, 1, "1")
+    dispatcher.register_rule(lambda x: x + "1", str)
+    assert dispatcher(1) == 2
+    assert dispatcher("1") == "11"
+    with pytest.raises(TypeError):
+        dispatcher(tuple())
 
-        dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
-        self.assertEqual(dispatcher(1, 2), 4)
-        self.assertEqual(dispatcher("1", "2"), "121")
-        self.assertEqual(dispatcher(1, "2"), "121")
-        self.assertRaises(TypeError, dispatcher, "1", 1)
 
-    def test_bottom_rule(self):
-        dispatcher = self.createDispatcher(1, args=["x"])
+def test_two_arguments():
+    dispatcher = create_dispatcher(2, args=["x", "y"])
 
-        dispatcher.register_rule(lambda x: x, object)
-        self.assertEqual(dispatcher(1), 1)
-        self.assertEqual(dispatcher("1"), "1")
-        self.assertEqual(dispatcher([1]), [1])
-        self.assertEqual(dispatcher((1,)), (1,))
+    dispatcher.register_rule(lambda x, y: x + y + 1, int, int)
+    assert dispatcher(1, 2) == 4
+    with pytest.raises(TypeError):
+        dispatcher("s", "ss")
+    with pytest.raises(TypeError):
+        dispatcher(1, "ss")
+    with pytest.raises(TypeError):
+        dispatcher("s", 2)
 
-    def test_subtype_evaluation(self):
-        class Super(object):
-            pass
-        class Sub(Super):
-            pass
+    dispatcher.register_rule(lambda x, y: x + y + "1", str, str)
+    assert dispatcher(1, 2) == 4
+    assert dispatcher("1", "2") == "121"
+    with pytest.raises(TypeError):
+        dispatcher("1", 1)
+    with pytest.raises(TypeError):
+        dispatcher(1, "1")
 
-        dispatcher = self.createDispatcher(1, args=["x"])
+    dispatcher.register_rule(lambda x, y: str(x) + y + "1", int, str)
+    assert dispatcher(1, 2) == 4
+    assert dispatcher("1", "2") == "121"
+    assert dispatcher(1, "2") == "121"
+    with pytest.raises(TypeError):
+        dispatcher("1", 1)
 
-        dispatcher.register_rule(lambda x: x, Super)
-        o_super = Super()
-        self.assertEqual(dispatcher(o_super), o_super)
-        o_sub = Sub()
-        self.assertEqual(dispatcher(o_sub), o_sub)
-        self.assertRaises(TypeError, dispatcher, object())
 
-        dispatcher.register_rule(lambda x: (x, x), Sub)
-        o_super = Super()
-        self.assertEqual(dispatcher(o_super), o_super)
-        o_sub = Sub()
-        self.assertEqual(dispatcher(o_sub), (o_sub, o_sub))
+def test_bottom_rule():
+    dispatcher = create_dispatcher(1, args=["x"])
 
-    def test_register_rule_with_wrong_arity(self):
-        dispatcher = self.createDispatcher(1, args=["x"])
-        dispatcher.register_rule(lambda x: x, int)
-        self.assertRaises(
-            TypeError,
-            dispatcher.register_rule, lambda x, y: x, str)
+    dispatcher.register_rule(lambda x: x, object)
+    assert dispatcher(1) == 1
+    assert dispatcher("1") == "1"
+    assert dispatcher([1]) == [1]
+    assert dispatcher((1,)) == (1,)
 
-    def test_register_rule_with_different_arg_names(self):
-        dispatcher = self.createDispatcher(1, args=["x"])
-        dispatcher.register_rule(lambda y: y, int)
-        self.assertEqual(dispatcher(1), 1)
 
-    def test_dispatching_with_varargs(self):
-        dispatcher = self.createDispatcher(1, args=["x"], varargs="va")
-        dispatcher.register_rule(lambda x, *va: x, int)
-        self.assertEqual(dispatcher(1), 1)
-        self.assertRaises(TypeError, dispatcher, "1", 2, 3)
+def test_subtype_evaluation():
+    class Super:
+        pass
 
-    def test_dispatching_with_varkw(self):
-        dispatcher = self.createDispatcher(1, args=["x"], keywords="vk")
-        dispatcher.register_rule(lambda x, **vk: x, int)
-        self.assertEqual(dispatcher(1), 1)
-        self.assertRaises(TypeError, dispatcher, "1", a=1, b=2)
+    class Sub(Super):
+        pass
 
-    def test_dispatching_with_kw(self):
-        dispatcher = self.createDispatcher(1, args=["x", "y"], defaults=["vk"])
-        dispatcher.register_rule(lambda x, y=1: x, int)
-        self.assertEqual(dispatcher(1), 1)
-        self.assertRaises(TypeError, dispatcher, "1", k=1)
+    dispatcher = create_dispatcher(1, args=["x"])
 
-    def test_create_dispatcher_with_pos_args_less_multi_arity(self):
-        self.assertRaises(TypeError, self.createDispatcher, 2, args=["x"])
-        self.assertRaises(TypeError, self.createDispatcher, 2, args=["x", "y"],
-                         defaults=["x"])
+    dispatcher.register_rule(lambda x: x, Super)
+    o_super = Super()
+    assert dispatcher(o_super) == o_super
+    o_sub = Sub()
+    assert dispatcher(o_sub) == o_sub
+    with pytest.raises(TypeError):
+        dispatcher(object())
 
-    def test_register_rule_with_wrong_number_types_parameters(self):
-        dispatcher = self.createDispatcher(1, args=["x", "y"])
-        self.assertRaises(
-            TypeError,
-            dispatcher.register_rule, lambda x, y: x, int, str)
+    dispatcher.register_rule(lambda x: (x, x), Sub)
+    o_super = Super()
+    assert dispatcher(o_super) == o_super
+    o_sub = Sub()
+    assert dispatcher(o_sub) == (o_sub, o_sub)
 
-    def test_register_rule_with_partial_dispatching(self):
-        dispatcher = self.createDispatcher(1, args=["x", "y"])
-        dispatcher.register_rule(lambda x, y: x, int)
-        self.assertEqual(dispatcher(1, 2), 1)
-        self.assertEqual(dispatcher(1, "2"), 1)
-        self.assertRaises(TypeError, dispatcher, "2", 1)
+
+def test_register_rule_with_wrong_arity():
+    dispatcher = create_dispatcher(1, args=["x"])
+    dispatcher.register_rule(lambda x: x, int)
+    with pytest.raises(TypeError):
         dispatcher.register_rule(lambda x, y: x, str)
-        self.assertEqual(dispatcher(1, 2), 1)
-        self.assertEqual(dispatcher(1, "2"), 1)
-        self.assertEqual(dispatcher("1", "2"), "1")
-        self.assertEqual(dispatcher("1", 2), "1")
 
-class MultifunctionTests(unittest.TestCase):
 
-    def test_it(self):
-        from generic.multidispatch import multifunction
+def test_register_rule_with_different_arg_names():
+    dispatcher = create_dispatcher(1, args=["x"])
+    dispatcher.register_rule(lambda y: y, int)
+    assert dispatcher(1) == 1
 
-        @multifunction(int, str)
-        def func(x, y):
-            return str(x) + y
 
-        self.assertEqual(func(1, "2"), "12")
-        self.assertRaises(TypeError, func, 1, 2)
-        self.assertRaises(TypeError, func, "1", 2)
-        self.assertRaises(TypeError, func, "1", "2")
+def test_dispatching_with_varargs():
+    dispatcher = create_dispatcher(1, args=["x"], varargs="va")
+    dispatcher.register_rule(lambda x, *va: x, int)
+    assert dispatcher(1) == 1
+    with pytest.raises(TypeError):
+        dispatcher("1", 2, 3)
 
-        @func.when(str, str)
-        def func(x, y):
-            return x + y
 
-        self.assertEqual(func(1, "2"), "12")
-        self.assertEqual(func("1", "2"), "12")
-        self.assertRaises(TypeError, func, 1, 2)
-        self.assertRaises(TypeError, func, "1", 2)
+def test_dispatching_with_varkw():
+    dispatcher = create_dispatcher(1, args=["x"], keywords="vk")
+    dispatcher.register_rule(lambda x, **vk: x, int)
+    assert dispatcher(1) == 1
+    with pytest.raises(TypeError):
+        dispatcher("1", a=1, b=2)
 
-    def test_overriding(self):
-        from generic.multidispatch import multifunction
 
-        @multifunction(int, str)
-        def func(x, y):
-            return str(x) + y
+def test_dispatching_with_kw():
+    dispatcher = create_dispatcher(1, args=["x", "y"], defaults=["vk"])
+    dispatcher.register_rule(lambda x, y=1: x, int)
+    assert dispatcher(1) == 1
+    with pytest.raises(TypeError):
+        dispatcher("1", k=1)
 
-        self.assertEqual(func(1, "2"), "12")
-        self.assertRaises(ValueError, func.when(int, str), lambda x, y: str(x))
 
-        @func.override(int, str)
-        def func(x, y):
-            return y + str(x)
+def test_create_dispatcher_with_pos_args_less_multi_arity():
+    with pytest.raises(TypeError):
+        create_dispatcher(2, args=["x"])
+    with pytest.raises(TypeError):
+        create_dispatcher(2, args=["x", "y"], defaults=["x"])
 
-        self.assertEqual(func(1, "2"), "21")
 
-class MultimethodTests(unittest.TestCase):
+def test_register_rule_with_wrong_number_types_parameters():
+    dispatcher = create_dispatcher(1, args=["x", "y"])
+    with pytest.raises(TypeError):
+        dispatcher.register_rule(lambda x, y: x, int, str)
 
-    def test_multimethod(self):
-        from generic.multidispatch import multimethod
-        from generic.multidispatch import has_multimethods
 
-        @has_multimethods
-        class Dummy(object):
+def test_register_rule_with_partial_dispatching():
+    dispatcher = create_dispatcher(1, args=["x", "y"])
+    dispatcher.register_rule(lambda x, y: x, int)
+    assert dispatcher(1, 2) == 1
+    assert dispatcher(1, "2") == 1
+    with pytest.raises(TypeError):
+        dispatcher("2", 1)
+    dispatcher.register_rule(lambda x, y: x, str)
+    assert dispatcher(1, 2) == 1
+    assert dispatcher(1, "2") == 1
+    assert dispatcher("1", "2") == "1"
+    assert dispatcher("1", 2) == "1"
 
-            @multimethod(int)
-            def foo(self, x):
-                return x + 1
 
-            @foo.when(str)
-            def foo(self, x):
-                return x + "1"
+def test_default_dispatcher():
+    @multidispatch(int, str)
+    def func(x, y):
+        return str(x) + y
 
-        self.assertEqual(Dummy().foo(1), 2)
-        self.assertEqual(Dummy().foo("1"), "11")
-        self.assertRaises(TypeError, Dummy().foo, [])
+    assert func(1, "2") == "12"
+    with pytest.raises(TypeError):
+        func(1, 2)
+    with pytest.raises(TypeError):
+        func("1", 2)
+    with pytest.raises(TypeError):
+        func("1", "2")
 
-    def test_inheritance(self):
-        from generic.multidispatch import multimethod
-        from generic.multidispatch import has_multimethods
 
-        @has_multimethods
-        class Dummy(object):
+def test_multiple_functions():
+    @multidispatch(int, str)
+    def func(x, y):
+        return str(x) + y
 
-            @multimethod(int)
-            def foo(self, x):
-                return x + 1
+    @func.register(str, str)
+    def _(x, y):
+        return x + y
 
-            @foo.when(float)
-            def foo(self, x):
-                return x + 1.5
+    assert func(1, "2") == "12"
+    assert func("1", "2") == "12"
+    with pytest.raises(TypeError):
+        func(1, 2)
+    with pytest.raises(TypeError):
+        func("1", 2)
 
-        @has_multimethods
-        class DummySub(Dummy):
 
-            @Dummy.foo.when(str)
-            def foo(self, x):
-                return x + "1"
+def test_default():
+    @multidispatch()
+    def func(x, y):
+        return x + y
 
-            @foo.when(tuple)
-            def foo(self, x):
-                return x + (1,)
+    @func.register(str, str)
+    def _(x, y):
+        return y + x
 
-            @Dummy.foo.when(bool)
-            def foo(self, x):
-                return not x
+    assert func(1, 1) == 2
+    assert func("1", "2") == "21"
 
-        self.assertEqual(Dummy().foo(1), 2)
-        self.assertEqual(Dummy().foo(1.5), 3.0)
-        self.assertRaises(TypeError, Dummy().foo, "1")
-        self.assertEqual(DummySub().foo(1), 2)
-        self.assertEqual(DummySub().foo(1.5), 3.0)
-        self.assertEqual(DummySub().foo("1"), "11")
-        self.assertEqual(DummySub().foo((1,2)), (1,2,1))
-        self.assertEqual(DummySub().foo(True), False)
-        self.assertRaises(TypeError, DummySub().foo, [])
 
-    def test_override(self):
-        from generic.multidispatch import multimethod
-        from generic.multidispatch import has_multimethods
+def test_on_classes():
+    @multidispatch()
+    class A:
+        def __init__(self, a, b):
+            self.v = a + b
 
-        @has_multimethods
-        class Dummy(object):
+    @A.register(str, str)  # type: ignore[attr-defined]
+    class B:
+        def __init__(self, a, b):
+            self.v = b + a
 
-            @multimethod(str, str)
-            def foo(self, x, y):
-                return x + y
-
-            @foo.when(str, str)
-            def foo(self, x, y):
-                return y + x
-
-        self.assertEqual(Dummy().foo("1", "2"), "21")
-
-    def test_inheritance_override(self):
-        from generic.multidispatch import multimethod
-        from generic.multidispatch import has_multimethods
-
-        @has_multimethods
-        class Dummy(object):
-
-            @multimethod(int)
-            def foo(self, x):
-                return x + 1
-
-        @has_multimethods
-        class DummySub(Dummy):
-
-            @Dummy.foo.when(int)
-            def foo(self, x):
-                return x + 2
-
-        self.assertEqual(Dummy().foo(1), 2)
-        self.assertEqual(DummySub().foo(1), 3)
+    assert A(1, 1).v == 2
+    assert A("1", "2").v == "21"
diff --git a/tests/test_registry.py b/tests/test_registry.py
index eda2b80..d02384f 100644
--- a/tests/test_registry.py
+++ b/tests/test_registry.py
@@ -1,135 +1,147 @@
 """ Tests for :module:`generic.registry`."""
 
-import unittest
+import pytest
 
-__all__ = ("RegistryTests",)
+from typing import Union
+from generic.registry import Registry, SimpleAxis, TypeAxis
 
-class RegistryTests(unittest.TestCase):
 
-    def test_one_axis_no_specificity(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(('foo', SimpleAxis()))
-        a = object()
-        b = object()
-        registry.register(a)
-        registry.register(b, 'foo')
-
-        self.assertEqual(registry.lookup(), a)
-        self.assertEqual(registry.lookup('foo'), b)
-        self.assertEqual(registry.lookup('bar'), None)
-
-    def test_two_axes(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        from generic.registry import TypeAxis
-        registry = Registry(('type', TypeAxis()),
-                            ('name', SimpleAxis()))
-
-        target1 = Target('one')
-        registry.register(target1, object)
-
-        target2 = Target('two')
-        registry.register(target2, DummyA)
-
-        target3 = Target('three')
-        registry.register(target3, DummyA, 'foo')
-
-        context1 = object()
-        self.assertEqual(registry.lookup(context1), target1)
-
-        context2 = DummyB()
-        self.assertEqual(registry.lookup(context2), target2)
-        self.assertEqual(registry.lookup(context2, 'foo'), target3)
-
-        target4 = object()
-        registry.register(target4, DummyB)
-
-        self.assertEqual(registry.lookup(context2), target4)
-        self.assertEqual(registry.lookup(context2, 'foo'), target3)
-
-    def test_get_registration(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        from generic.registry import TypeAxis
-        registry = Registry(('type', TypeAxis()),
-                            ('name', SimpleAxis()))
-        registry.register('one', object)
-        registry.register('two', DummyA, 'foo')
-        self.assertEqual(registry.get_registration(object), 'one')
-        self.assertEqual(registry.get_registration(DummyA, 'foo'), 'two')
-        self.assertEqual(registry.get_registration(object, 'foo'), None)
-        self.assertEqual(registry.get_registration(DummyA), None)
-
-    def test_register_too_many_keys(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(('name', SimpleAxis()))
-        self.assertRaises(ValueError, registry.register, object(),
-                          'one', 'two')
-
-    def test_lookup_too_many_keys(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(('name', SimpleAxis()))
-        self.assertRaises(ValueError, registry.lookup, 'one', 'two')
-
-    def test_conflict_error(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(('name', SimpleAxis()))
-        registry.register(object(), name='foo')
-        self.assertRaises(ValueError, registry.register, object(), 'foo')
-
-    def test_override(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(('name', SimpleAxis()))
-        registry.register(1, name='foo')
-        registry.override(2, name='foo')
-        self.assertEqual(registry.lookup('foo'), 2)
-
-    def test_skip_nodes(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(
-            ('one', SimpleAxis()),
-            ('two', SimpleAxis()),
-            ('three', SimpleAxis())
-            )
-        registry.register('foo', one=1, three=3)
-        self.assertEqual(registry.lookup(1, three=3), 'foo')
-
-    def test_miss(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(
-            ('one', SimpleAxis()),
-            ('two', SimpleAxis()),
-            ('three', SimpleAxis())
-            )
-        registry.register('foo', 1, 2)
-        self.assertEqual(registry.lookup(one=1, three=3), None)
-
-    def test_bad_lookup(self):
-        from generic.registry import Registry
-        from generic.registry import SimpleAxis
-        registry = Registry(('name', SimpleAxis()),
-                            ('grade', SimpleAxis()))
-        self.assertRaises(ValueError, registry.register, 1, foo=1)
-        self.assertRaises(ValueError, registry.lookup, foo=1)
-        self.assertRaises(ValueError, registry.register, 1, 'foo', name='foo')
-
-class DummyA(object):
+class DummyA:
     pass
 
+
 class DummyB(DummyA):
     pass
 
-class Target(object):
-    def __init__(self, name):
-        self.name = name
 
-    # Only called if being printed due to a failing test
-    def __repr__(self): #pragma NO COVERAGE
-        return "Target('%s')" % self.name
+def test_one_axis_no_specificity():
+    registry: Registry[object] = Registry(("foo", SimpleAxis()))
+    a = object()
+    b = object()
+    registry.register(a)
+    registry.register(b, "foo")
+
+    assert registry.lookup() == a
+    assert registry.lookup("foo") == b
+    assert registry.lookup("bar") is None
+
+
+def test_subtyping_on_axes():
+    registry: Registry[str] = Registry(("type", TypeAxis()))
+
+    target1 = "one"
+    registry.register(target1, object)
+
+    target2 = "two"
+    registry.register(target2, DummyA)
+
+    target3 = "three"
+    registry.register(target3, DummyB)
+
+    assert registry.lookup(object()) == target1
+    assert registry.lookup(DummyA()) == target2
+    assert registry.lookup(DummyB()) == target3
+
+
+def test_query_subtyping_on_axes():
+    registry: Registry[str] = Registry(("type", TypeAxis()))
+
+    target1 = "one"
+    registry.register(target1, object)
+
+    target2 = "two"
+    registry.register(target2, DummyA)
+
+    target3 = "three"
+    registry.register(target3, DummyB)
+
+    target4 = "four"
+    registry.register(target4, int)
+
+    assert list(registry.query(object())) == [target1]
+    assert list(registry.query(DummyA())) == [target2, target1]
+    assert list(registry.query(DummyB())) == [target3, target2, target1]
+    assert list(registry.query(3)) == [target4, target1]
+
+
+def test_two_axes():
+    registry: Registry[Union[str, object]] = Registry(
+        ("type", TypeAxis()), ("name", SimpleAxis())
+    )
+
+    target1 = "one"
+    registry.register(target1, object)
+
+    target2 = "two"
+    registry.register(target2, DummyA)
+
+    target3 = "three"
+    registry.register(target3, DummyA, "foo")
+
+    context1 = object()
+    assert registry.lookup(context1) == target1
+
+    context2 = DummyB()
+    assert registry.lookup(context2) == target2
+    assert registry.lookup(context2, "foo") == target3
+
+    target4 = object()
+    registry.register(target4, DummyB)
+
+    assert registry.lookup(context2) == target4
+    assert registry.lookup(context2, "foo") == target3
+
+
+def test_get_registration():
+    registry: Registry[str] = Registry(("type", TypeAxis()), ("name", SimpleAxis()))
+    registry.register("one", object)
+    registry.register("two", DummyA, "foo")
+    assert registry.get_registration(object) == "one"
+    assert registry.get_registration(DummyA, "foo") == "two"
+    assert registry.get_registration(object, "foo") is None
+    assert registry.get_registration(DummyA) is None
+
+
+def test_register_too_many_keys():
+    registry: Registry[type] = Registry(("name", SimpleAxis()))
+    with pytest.raises(ValueError):
+        registry.register(object, "one", "two")
+
+
+def test_lookup_too_many_keys():
+    registry: Registry[object] = Registry(("name", SimpleAxis()))
+    with pytest.raises(ValueError):
+        registry.register(registry.lookup("one", "two"))
+
+
+def test_conflict_error():
+    registry: Registry[Union[object, type]] = Registry(("name", SimpleAxis()))
+    registry.register(object(), name="foo")
+    with pytest.raises(ValueError):
+        registry.register(object, "foo")
+
+
+def test_skip_nodes():
+    registry: Registry[str] = Registry(
+        ("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
+    )
+    registry.register("foo", one=1, three=3)
+    assert registry.lookup(1, three=3) == "foo"
+
+
+def test_miss():
+    registry: Registry[str] = Registry(
+        ("one", SimpleAxis()), ("two", SimpleAxis()), ("three", SimpleAxis())
+    )
+    registry.register("foo", 1, 2)
+    assert registry.lookup(one=1, three=3) is None
+
+
+def test_bad_lookup():
+    registry: Registry[int] = Registry(("name", SimpleAxis()), ("grade", SimpleAxis()))
+    with pytest.raises(ValueError):
+        registry.register(1, foo=1)
+    with pytest.raises(ValueError):
+        registry.lookup(foo=1)
+    with pytest.raises(ValueError):
+        registry.register(1, "foo", name="foo")