diff --git a/integration/retry_test.go b/integration/retry_test.go index 1def53435..ca935779e 100644 --- a/integration/retry_test.go +++ b/integration/retry_test.go @@ -7,6 +7,7 @@ import ( "github.com/containous/traefik/integration/try" "github.com/go-check/check" + "github.com/gorilla/websocket" checker "github.com/vdemeester/shakers" ) @@ -38,3 +39,29 @@ func (s *RetrySuite) TestRetry(c *check.C) { c.Assert(err, checker.IsNil) c.Assert(response.StatusCode, checker.Equals, http.StatusOK) } + +func (s *RetrySuite) TestRetryWebsocket(c *check.C) { + whoamiEndpoint := s.composeProject.Container(c, "whoami").NetworkSettings.IPAddress + file := s.adaptFile(c, "fixtures/retry/simple.toml", struct { + WhoamiEndpoint string + }{whoamiEndpoint}) + defer os.Remove(file) + + cmd, display := s.traefikCmd(withConfigFile(file)) + defer display(c) + err := cmd.Start() + c.Assert(err, checker.IsNil) + defer cmd.Process.Kill() + + err = try.GetRequest("http://127.0.0.1:8080/api/providers", 60*time.Second, try.BodyContains("PathPrefix:/")) + c.Assert(err, checker.IsNil) + + // This simulates a DialTimeout when connecting to the backend server. + _, response, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:8000/echo", nil) + c.Assert(err, checker.IsNil) + c.Assert(response.StatusCode, checker.Equals, http.StatusSwitchingProtocols) + + _, response, err = websocket.DefaultDialer.Dial("ws://127.0.0.1:8000/echo", nil) + c.Assert(err, checker.IsNil) + c.Assert(response.StatusCode, checker.Equals, http.StatusSwitchingProtocols) +} diff --git a/middlewares/retry.go b/middlewares/retry.go index 6ecbdf1cd..57c631ded 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -2,6 +2,7 @@ package middlewares import ( "bufio" + "fmt" "io/ioutil" "net" "net/http" @@ -41,11 +42,8 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { attempts := 1 for { attemptsExhausted := attempts >= retry.attempts - // Websocket requests can't be retried at this point in time. - // This is due to the fact that gorilla/websocket doesn't use the request - // context and so we don't get httptrace information. - // Websocket clients should however retry on their own anyway. - shouldRetry := !attemptsExhausted && !isWebsocketRequest(r) + + shouldRetry := !attemptsExhausted retryResponseWriter := newRetryResponseWriter(rw, shouldRetry) // Disable retries when the backend already received request data @@ -150,7 +148,11 @@ func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) { } func (rr *retryResponseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return rr.responseWriter.(http.Hijacker).Hijack() + hijacker, ok := rr.responseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("%T is not a http.Hijacker", rr.responseWriter) + } + return hijacker.Hijack() } func (rr *retryResponseWriterWithoutCloseNotify) Flush() { diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go index b73044b29..9c51c567f 100644 --- a/middlewares/retry_test.go +++ b/middlewares/retry_test.go @@ -3,21 +3,24 @@ package middlewares import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/containous/traefik/testhelpers" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/roundrobin" ) func TestRetry(t *testing.T) { testCases := []struct { - desc string - maxRequestAttempts int - wantRetryAttempts int - wantResponseStatus int - amountFaultyEndpoints int - isWebsocketHandshakeRequest bool + desc string + maxRequestAttempts int + wantRetryAttempts int + wantResponseStatus int + amountFaultyEndpoints int }{ { desc: "no retry on success", @@ -54,14 +57,6 @@ func TestRetry(t *testing.T) { wantResponseStatus: http.StatusInternalServerError, amountFaultyEndpoints: 3, }, - { - desc: "websocket request should not be retried", - maxRequestAttempts: 3, - wantRetryAttempts: 0, - wantResponseStatus: http.StatusBadGateway, - amountFaultyEndpoints: 1, - isWebsocketHandshakeRequest: true, - }, } backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -74,10 +69,10 @@ func TestRetry(t *testing.T) { t.Fatalf("Error creating forwarder: %s", err) } - for _, tc := range testCases { - tc := tc + for _, test := range testCases { + test := test - t.Run(tc.desc, func(t *testing.T) { + t.Run(test.desc, func(t *testing.T) { t.Parallel() loadBalancer, err := roundrobin.New(forwarder) @@ -86,7 +81,7 @@ func TestRetry(t *testing.T) { } basePort := 33444 - for i := 0; i < tc.amountFaultyEndpoints; i++ { + for i := 0; i < test.amountFaultyEndpoints; i++ { // 192.0.2.0 is a non-routable IP for testing purposes. // See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928 // We only use the port specification here because the URL is used as identifier @@ -98,24 +93,91 @@ func TestRetry(t *testing.T) { loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL)) retryListener := &countingRetryListener{} - retry := NewRetry(tc.maxRequestAttempts, loadBalancer, retryListener) + retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener) recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil) - if tc.isWebsocketHandshakeRequest { - req.Header.Add("Connection", "Upgrade") - req.Header.Add("Upgrade", "websocket") - } - retry.ServeHTTP(recorder, req) - if tc.wantResponseStatus != recorder.Code { - t.Errorf("got status code %d, want %d", recorder.Code, tc.wantResponseStatus) + assert.Equal(t, test.wantResponseStatus, recorder.Code) + assert.Equal(t, test.wantRetryAttempts, retryListener.timesCalled) + }) + } +} + +func TestRetryWebsocket(t *testing.T) { + testCases := []struct { + desc string + maxRequestAttempts int + expectedRetryAttempts int + expectedResponseStatus int + expectedError bool + amountFaultyEndpoints int + }{ + { + desc: "Switching ok after 2 retries", + maxRequestAttempts: 3, + expectedRetryAttempts: 2, + amountFaultyEndpoints: 2, + expectedResponseStatus: http.StatusSwitchingProtocols, + }, + { + desc: "Switching failed", + maxRequestAttempts: 2, + expectedRetryAttempts: 1, + amountFaultyEndpoints: 2, + expectedResponseStatus: http.StatusBadGateway, + expectedError: true, + }, + } + + forwarder, err := forward.New() + if err != nil { + t.Fatalf("Error creating forwarder: %s", err) + } + + backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + upgrader := websocket.Upgrader{} + upgrader.Upgrade(rw, req, nil) + })) + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + loadBalancer, err := roundrobin.New(forwarder) + if err != nil { + t.Fatalf("Error creating load balancer: %s", err) } - if tc.wantRetryAttempts != retryListener.timesCalled { - t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, tc.wantRetryAttempts) + + basePort := 33444 + for i := 0; i < test.amountFaultyEndpoints; i++ { + // 192.0.2.0 is a non-routable IP for testing purposes. + // See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928 + // We only use the port specification here because the URL is used as identifier + // in the load balancer and using the exact same URL would not add a new server. + loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i))) } + + // add the functioning server to the end of the load balancer list + loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL)) + + retryListener := &countingRetryListener{} + retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener) + + retryServer := httptest.NewServer(retry) + + url := strings.Replace(retryServer.URL, "http", "ws", 1) + _, response, err := websocket.DefaultDialer.Dial(url, nil) + + if !test.expectedError { + require.NoError(t, err) + } + + assert.Equal(t, test.expectedResponseStatus, response.StatusCode) + assert.Equal(t, test.expectedRetryAttempts, retryListener.timesCalled) }) } }