2010-07-15 10:39:45 +04:00
""" Tests for :module:`generic.multidispatch`. """
import unittest
2010-07-15 17:05:54 +04:00
__all__ = [ " DispatcherTests " ]
2010-07-15 10:39:45 +04:00
class DispatcherTests ( unittest . TestCase ) :
2010-07-15 17:05:54 +04:00
def createDispatcher ( self , multi_arity , args = None , varargs = None ,
keywords = None , defaults = None ) :
from inspect import ArgSpec
2010-07-15 10:39:45 +04:00
from generic . multidispatch import Dispatcher
2010-07-15 17:05:54 +04:00
return Dispatcher ( ArgSpec ( args = args , varargs = varargs , keywords = keywords ,
defaults = defaults ) , multi_arity )
def test_one_argument ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x + 1 , int )
self . assertEqual ( dispatcher ( 1 ) , 2 )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " s " )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x + " 1 " , str )
self . assertEqual ( dispatcher ( 1 ) , 2 )
self . assertEqual ( dispatcher ( " 1 " ) , " 11 " )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , tuple ( ) )
2010-07-15 10:39:45 +04:00
def test_two_arguments ( self ) :
2010-07-15 17:05:54 +04:00
dispatcher = self . createDispatcher ( 2 , args = [ " x " , " y " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x , y : x + y + 1 , int , int )
self . assertEqual ( dispatcher ( 1 , 2 ) , 4 )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " s " , " ss " )
self . assertRaises ( TypeError , dispatcher , 1 , " ss " )
self . assertRaises ( TypeError , dispatcher , " s " , 2 )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x , y : x + y + " 1 " , str , str )
self . assertEqual ( dispatcher ( 1 , 2 ) , 4 )
self . assertEqual ( dispatcher ( " 1 " , " 2 " ) , " 121 " )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " 1 " , 1 )
self . assertRaises ( TypeError , dispatcher , 1 , " 1 " )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x , y : str ( x ) + y + " 1 " , int , str )
self . assertEqual ( dispatcher ( 1 , 2 ) , 4 )
self . assertEqual ( dispatcher ( " 1 " , " 2 " ) , " 121 " )
self . assertEqual ( dispatcher ( 1 , " 2 " ) , " 121 " )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , " 1 " , 1 )
2010-07-15 10:39:45 +04:00
def test_bottom_rule ( self ) :
2010-07-15 17:05:54 +04:00
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x , object )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertEqual ( dispatcher ( " 1 " ) , " 1 " )
self . assertEqual ( dispatcher ( [ 1 ] ) , [ 1 ] )
self . assertEqual ( dispatcher ( ( 1 , ) ) , ( 1 , ) )
def test_subtype_evaluation ( self ) :
class Super ( object ) :
pass
class Sub ( Super ) :
pass
2010-07-15 17:05:54 +04:00
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x , Super )
o_super = Super ( )
self . assertEqual ( dispatcher ( o_super ) , o_super )
o_sub = Sub ( )
self . assertEqual ( dispatcher ( o_sub ) , o_sub )
2010-07-15 17:05:54 +04:00
self . assertRaises ( TypeError , dispatcher , object ( ) )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : ( x , x ) , Sub )
o_super = Super ( )
self . assertEqual ( dispatcher ( o_super ) , o_super )
o_sub = Sub ( )
self . assertEqual ( dispatcher ( o_sub ) , ( o_sub , o_sub ) )
2010-07-15 17:05:54 +04:00
def test_register_rule_with_wrong_arity ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
2010-07-15 10:39:45 +04:00
dispatcher . register_rule ( lambda x : x , int )
2010-07-15 17:05:54 +04:00
self . assertRaises (
TypeError ,
dispatcher . register_rule , lambda x , y : x , str )
2010-07-15 10:39:45 +04:00
2010-07-15 17:05:54 +04:00
def test_register_rule_with_different_arg_names ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] )
dispatcher . register_rule ( lambda y : y , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
def test_dispatching_with_varargs ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] , varargs = " va " )
dispatcher . register_rule ( lambda x , * va : x , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 1 " , 2 , 3 )
def test_dispatching_with_varkw ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " ] , keywords = " vk " )
dispatcher . register_rule ( lambda x , * * vk : x , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 1 " , a = 1 , b = 2 )
def test_dispatching_with_kw ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " , " y " ] , defaults = [ " vk " ] )
dispatcher . register_rule ( lambda x , y = 1 : x , int )
self . assertEqual ( dispatcher ( 1 ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 1 " , k = 1 )
def test_create_dispatcher_with_pos_args_less_multi_arity ( self ) :
self . assertRaises ( TypeError , self . createDispatcher , 2 , args = [ " x " ] )
self . assertRaises ( TypeError , self . createDispatcher , 2 , args = [ " x " , " y " ] ,
defaults = [ " x " ] )
def test_register_rule_with_wrong_number_types_parameters ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " , " y " ] )
2010-07-15 10:39:45 +04:00
self . assertRaises (
2010-07-15 17:05:54 +04:00
TypeError ,
dispatcher . register_rule , lambda x , y : x , int , str )
def test_register_rule_with_partial_dispatching ( self ) :
dispatcher = self . createDispatcher ( 1 , args = [ " x " , " y " ] )
dispatcher . register_rule ( lambda x , y : x , int )
self . assertEqual ( dispatcher ( 1 , 2 ) , 1 )
self . assertEqual ( dispatcher ( 1 , " 2 " ) , 1 )
self . assertRaises ( TypeError , dispatcher , " 2 " , 1 )
dispatcher . register_rule ( lambda x , y : x , str )
self . assertEqual ( dispatcher ( 1 , 2 ) , 1 )
self . assertEqual ( dispatcher ( 1 , " 2 " ) , 1 )
self . assertEqual ( dispatcher ( " 1 " , " 2 " ) , " 1 " )
self . assertEqual ( dispatcher ( " 1 " , 2 ) , " 1 " )
2010-07-15 10:39:45 +04:00
class TestMultimethod ( unittest . TestCase ) :
def test_it ( self ) :
from generic . multidispatch import multimethod
@multimethod ( int , str )
def func ( x , y ) :
return str ( x ) + y
self . assertEqual ( func ( 1 , " 2 " ) , " 12 " )
self . assertRaises ( TypeError , func , 1 , 2 )
self . assertRaises ( TypeError , func , " 1 " , 2 )
self . assertRaises ( TypeError , func , " 1 " , " 2 " )
2010-07-15 20:35:38 +04:00
@func.when ( str , str )
2010-07-15 10:39:45 +04:00
def func ( x , y ) :
return x + y
self . assertEqual ( func ( 1 , " 2 " ) , " 12 " )
self . assertEqual ( func ( " 1 " , " 2 " ) , " 12 " )
self . assertRaises ( TypeError , func , 1 , 2 )
self . assertRaises ( TypeError , func , " 1 " , 2 )
2010-07-15 21:47:41 +04:00
def test_overriding ( self ) :
2010-07-15 21:54:16 +04:00
# XXX: for now, overriding is not allowed and Value error is raised
2010-07-15 21:47:41 +04:00
# open questions are:
# 1. Should we allow overriding by default.
# a. If yes, should it be implicit or explicit (something like
# Dispatcher.override method)
# b. If no -- what exception we should raise.
from generic . multidispatch import multimethod
@multimethod ( int , str )
def func ( x , y ) :
return str ( x ) + y
self . assertRaises ( ValueError , func . when ( int , str ) , lambda x , y : str ( x ) )