2010-07-15 10:39:45 +04:00
""" Tests for :module:`generic.multidispatch`. """
import unittest
2011-12-27 16:51:08 +04:00
__all__ = ( " DispatcherTests " , )
2010-07-15 10:39:45 +04:00
class DispatcherTests ( unittest . TestCase ) :
2010-07-16 01:34:15 +04:00
def createDispatcher ( self , params_arity , args = None , varargs = None ,
2010-07-15 17:05:54 +04:00
keywords = None , defaults = None ) :
from inspect import ArgSpec
2010-07-19 13:20:20 +04:00
from generic . multidispatch import FunctionDispatcher
return FunctionDispatcher ( ArgSpec ( args = args , varargs = varargs ,
keywords = keywords ,
defaults = defaults ) , params_arity )
2010-07-15 17:05:54 +04:00
def test_one_argument ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x + 1 , int )
self . assertEqual ( dispatcher ( 1 ) , 2 )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " s " )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x + " 1 " , str )
self . assertEqual ( dispatcher ( 1 ) , 2 )
self . assertEqual ( dispatcher ( " 1 " ) , " 11 " )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , tuple ( ) )
2010-07-15 10:39:45 +04:00
def test_two_arguments ( self ) :
2010-07-15 17:05:54 +04:00
dispatcher = self . createDispatcher ( 2 , args = [ " x " , " y " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x , y : x + y + 1 , int , int )
self . assertEqual ( dispatcher ( 1 , 2 ) , 4 )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " s " , " ss " )
self . assertRaises ( TypeError , dispatcher , 1 , " ss " )
self . assertRaises ( TypeError , dispatcher , " s " , 2 )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x , y : x + y + " 1 " , str , str )
self . assertEqual ( dispatcher ( 1 , 2 ) , 4 )
self . assertEqual ( dispatcher ( " 1 " , " 2 " ) , " 121 " )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " 1 " , 1 )
self . assertRaises ( TypeError , dispatcher , 1 , " 1 " )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x , y : str ( x ) + y + " 1 " , int , str )
self . assertEqual ( dispatcher ( 1 , 2 ) , 4 )
self . assertEqual ( dispatcher ( " 1 " , " 2 " ) , " 121 " )
self . assertEqual ( dispatcher ( 1 , " 2 " ) , " 121 " )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " 1 " , 1 )
2010-07-15 10:39:45 +04:00
def test_bottom_rule ( self ) :
2010-07-15 17:05:54 +04:00
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x , object )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertEqual ( dispatcher ( " 1 " ) , " 1 " )
self . assertEqual ( dispatcher ( [ 1 ] ) , [ 1 ] )
self . assertEqual ( dispatcher ( ( 1 , ) ) , ( 1 , ) )
def test_subtype_evaluation ( self ) :
class Super ( object ) :
pass
class Sub ( Super ) :
pass
2010-07-15 17:05:54 +04:00
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x , Super )
o_super = Super ( )
self . assertEqual ( dispatcher ( o_super ) , o_super )
o_sub = Sub ( )
self . assertEqual ( dispatcher ( o_sub ) , o_sub )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , object ( ) )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : ( x , x ) , Sub )
o_super = Super ( )
self . assertEqual ( dispatcher ( o_super ) , o_super )
o_sub = Sub ( )
self . assertEqual ( dispatcher ( o_sub ) , ( o_sub , o_sub ) )
2010-07-15 17:05:54 +04:00
def test_register_rule_with_wrong_arity ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x , int )
2010-07-15 17:05:54 +04:00
self . assertRaises (
TypeError ,
dispatcher . register_rule , lambda x , y : x , str )
2010-07-15 10:39:45 +04:00
2010-07-15 17:05:54 +04:00
def test_register_rule_with_different_arg_names ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda y : y , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
def test_dispatching_with_varargs ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] , varargs = " va " )
dispatcher . register_rule ( lambda x , * va : x , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 1 " , 2 , 3 )
def test_dispatching_with_varkw ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] , keywords = " vk " )
dispatcher . register_rule ( lambda x , * * vk : x , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 1 " , a = 1 , b = 2 )
def test_dispatching_with_kw ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " , " y " ] , defaults = [ " vk " ] )
dispatcher . register_rule ( lambda x , y = 1 : x , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 1 " , k = 1 )
def test_create_dispatcher_with_pos_args_less_multi_arity ( self ) :
self . assertRaises ( TypeError , self . createDispatcher , 2 , args = [ " x " ] )
self . assertRaises ( TypeError , self . createDispatcher , 2 , args = [ " x " , " y " ] ,
defaults = [ " x " ] )
def test_register_rule_with_wrong_number_types_parameters ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " , " y " ] )
2010-07-15 10:39:45 +04:00
self . assertRaises (
2010-07-15 17:05:54 +04:00
TypeError ,
dispatcher . register_rule , lambda x , y : x , int , str )
def test_register_rule_with_partial_dispatching ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " , " y " ] )
dispatcher . register_rule ( lambda x , y : x , int )
self . assertEqual ( dispatcher ( 1 , 2 ) , 1 )
self . assertEqual ( dispatcher ( 1 , " 2 " ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 2 " , 1 )
dispatcher . register_rule ( lambda x , y : x , str )
self . assertEqual ( dispatcher ( 1 , 2 ) , 1 )
self . assertEqual ( dispatcher ( 1 , " 2 " ) , 1 )
self . assertEqual ( dispatcher ( " 1 " , " 2 " ) , " 1 " )
self . assertEqual ( dispatcher ( " 1 " , 2 ) , " 1 " )
2010-07-15 10:39:45 +04:00
2010-07-18 23:37:55 +04:00
class MultifunctionTests ( unittest . TestCase ) :
2010-07-15 10:39:45 +04:00
def test_it ( self ) :
2010-07-16 10:29:59 +04:00
from generic . multidispatch import multifunction
2010-07-15 10:39:45 +04:00
2010-07-16 10:29:59 +04:00
@multifunction ( int , str )
2010-07-15 10:39:45 +04:00
def func ( x , y ) :
return str ( x ) + y
self . assertEqual ( func ( 1 , " 2 " ) , " 12 " )
self . assertRaises ( TypeError , func , 1 , 2 )
self . assertRaises ( TypeError , func , " 1 " , 2 )
self . assertRaises ( TypeError , func , " 1 " , " 2 " )
2010-07-15 20:35:38 +04:00
@func.when ( str , str )
2010-07-15 10:39:45 +04:00
def func ( x , y ) :
return x + y
self . assertEqual ( func ( 1 , " 2 " ) , " 12 " )
self . assertEqual ( func ( " 1 " , " 2 " ) , " 12 " )
self . assertRaises ( TypeError , func , 1 , 2 )
self . assertRaises ( TypeError , func , " 1 " , 2 )
2010-07-15 21:47:41 +04:00
def test_overriding ( self ) :
2010-07-16 10:29:59 +04:00
from generic . multidispatch import multifunction
2010-07-15 21:47:41 +04:00
2010-07-16 10:29:59 +04:00
@multifunction ( int , str )
2010-07-15 21:47:41 +04:00
def func ( x , y ) :
return str ( x ) + y
2010-07-19 00:39:21 +04:00
self . assertEqual ( func ( 1 , " 2 " ) , " 12 " )
2010-07-15 21:47:41 +04:00
self . assertRaises ( ValueError , func . when ( int , str ) , lambda x , y : str ( x ) )
2010-07-19 00:32:14 +04:00
2010-07-19 00:39:21 +04:00
@func.override ( int , str )
def func ( x , y ) :
return y + str ( x )
self . assertEqual ( func ( 1 , " 2 " ) , " 21 " )
2010-07-19 00:32:14 +04:00
class MultimethodTests ( unittest . TestCase ) :
def test_multimethod ( self ) :
from generic . multidispatch import multimethod
from generic . multidispatch import has_multimethods
@has_multimethods
class Dummy ( object ) :
@multimethod ( int )
def foo ( self , x ) :
return x + 1
@foo.when ( str )
def foo ( self , x ) :
return x + " 1 "
self . assertEqual ( Dummy ( ) . foo ( 1 ) , 2 )
self . assertEqual ( Dummy ( ) . foo ( " 1 " ) , " 11 " )
self . assertRaises ( TypeError , Dummy ( ) . foo , [ ] )
def test_inheritance ( self ) :
from generic . multidispatch import multimethod
from generic . multidispatch import has_multimethods
@has_multimethods
class Dummy ( object ) :
@multimethod ( int )
def foo ( self , x ) :
return x + 1
@foo.when ( float )
def foo ( self , x ) :
return x + 1.5
@has_multimethods
class DummySub ( Dummy ) :
@Dummy.foo.when ( str )
def foo ( self , x ) :
return x + " 1 "
@foo.when ( tuple )
def foo ( self , x ) :
return x + ( 1 , )
@Dummy.foo.when ( bool )
def foo ( self , x ) :
return not x
self . assertEqual ( Dummy ( ) . foo ( 1 ) , 2 )
self . assertEqual ( Dummy ( ) . foo ( 1.5 ) , 3.0 )
self . assertRaises ( TypeError , Dummy ( ) . foo , " 1 " )
self . assertEqual ( DummySub ( ) . foo ( 1 ) , 2 )
self . assertEqual ( DummySub ( ) . foo ( 1.5 ) , 3.0 )
self . assertEqual ( DummySub ( ) . foo ( " 1 " ) , " 11 " )
self . assertEqual ( DummySub ( ) . foo ( ( 1 , 2 ) ) , ( 1 , 2 , 1 ) )
self . assertEqual ( DummySub ( ) . foo ( True ) , False )
self . assertRaises ( TypeError , DummySub ( ) . foo , [ ] )
2010-07-19 00:39:21 +04:00
def test_override ( self ) :
from generic . multidispatch import multimethod
from generic . multidispatch import has_multimethods
@has_multimethods
class Dummy ( object ) :
@multimethod ( str , str )
def foo ( self , x , y ) :
return x + y
@foo.when ( str , str )
def foo ( self , x , y ) :
return y + x
self . assertEqual ( Dummy ( ) . foo ( " 1 " , " 2 " ) , " 21 " )
def test_inheritance_override ( self ) :
from generic . multidispatch import multimethod
from generic . multidispatch import has_multimethods
@has_multimethods
class Dummy ( object ) :
@multimethod ( int )
def foo ( self , x ) :
return x + 1
@has_multimethods
class DummySub ( Dummy ) :
@Dummy.foo.when ( int )
def foo ( self , x ) :
return x + 2
self . assertEqual ( Dummy ( ) . foo ( 1 ) , 2 )
self . assertEqual ( DummySub ( ) . foo ( 1 ) , 3 )