2017-04-18 09:22:06 +03:00
package middlewares
import (
2017-05-03 11:20:33 +03:00
"context"
2017-04-18 09:22:06 +03:00
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
func TestRetry ( t * testing . T ) {
testCases := [ ] struct {
failAtCalls [ ] int
attempts int
responseStatus int
listener * countingRetryListener
retriedCount int
} {
{
failAtCalls : [ ] int { 1 , 2 } ,
attempts : 3 ,
responseStatus : http . StatusOK ,
listener : & countingRetryListener { } ,
retriedCount : 2 ,
} ,
{
failAtCalls : [ ] int { 1 , 2 } ,
attempts : 2 ,
responseStatus : http . StatusBadGateway ,
listener : & countingRetryListener { } ,
retriedCount : 1 ,
} ,
}
for _ , tc := range testCases {
// bind tc locally
tc := tc
tcName := fmt . Sprintf ( "FailAtCalls(%v) RetryAttempts(%v)" , tc . failAtCalls , tc . attempts )
t . Run ( tcName , func ( t * testing . T ) {
t . Parallel ( )
2017-08-18 03:18:02 +03:00
var httpHandler http . Handler = & networkFailingHTTPHandler { failAtCalls : tc . failAtCalls , netErrorRecorder : & DefaultNetErrorRecorder { } }
2017-04-18 09:22:06 +03:00
httpHandler = NewRetry ( tc . attempts , httpHandler , tc . listener )
recorder := httptest . NewRecorder ( )
2017-11-20 11:40:03 +03:00
req , err := http . NewRequest ( http . MethodGet , "http://localhost:3000/ok" , ioutil . NopCloser ( nil ) )
2017-04-18 09:22:06 +03:00
if err != nil {
t . Fatalf ( "could not create request: %+v" , err )
}
httpHandler . ServeHTTP ( recorder , req )
if tc . responseStatus != recorder . Code {
t . Errorf ( "wrong status code %d, want %d" , recorder . Code , tc . responseStatus )
}
if tc . retriedCount != tc . listener . timesCalled {
t . Errorf ( "RetryListener called %d times, want %d times" , tc . listener . timesCalled , tc . retriedCount )
}
} )
}
}
2017-05-03 11:20:33 +03:00
func TestDefaultNetErrorRecorderSuccess ( t * testing . T ) {
boolNetErrorOccurred := false
recorder := DefaultNetErrorRecorder { }
recorder . Record ( context . WithValue ( context . Background ( ) , defaultNetErrCtxKey , & boolNetErrorOccurred ) )
if ! boolNetErrorOccurred {
t . Errorf ( "got %v after recording net error, wanted %v" , boolNetErrorOccurred , true )
}
}
func TestDefaultNetErrorRecorderInvalidValueType ( t * testing . T ) {
stringNetErrorOccured := "nonsense"
recorder := DefaultNetErrorRecorder { }
recorder . Record ( context . WithValue ( context . Background ( ) , defaultNetErrCtxKey , & stringNetErrorOccured ) )
if stringNetErrorOccured != "nonsense" {
t . Errorf ( "got %v after recording net error, wanted %v" , stringNetErrorOccured , "nonsense" )
}
}
func TestDefaultNetErrorRecorderNilValue ( t * testing . T ) {
nilNetErrorOccured := interface { } ( nil )
recorder := DefaultNetErrorRecorder { }
recorder . Record ( context . WithValue ( context . Background ( ) , defaultNetErrCtxKey , & nilNetErrorOccured ) )
if nilNetErrorOccured != interface { } ( nil ) {
t . Errorf ( "got %v after recording net error, wanted %v" , nilNetErrorOccured , interface { } ( nil ) )
}
}
2017-08-28 13:50:02 +03:00
func TestRetryListeners ( t * testing . T ) {
req := httptest . NewRequest ( http . MethodGet , "/" , nil )
retryListeners := RetryListeners { & countingRetryListener { } , & countingRetryListener { } }
retryListeners . Retried ( req , 1 )
retryListeners . Retried ( req , 1 )
for _ , retryListener := range retryListeners {
listener := retryListener . ( * countingRetryListener )
if listener . timesCalled != 2 {
t . Errorf ( "retry listener was called %d times, want %d" , listener . timesCalled , 2 )
}
}
}
2017-04-18 09:22:06 +03:00
// networkFailingHTTPHandler is an http.Handler implementation you can use to test retries.
type networkFailingHTTPHandler struct {
2017-05-03 11:20:33 +03:00
netErrorRecorder NetErrorRecorder
failAtCalls [ ] int
callNumber int
2017-04-18 09:22:06 +03:00
}
func ( handler * networkFailingHTTPHandler ) ServeHTTP ( w http . ResponseWriter , r * http . Request ) {
handler . callNumber ++
for _ , failAtCall := range handler . failAtCalls {
if handler . callNumber == failAtCall {
2017-05-03 11:20:33 +03:00
handler . netErrorRecorder . Record ( r . Context ( ) )
2017-04-18 09:22:06 +03:00
w . WriteHeader ( http . StatusBadGateway )
return
}
}
w . WriteHeader ( http . StatusOK )
}
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
2017-08-28 13:50:02 +03:00
func ( l * countingRetryListener ) Retried ( req * http . Request , attempt int ) {
2017-04-18 09:22:06 +03:00
l . timesCalled ++
}