2020-09-16 02:50:55 +03:00
""" Tests for :module:`generic.multidispatch`. """
2010-07-15 10:39:45 +04:00
2020-08-29 05:08:26 +03:00
from inspect import FullArgSpec
2019-11-08 18:35:11 +03:00
import pytest
2020-08-29 05:08:26 +03:00
from generic . multidispatch import FunctionDispatcher , multidispatch
2019-11-08 18:35:11 +03:00
def create_dispatcher (
params_arity , args = None , varargs = None , keywords = None , defaults = None
) - > FunctionDispatcher :
return FunctionDispatcher (
FullArgSpec (
args = args ,
varargs = varargs ,
varkw = keywords ,
defaults = defaults ,
kwonlyargs = [ ] ,
kwonlydefaults = { } ,
annotations = { } ,
) ,
params_arity ,
)
def test_one_argument ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda x : x + 1 , int )
assert dispatcher ( 1 ) == 2
with pytest . raises ( TypeError ) :
dispatcher ( " s " )
2022-11-20 06:09:49 +03:00
dispatcher . register_rule ( lambda x : f " { x } 1 " , str )
2019-11-08 18:35:11 +03:00
assert dispatcher ( 1 ) == 2
assert dispatcher ( " 1 " ) == " 11 "
with pytest . raises ( TypeError ) :
2023-02-03 06:06:51 +03:00
dispatcher ( ( ) )
2019-11-08 18:35:11 +03:00
def test_two_arguments ( ) :
dispatcher = create_dispatcher ( 2 , args = [ " x " , " y " ] )
dispatcher . register_rule ( lambda x , y : x + y + 1 , int , int )
assert dispatcher ( 1 , 2 ) == 4
with pytest . raises ( TypeError ) :
dispatcher ( " s " , " ss " )
with pytest . raises ( TypeError ) :
dispatcher ( 1 , " ss " )
with pytest . raises ( TypeError ) :
dispatcher ( " s " , 2 )
dispatcher . register_rule ( lambda x , y : x + y + " 1 " , str , str )
assert dispatcher ( 1 , 2 ) == 4
assert dispatcher ( " 1 " , " 2 " ) == " 121 "
with pytest . raises ( TypeError ) :
dispatcher ( " 1 " , 1 )
with pytest . raises ( TypeError ) :
dispatcher ( 1 , " 1 " )
dispatcher . register_rule ( lambda x , y : str ( x ) + y + " 1 " , int , str )
assert dispatcher ( 1 , 2 ) == 4
assert dispatcher ( " 1 " , " 2 " ) == " 121 "
assert dispatcher ( 1 , " 2 " ) == " 121 "
with pytest . raises ( TypeError ) :
dispatcher ( " 1 " , 1 )
def test_bottom_rule ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda x : x , object )
assert dispatcher ( 1 ) == 1
assert dispatcher ( " 1 " ) == " 1 "
assert dispatcher ( [ 1 ] ) == [ 1 ]
assert dispatcher ( ( 1 , ) ) == ( 1 , )
def test_subtype_evaluation ( ) :
class Super :
pass
class Sub ( Super ) :
pass
dispatcher = create_dispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda x : x , Super )
o_super = Super ( )
assert dispatcher ( o_super ) == o_super
o_sub = Sub ( )
assert dispatcher ( o_sub ) == o_sub
with pytest . raises ( TypeError ) :
dispatcher ( object ( ) )
dispatcher . register_rule ( lambda x : ( x , x ) , Sub )
o_super = Super ( )
assert dispatcher ( o_super ) == o_super
o_sub = Sub ( )
assert dispatcher ( o_sub ) == ( o_sub , o_sub )
def test_register_rule_with_wrong_arity ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda x : x , int )
with pytest . raises ( TypeError ) :
2010-07-15 17:05:54 +04:00
dispatcher . register_rule ( lambda x , y : x , str )
2010-07-15 10:39:45 +04:00
2019-11-08 18:35:11 +03:00
def test_register_rule_with_different_arg_names ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda y : y , int )
assert dispatcher ( 1 ) == 1
2010-07-15 10:39:45 +04:00
2019-11-08 18:35:11 +03:00
def test_dispatching_with_varargs ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " ] , varargs = " va " )
dispatcher . register_rule ( lambda x , * va : x , int )
assert dispatcher ( 1 ) == 1
with pytest . raises ( TypeError ) :
dispatcher ( " 1 " , 2 , 3 )
2010-07-15 10:39:45 +04:00
2019-11-08 18:35:11 +03:00
def test_dispatching_with_varkw ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " ] , keywords = " vk " )
dispatcher . register_rule ( lambda x , * * vk : x , int )
assert dispatcher ( 1 ) == 1
with pytest . raises ( TypeError ) :
dispatcher ( " 1 " , a = 1 , b = 2 )
2010-07-15 21:47:41 +04:00
2019-11-08 18:35:11 +03:00
def test_dispatching_with_kw ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " , " y " ] , defaults = [ " vk " ] )
dispatcher . register_rule ( lambda x , y = 1 : x , int )
assert dispatcher ( 1 ) == 1
with pytest . raises ( TypeError ) :
dispatcher ( " 1 " , k = 1 )
2010-07-15 21:47:41 +04:00
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
def test_create_dispatcher_with_pos_args_less_multi_arity ( ) :
with pytest . raises ( TypeError ) :
create_dispatcher ( 2 , args = [ " x " ] )
with pytest . raises ( TypeError ) :
create_dispatcher ( 2 , args = [ " x " , " y " ] , defaults = [ " x " ] )
2010-07-19 00:39:21 +04:00
2019-11-08 18:35:11 +03:00
def test_register_rule_with_wrong_number_types_parameters ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " , " y " ] )
with pytest . raises ( TypeError ) :
dispatcher . register_rule ( lambda x , y : x , int , str )
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
def test_register_rule_with_partial_dispatching ( ) :
dispatcher = create_dispatcher ( 1 , args = [ " x " , " y " ] )
dispatcher . register_rule ( lambda x , y : x , int )
assert dispatcher ( 1 , 2 ) == 1
assert dispatcher ( 1 , " 2 " ) == 1
with pytest . raises ( TypeError ) :
dispatcher ( " 2 " , 1 )
dispatcher . register_rule ( lambda x , y : x , str )
assert dispatcher ( 1 , 2 ) == 1
assert dispatcher ( 1 , " 2 " ) == 1
assert dispatcher ( " 1 " , " 2 " ) == " 1 "
assert dispatcher ( " 1 " , 2 ) == " 1 "
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
def test_default_dispatcher ( ) :
@multidispatch ( int , str )
def func ( x , y ) :
return str ( x ) + y
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
assert func ( 1 , " 2 " ) == " 12 "
with pytest . raises ( TypeError ) :
func ( 1 , 2 )
with pytest . raises ( TypeError ) :
func ( " 1 " , 2 )
with pytest . raises ( TypeError ) :
func ( " 1 " , " 2 " )
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
def test_multiple_functions ( ) :
@multidispatch ( int , str )
def func ( x , y ) :
return str ( x ) + y
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
@func.register ( str , str )
def _ ( x , y ) :
return x + y
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
assert func ( 1 , " 2 " ) == " 12 "
assert func ( " 1 " , " 2 " ) == " 12 "
with pytest . raises ( TypeError ) :
func ( 1 , 2 )
with pytest . raises ( TypeError ) :
func ( " 1 " , 2 )
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
def test_default ( ) :
@multidispatch ( )
def func ( x , y ) :
return x + y
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
@func.register ( str , str )
def _ ( x , y ) :
return y + x
2010-07-19 00:32:14 +04:00
2019-11-08 18:35:11 +03:00
assert func ( 1 , 1 ) == 2
assert func ( " 1 " , " 2 " ) == " 21 "
2010-07-19 00:32:14 +04:00
2010-07-19 00:39:21 +04:00
2019-11-08 18:35:11 +03:00
def test_on_classes ( ) :
@multidispatch ( )
class A :
def __init__ ( self , a , b ) :
self . v = a + b
@A.register ( str , str ) # type: ignore[attr-defined]
class B :
def __init__ ( self , a , b ) :
self . v = b + a
assert A ( 1 , 1 ) . v == 2
assert A ( " 1 " , " 2 " ) . v == " 21 "