diff --git a/pkg/middlewares/retry/retry.go b/pkg/middlewares/retry/retry.go index 5c95486d3..3e8167ba9 100644 --- a/pkg/middlewares/retry/retry.go +++ b/pkg/middlewares/retry/retry.go @@ -37,6 +37,48 @@ type Listener interface { // each of them about a retry attempt. type Listeners []Listener +// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries. +func (l Listeners) Retried(req *http.Request, attempt int) { + for _, listener := range l { + listener.Retried(req, attempt) + } +} + +type shouldRetryContextKey struct{} + +// ShouldRetry is a function allowing to enable/disable the retry middleware mechanism. +type ShouldRetry func(shouldRetry bool) + +// ContextShouldRetry returns the ShouldRetry function if it has been set by the Retry middleware in the chain. +func ContextShouldRetry(ctx context.Context) ShouldRetry { + f, _ := ctx.Value(shouldRetryContextKey{}).(ShouldRetry) + return f +} + +// WrapHandler wraps a given http.Handler to inject the httptrace.ClientTrace in the request context when it is needed +// by the retry middleware. +func WrapHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if shouldRetry := ContextShouldRetry(req.Context()); shouldRetry != nil { + shouldRetry(true) + + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + shouldRetry(false) + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + shouldRetry(false) + }, + } + newCtx := httptrace.WithClientTrace(req.Context(), trace) + next.ServeHTTP(rw, req.WithContext(newCtx)) + return + } + + next.ServeHTTP(rw, req) + }) +} + // retry is a middleware that retries requests. type retry struct { attempts int @@ -83,19 +125,13 @@ func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) { attempts := 1 operation := func() error { - shouldRetry := attempts < r.attempts - retryResponseWriter := newResponseWriter(rw, shouldRetry) + remainAttempts := attempts < r.attempts + retryResponseWriter := newResponseWriter(rw) - // Disable retries when the backend already received request data - trace := &httptrace.ClientTrace{ - WroteHeaders: func() { - retryResponseWriter.DisableRetries() - }, - WroteRequest: func(httptrace.WroteRequestInfo) { - retryResponseWriter.DisableRetries() - }, + var shouldRetry ShouldRetry = func(shouldRetry bool) { + retryResponseWriter.SetShouldRetry(remainAttempts && shouldRetry) } - newCtx := httptrace.WithClientTrace(req.Context(), trace) + newCtx := context.WithValue(req.Context(), shouldRetryContextKey{}, shouldRetry) r.next.ServeHTTP(retryResponseWriter, req.Clone(newCtx)) @@ -142,25 +178,17 @@ func (r *retry) newBackOff() backoff.BackOff { return b } -// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries. -func (l Listeners) Retried(req *http.Request, attempt int) { - for _, listener := range l { - listener.Retried(req, attempt) - } -} - type responseWriter interface { http.ResponseWriter http.Flusher ShouldRetry() bool - DisableRetries() + SetShouldRetry(shouldRetry bool) } -func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter { +func newResponseWriter(rw http.ResponseWriter) responseWriter { responseWriter := &responseWriterWithoutCloseNotify{ responseWriter: rw, headers: make(http.Header), - shouldRetry: shouldRetry, } if _, ok := rw.(http.CloseNotifier); ok { return &responseWriterWithCloseNotify{ @@ -181,8 +209,8 @@ func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool { return r.shouldRetry } -func (r *responseWriterWithoutCloseNotify) DisableRetries() { - r.shouldRetry = false +func (r *responseWriterWithoutCloseNotify) SetShouldRetry(shouldRetry bool) { + r.shouldRetry = shouldRetry } func (r *responseWriterWithoutCloseNotify) Header() http.Header { @@ -193,7 +221,7 @@ func (r *responseWriterWithoutCloseNotify) Header() http.Header { } func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) { - if r.ShouldRetry() { + if r.shouldRetry { return len(buf), nil } if !r.written { @@ -203,16 +231,7 @@ func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) { } func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) { - if r.ShouldRetry() && code == http.StatusServiceUnavailable { - // We get a 503 HTTP Status Code when there is no backend server in the pool - // to which the request could be sent. Also, note that r.ShouldRetry() - // will never return true in case there was a connection established to - // the backend server and so we can be sure that the 503 was produced - // inside Traefik already and we don't have to retry in this cases. - r.DisableRetries() - } - - if r.ShouldRetry() || r.written { + if r.shouldRetry || r.written { return } diff --git a/pkg/middlewares/retry/retry_test.go b/pkg/middlewares/retry/retry_test.go index 8d493cf11..b7332fe1e 100644 --- a/pkg/middlewares/retry/retry_test.go +++ b/pkg/middlewares/retry/retry_test.go @@ -105,12 +105,21 @@ func TestRetry(t *testing.T) { t.Parallel() retryAttempts := 0 - next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // This signals that a connection will be established with the backend + // to enable the Retry middleware mechanism. + shouldRetry := ContextShouldRetry(req.Context()) + if shouldRetry != nil { + shouldRetry(true) + } + retryAttempts++ if retryAttempts > test.amountFaultyEndpoints { - // calls WroteHeaders on httptrace. - _ = r.Write(io.Discard) + // This signals that request headers have been sent to the backend. + if shouldRetry != nil { + shouldRetry(false) + } rw.WriteHeader(http.StatusOK) return @@ -152,26 +161,16 @@ func TestRetryEmptyServerList(t *testing.T) { assert.Equal(t, 0, retryListener.timesCalled) } -func TestRetryListeners(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/", nil) - retryListeners := Listeners{&countingRetryListener{}, &countingRetryListener{}} - - retryListeners.Retried(req, 1) - retryListeners.Retried(req, 1) - - for _, retryListener := range retryListeners { - listener := retryListener.(*countingRetryListener) - if listener.timesCalled != 2 { - t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2) - } - } -} - func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { attempt := 0 expectedHeaderValue := "bar" next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + shouldRetry := ContextShouldRetry(req.Context()) + if shouldRetry != nil { + shouldRetry(true) + } + headerName := fmt.Sprintf("X-Foo-Test-%d", attempt) rw.Header().Add(headerName, expectedHeaderValue) if attempt < 2 { @@ -179,9 +178,8 @@ func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { return } - // Request has been successfully written to backend. - trace := httptrace.ContextClientTrace(req.Context()) - trace.WroteHeaders() + // Request has been successfully written to backend + shouldRetry(false) // And we decide to answer to client. rw.WriteHeader(http.StatusNoContent) @@ -212,9 +210,10 @@ func TestRetryShouldNotLooseHeadersOnWrite(t *testing.T) { rw.Header().Add("X-Foo-Test", "bar") // Request has been successfully written to backend. - trace := httptrace.ContextClientTrace(req.Context()) - trace.WroteHeaders() - + shouldRetry := ContextShouldRetry(req.Context()) + if shouldRetry != nil { + shouldRetry(false) + } // And we decide to answer to client without calling WriteHeader. _, err := rw.Write([]byte("bar")) require.NoError(t, err) @@ -285,12 +284,24 @@ func TestRetryWebsocket(t *testing.T) { t.Parallel() retryAttempts := 0 - next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // This signals that a connection will be established with the backend + // to enable the Retry middleware mechanism. + shouldRetry := ContextShouldRetry(req.Context()) + if shouldRetry != nil { + shouldRetry(true) + } + retryAttempts++ if retryAttempts > test.amountFaultyEndpoints { + // This signals that request headers have been sent to the backend. + if shouldRetry != nil { + shouldRetry(false) + } + upgrader := websocket.Upgrader{} - _, err := upgrader.Upgrade(rw, r, nil) + _, err := upgrader.Upgrade(rw, req, nil) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index f30770b62..0a13bb521 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -21,6 +21,7 @@ import ( "github.com/traefik/traefik/v2/pkg/middlewares/emptybackendhandler" metricsMiddle "github.com/traefik/traefik/v2/pkg/middlewares/metrics" "github.com/traefik/traefik/v2/pkg/middlewares/pipelining" + "github.com/traefik/traefik/v2/pkg/middlewares/retry" "github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/server/cookie" "github.com/traefik/traefik/v2/pkg/server/provider" @@ -283,16 +284,20 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName if err != nil { return nil, err } + // The retry wrapping must be done just before the proxy handler, + // to make sure that the retry will not be triggered/disabled by + // middlewares in the chain. + fwd = retry.WrapHandler(fwd) - alHandler := func(next http.Handler) (http.Handler, error) { - return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil - } chain := alice.New() if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() { chain = chain.Append(metricsMiddle.WrapServiceHandler(ctx, m.metricsRegistry, serviceName)) } + chain = chain.Append(func(next http.Handler) (http.Handler, error) { + return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil + }) - handler, err := chain.Append(alHandler).Then(pipelining.New(ctx, fwd, "pipelining")) + handler, err := chain.Then(pipelining.New(ctx, fwd, "pipelining")) if err != nil { return nil, err }