diff --git a/pkg/middlewares/retry/retry.go b/pkg/middlewares/retry/retry.go index e1eee533d..5c95486d3 100644 --- a/pkg/middlewares/retry/retry.go +++ b/pkg/middlewares/retry/retry.go @@ -196,6 +196,9 @@ func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) { if r.ShouldRetry() { return len(buf), nil } + if !r.written { + r.WriteHeader(http.StatusOK) + } return r.responseWriter.Write(buf) } diff --git a/pkg/middlewares/retry/retry_test.go b/pkg/middlewares/retry/retry_test.go index be2606ce4..8d493cf11 100644 --- a/pkg/middlewares/retry/retry_test.go +++ b/pkg/middlewares/retry/retry_test.go @@ -169,7 +169,6 @@ func TestRetryListeners(t *testing.T) { func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { attempt := 0 - expectedHeaderName := "X-Foo-Test-2" expectedHeaderValue := "bar" next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -180,44 +179,55 @@ func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { return } - // Request has been successfully written to backend + // Request has been successfully written to backend. trace := httptrace.ContextClientTrace(req.Context()) trace.WroteHeaders() - // And we decide to answer to client + // And we decide to answer to client. rw.WriteHeader(http.StatusNoContent) }) retry, err := New(context.Background(), next, dynamic.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest") require.NoError(t, err) - responseRecorder := httptest.NewRecorder() - retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody)) + res := httptest.NewRecorder() + retry.ServeHTTP(res, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody)) - headerValue := responseRecorder.Header().Get(expectedHeaderName) - - // Validate if we have the correct header - if headerValue != expectedHeaderValue { - t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue) - } + // The third header attempt is kept. + headerValue := res.Header().Get("X-Foo-Test-2") + assert.Equal(t, expectedHeaderValue, headerValue) // Validate that we don't have headers from previous attempts for i := range attempt { headerName := fmt.Sprintf("X-Foo-Test-%d", i) - headerValue = responseRecorder.Header().Get("headerName") + headerValue = res.Header().Get(headerName) if headerValue != "" { t.Errorf("Expected no value for header %s, got %s", headerName, headerValue) } } } -// countingRetryListener is a Listener implementation to count the times the Retried fn is called. -type countingRetryListener struct { - timesCalled int -} +func TestRetryShouldNotLooseHeadersOnWrite(t *testing.T) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Add("X-Foo-Test", "bar") -func (l *countingRetryListener) Retried(req *http.Request, attempt int) { - l.timesCalled++ + // Request has been successfully written to backend. + trace := httptrace.ContextClientTrace(req.Context()) + trace.WroteHeaders() + + // And we decide to answer to client without calling WriteHeader. + _, err := rw.Write([]byte("bar")) + require.NoError(t, err) + }) + + retry, err := New(context.Background(), next, dynamic.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest") + require.NoError(t, err) + + res := httptest.NewRecorder() + retry.ServeHTTP(res, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody)) + + headerValue := res.Header().Get("X-Foo-Test") + assert.Equal(t, "bar", headerValue) } func TestRetryWithFlush(t *testing.T) { @@ -387,3 +397,12 @@ func Test1xxResponses(t *testing.T) { assert.Equal(t, 0, retryListener.timesCalled) } + +// countingRetryListener is a Listener implementation to count the times the Retried fn is called. +type countingRetryListener struct { + timesCalled int +} + +func (l *countingRetryListener) Retried(req *http.Request, attempt int) { + l.timesCalled++ +}