diff --git a/middlewares/retry.go b/middlewares/retry.go index e2bfdc8af..e4146e9bb 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -2,10 +2,10 @@ package middlewares import ( "bufio" - "context" "io/ioutil" "net" "net/http" + "net/http/httptrace" "github.com/containous/traefik/log" ) @@ -40,11 +40,24 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { attempts := 1 for { - netErrorOccurred := false - // We pass in a pointer to netErrorOccurred so that we can set it to true on network errors - // when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler. - newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred) - retryResponseWriter := newRetryResponseWriter(rw, attempts >= retry.attempts, &netErrorOccurred) + attemptsExhausted := attempts >= retry.attempts + // Websocket requests can't be retried at this point in time. + // This is due to the fact that gorilla/websocket doesn't use the request + // context and so we don't get httptrace information. + // Websocket clients should however retry on their own anyway. + shouldRetry := !attemptsExhausted && !isWebsocketRequest(r) + retryResponseWriter := newRetryResponseWriter(rw, shouldRetry) + + // Disable retries when the backend already received request data + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + retryResponseWriter.DisableRetries() + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + retryResponseWriter.DisableRetries() + }, + } + newCtx := httptrace.WithClientTrace(r.Context(), trace) retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx)) if !retryResponseWriter.ShouldRetry() { @@ -57,31 +70,6 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } } -// netErrorCtxKey is a custom type that is used as key for the context. -type netErrorCtxKey string - -// defaultNetErrCtxKey is the actual key which value is used to record network errors. -var defaultNetErrCtxKey netErrorCtxKey = "NetErrCtxKey" - -// NetErrorRecorder is an interface to record net errors. -type NetErrorRecorder interface { - // Record can be used to signal the retry middleware that an network error happened - // and therefore the request should be retried. - Record(ctx context.Context) -} - -// DefaultNetErrorRecorder is the default NetErrorRecorder implementation. -type DefaultNetErrorRecorder struct{} - -// Record is recording network errors by setting the context value for the defaultNetErrCtxKey to true. -func (DefaultNetErrorRecorder) Record(ctx context.Context) { - val := ctx.Value(defaultNetErrCtxKey) - - if netErrorOccurred, isBoolPointer := val.(*bool); isBoolPointer { - *netErrorOccurred = true - } -} - // RetryListener is used to inform about retry attempts. type RetryListener interface { // Retried will be called when a retry happens, with the request attempt passed to it. @@ -104,13 +92,13 @@ type retryResponseWriter interface { http.ResponseWriter http.Flusher ShouldRetry() bool + DisableRetries() } -func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netErrorOccured *bool) retryResponseWriter { +func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter { responseWriter := &retryResponseWriterWithoutCloseNotify{ - responseWriter: rw, - attemptsExhausted: attemptsExhausted, - netErrorOccured: netErrorOccured, + responseWriter: rw, + shouldRetry: shouldRetry, } if _, ok := rw.(http.CloseNotifier); ok { return &retryResponseWriterWithCloseNotify{responseWriter} @@ -119,13 +107,16 @@ func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netE } type retryResponseWriterWithoutCloseNotify struct { - responseWriter http.ResponseWriter - attemptsExhausted bool - netErrorOccured *bool + responseWriter http.ResponseWriter + shouldRetry bool } func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool { - return *rr.netErrorOccured && !rr.attemptsExhausted + return rr.shouldRetry +} + +func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() { + rr.shouldRetry = false } func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header { @@ -143,6 +134,15 @@ func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) } func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) { + if rr.ShouldRetry() && code == http.StatusServiceUnavailable { + // We get a 503 HTTP Status Code when there is no backend server in the pool + // to which the request could be sent. Also, note that rr.ShouldRetry() + // will never return true in case there was a connetion established to + // the backend server and so we can be sure that the 503 was produced + // inside Traefik already and we don't have to retry in this cases. + rr.DisableRetries() + } + if rr.ShouldRetry() { return } diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go index 4f74efba7..b73044b29 100644 --- a/middlewares/retry_test.go +++ b/middlewares/retry_test.go @@ -1,91 +1,155 @@ package middlewares import ( - "context" - "fmt" - "io/ioutil" "net/http" "net/http/httptest" "testing" + + "github.com/containous/traefik/testhelpers" + "github.com/vulcand/oxy/forward" + "github.com/vulcand/oxy/roundrobin" ) func TestRetry(t *testing.T) { testCases := []struct { - failAtCalls []int - attempts int - responseStatus int - listener *countingRetryListener - retriedCount int + desc string + maxRequestAttempts int + wantRetryAttempts int + wantResponseStatus int + amountFaultyEndpoints int + isWebsocketHandshakeRequest bool }{ { - failAtCalls: []int{1, 2}, - attempts: 3, - responseStatus: http.StatusOK, - listener: &countingRetryListener{}, - retriedCount: 2, + desc: "no retry on success", + maxRequestAttempts: 1, + wantRetryAttempts: 0, + wantResponseStatus: http.StatusOK, + amountFaultyEndpoints: 0, }, { - failAtCalls: []int{1, 2}, - attempts: 2, - responseStatus: http.StatusBadGateway, - listener: &countingRetryListener{}, - retriedCount: 1, + desc: "no retry when max request attempts is one", + maxRequestAttempts: 1, + wantRetryAttempts: 0, + wantResponseStatus: http.StatusInternalServerError, + amountFaultyEndpoints: 1, + }, + { + desc: "one retry when one server is faulty", + maxRequestAttempts: 2, + wantRetryAttempts: 1, + wantResponseStatus: http.StatusOK, + amountFaultyEndpoints: 1, + }, + { + desc: "two retries when two servers are faulty", + maxRequestAttempts: 3, + wantRetryAttempts: 2, + wantResponseStatus: http.StatusOK, + amountFaultyEndpoints: 2, + }, + { + desc: "max attempts exhausted delivers the 5xx response", + maxRequestAttempts: 3, + wantRetryAttempts: 2, + wantResponseStatus: http.StatusInternalServerError, + amountFaultyEndpoints: 3, + }, + { + desc: "websocket request should not be retried", + maxRequestAttempts: 3, + wantRetryAttempts: 0, + wantResponseStatus: http.StatusBadGateway, + amountFaultyEndpoints: 1, + isWebsocketHandshakeRequest: true, }, } - for _, tc := range testCases { - // bind tc locally - tc := tc - tcName := fmt.Sprintf("FailAtCalls(%v) RetryAttempts(%v)", tc.failAtCalls, tc.attempts) + backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + rw.Write([]byte("OK")) + })) - t.Run(tcName, func(t *testing.T) { + forwarder, err := forward.New() + if err != nil { + t.Fatalf("Error creating forwarder: %s", err) + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.desc, func(t *testing.T) { t.Parallel() - var httpHandler http.Handler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls, netErrorRecorder: &DefaultNetErrorRecorder{}} - httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener) + loadBalancer, err := roundrobin.New(forwarder) + if err != nil { + t.Fatalf("Error creating load balancer: %s", err) + } + + basePort := 33444 + for i := 0; i < tc.amountFaultyEndpoints; i++ { + // 192.0.2.0 is a non-routable IP for testing purposes. + // See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928 + // We only use the port specification here because the URL is used as identifier + // in the load balancer and using the exact same URL would not add a new server. + loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i))) + } + + // add the functioning server to the end of the load balancer list + loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL)) + + retryListener := &countingRetryListener{} + retry := NewRetry(tc.maxRequestAttempts, loadBalancer, retryListener) recorder := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "http://localhost:3000/ok", ioutil.NopCloser(nil)) - if err != nil { - t.Fatalf("could not create request: %+v", err) + req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil) + + if tc.isWebsocketHandshakeRequest { + req.Header.Add("Connection", "Upgrade") + req.Header.Add("Upgrade", "websocket") } - httpHandler.ServeHTTP(recorder, req) + retry.ServeHTTP(recorder, req) - if tc.responseStatus != recorder.Code { - t.Errorf("wrong status code %d, want %d", recorder.Code, tc.responseStatus) + if tc.wantResponseStatus != recorder.Code { + t.Errorf("got status code %d, want %d", recorder.Code, tc.wantResponseStatus) } - if tc.retriedCount != tc.listener.timesCalled { - t.Errorf("RetryListener called %d times, want %d times", tc.listener.timesCalled, tc.retriedCount) + if tc.wantRetryAttempts != retryListener.timesCalled { + t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, tc.wantRetryAttempts) } }) } } -func TestDefaultNetErrorRecorderSuccess(t *testing.T) { - boolNetErrorOccurred := false - recorder := DefaultNetErrorRecorder{} - recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &boolNetErrorOccurred)) - if !boolNetErrorOccurred { - t.Errorf("got %v after recording net error, wanted %v", boolNetErrorOccurred, true) +func TestRetryEmptyServerList(t *testing.T) { + forwarder, err := forward.New() + if err != nil { + t.Fatalf("Error creating forwarder: %s", err) } -} -func TestDefaultNetErrorRecorderInvalidValueType(t *testing.T) { - stringNetErrorOccured := "nonsense" - recorder := DefaultNetErrorRecorder{} - recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &stringNetErrorOccured)) - if stringNetErrorOccured != "nonsense" { - t.Errorf("got %v after recording net error, wanted %v", stringNetErrorOccured, "nonsense") + loadBalancer, err := roundrobin.New(forwarder) + if err != nil { + t.Fatalf("Error creating load balancer: %s", err) } -} -func TestDefaultNetErrorRecorderNilValue(t *testing.T) { - nilNetErrorOccured := interface{}(nil) - recorder := DefaultNetErrorRecorder{} - recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &nilNetErrorOccured)) - if nilNetErrorOccured != interface{}(nil) { - t.Errorf("got %v after recording net error, wanted %v", nilNetErrorOccured, interface{}(nil)) + // The EmptyBackendHandler middleware ensures that there is a 503 + // response status set when there is no backend server in the pool. + next := NewEmptyBackendHandler(loadBalancer) + + retryListener := &countingRetryListener{} + retry := NewRetry(3, next, retryListener) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil) + + retry.ServeHTTP(recorder, req) + + const wantResponseStatus = http.StatusServiceUnavailable + if wantResponseStatus != recorder.Code { + t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus) + } + const wantRetryAttempts = 0 + if wantRetryAttempts != retryListener.timesCalled { + t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts) } } @@ -99,33 +163,11 @@ func TestRetryListeners(t *testing.T) { for _, retryListener := range retryListeners { listener := retryListener.(*countingRetryListener) if listener.timesCalled != 2 { - t.Errorf("retry listener was called %d times, want %d", listener.timesCalled, 2) + t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2) } } } -// networkFailingHTTPHandler is an http.Handler implementation you can use to test retries. -type networkFailingHTTPHandler struct { - netErrorRecorder NetErrorRecorder - failAtCalls []int - callNumber int -} - -func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - handler.callNumber++ - - for _, failAtCall := range handler.failAtCalls { - if handler.callNumber == failAtCall { - handler.netErrorRecorder.Record(r.Context()) - - w.WriteHeader(http.StatusBadGateway) - return - } - } - - w.WriteHeader(http.StatusOK) -} - // countingRetryListener is a RetryListener implementation to count the times the Retried fn is called. type countingRetryListener struct { timesCalled int diff --git a/server/errorhandler.go b/server/errorhandler.go deleted file mode 100644 index 80cc9fae6..000000000 --- a/server/errorhandler.go +++ /dev/null @@ -1,40 +0,0 @@ -package server - -import ( - "io" - "net" - "net/http" - - "github.com/containous/traefik/middlewares" -) - -// RecordingErrorHandler is an error handler, implementing the vulcand/oxy -// error handler interface, which is recording network errors by using the netErrorRecorder. -// In addition it sets a proper HTTP status code and body, depending on the type of error occurred. -type RecordingErrorHandler struct { - netErrorRecorder middlewares.NetErrorRecorder -} - -// NewRecordingErrorHandler creates and returns a new instance of RecordingErrorHandler. -func NewRecordingErrorHandler(recorder middlewares.NetErrorRecorder) *RecordingErrorHandler { - return &RecordingErrorHandler{recorder} -} - -func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { - statusCode := http.StatusInternalServerError - - if e, ok := err.(net.Error); ok { - eh.netErrorRecorder.Record(req.Context()) - if e.Timeout() { - statusCode = http.StatusGatewayTimeout - } else { - statusCode = http.StatusBadGateway - } - } else if err == io.EOF { - eh.netErrorRecorder.Record(req.Context()) - statusCode = http.StatusBadGateway - } - - w.WriteHeader(statusCode) - w.Write([]byte(http.StatusText(statusCode))) -} diff --git a/server/errorhandler_test.go b/server/errorhandler_test.go deleted file mode 100644 index 0ff0a0255..000000000 --- a/server/errorhandler_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package server - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "net/http/httptest" - "testing" -) - -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - -func TestServeHTTP(t *testing.T) { - tests := []struct { - name string - err error - wantHTTPStatus int - wantNetErrRecorded bool - }{ - { - name: "net.Error", - err: net.UnknownNetworkError("any network error"), - wantHTTPStatus: http.StatusBadGateway, - wantNetErrRecorded: true, - }, - { - name: "net.Error with Timeout", - err: &timeoutError{}, - wantHTTPStatus: http.StatusGatewayTimeout, - wantNetErrRecorded: true, - }, - { - name: "io.EOF", - err: io.EOF, - wantHTTPStatus: http.StatusBadGateway, - wantNetErrRecorded: true, - }, - { - name: "custom error", - err: errors.New("any error"), - wantHTTPStatus: http.StatusInternalServerError, - wantNetErrRecorded: false, - }, - { - name: "nil error", - err: nil, - wantHTTPStatus: http.StatusInternalServerError, - wantNetErrRecorded: false, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - recorder := httptest.NewRecorder() - - errorRecorder := &netErrorRecorder{} - req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/any", nil) - - recordingErrorHandler := NewRecordingErrorHandler(errorRecorder) - recordingErrorHandler.ServeHTTP(recorder, req, test.err) - - if recorder.Code != test.wantHTTPStatus { - t.Errorf("got HTTP status code %v, wanted %v", recorder.Code, test.wantHTTPStatus) - } - if errorRecorder.netErrorWasRecorded != test.wantNetErrRecorded { - t.Errorf("net error recording wrong, got %v wanted %v", errorRecorder.netErrorWasRecorded, test.wantNetErrRecorded) - } - }) - } -} - -type netErrorRecorder struct { - netErrorWasRecorded bool -} - -func (recorder *netErrorRecorder) Record(ctx context.Context) { - recorder.netErrorWasRecorded = true -} diff --git a/server/server_configuration.go b/server/server_configuration.go index 1e3ede0bf..c8146e44a 100644 --- a/server/server_configuration.go +++ b/server/server_configuration.go @@ -23,7 +23,6 @@ import ( "github.com/eapache/channels" "github.com/urfave/negroni" "github.com/vulcand/oxy/forward" - "github.com/vulcand/oxy/utils" ) // loadConfiguration manages dynamically frontends, backends and TLS configurations @@ -80,7 +79,6 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura } serverEntryPoints := s.buildServerEntryPoints() - errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{}) backendsHandlers := map[string]http.Handler{} backendsHealthCheck := map[string]*healthcheck.BackendConfig{} @@ -92,7 +90,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura for _, frontendName := range frontendNames { frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config, - redirectHandlers, serverEntryPoints, errorHandler, + redirectHandlers, serverEntryPoints, backendsHandlers, backendsHealthCheck) if err != nil { log.Errorf("%v. Skipping frontend %s...", err, frontendName) @@ -131,7 +129,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura func (s *Server) loadFrontendConfig( providerName string, frontendName string, config *types.Configuration, - redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler, + redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig, ) ([]handlerPostConfig, error) { @@ -170,7 +168,7 @@ func (s *Server) loadFrontendConfig( postConfigs = append(postConfigs, postConfig) } - fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier) + fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, responseModifier) if err != nil { return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err) } @@ -222,7 +220,7 @@ func (s *Server) loadFrontendConfig( func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint, frontendName string, frontend *types.Frontend, - errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) { + responseModifier modifyResponse) (http.Handler, error) { roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS) if err != nil { @@ -239,7 +237,6 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration forward.Stream(true), forward.PassHostHeader(frontend.PassHostHeader), forward.RoundTripper(roundTripper), - forward.ErrorHandler(errorHandler), forward.Rewriter(rewriter), forward.ResponseModifier(responseModifier), forward.BufferPool(s.bufferPool),