diff --git a/pkg/proxy/fast/proxy.go b/pkg/proxy/fast/proxy.go index 06e68513e..76b44a1db 100644 --- a/pkg/proxy/fast/proxy.go +++ b/pkg/proxy/fast/proxy.go @@ -171,6 +171,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if reqUpType != "" { outReq.Header.Set("Connection", "Upgrade") outReq.Header.Set("Upgrade", reqUpType) + if strings.EqualFold(reqUpType, "websocket") { cleanWebSocketHeaders(&outReq.Header) } @@ -351,7 +352,7 @@ func isGraphic(s string) bool { type fasthttpHeader interface { Peek(key string) []byte Set(key string, value string) - SetBytesV(key string, value []byte) + SetCanonical(key []byte, value []byte) DelBytes(key []byte) Del(key string) ConnectionUpgrade() bool @@ -382,18 +383,33 @@ func fixPragmaCacheControl(header fasthttpHeader) { // Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. // https://tools.ietf.org/html/rfc6455#page-20 func cleanWebSocketHeaders(headers fasthttpHeader) { - headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key")) - headers.Del("Sec-Websocket-Key") + secWebsocketKey := headers.Peek("Sec-Websocket-Key") + if len(secWebsocketKey) > 0 { + headers.SetCanonical([]byte("Sec-WebSocket-Key"), secWebsocketKey) + headers.Del("Sec-Websocket-Key") + } - headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions")) - headers.Del("Sec-Websocket-Extensions") + secWebsocketExtensions := headers.Peek("Sec-Websocket-Extensions") + if len(secWebsocketExtensions) > 0 { + headers.SetCanonical([]byte("Sec-WebSocket-Extensions"), secWebsocketExtensions) + headers.Del("Sec-Websocket-Extensions") + } - headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept")) - headers.Del("Sec-Websocket-Accept") + secWebsocketAccept := headers.Peek("Sec-Websocket-Accept") + if len(secWebsocketAccept) > 0 { + headers.SetCanonical([]byte("Sec-WebSocket-Accept"), secWebsocketAccept) + headers.Del("Sec-Websocket-Accept") + } - headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol")) - headers.Del("Sec-Websocket-Protocol") + secWebsocketProtocol := headers.Peek("Sec-Websocket-Protocol") + if len(secWebsocketProtocol) > 0 { + headers.SetCanonical([]byte("Sec-WebSocket-Protocol"), secWebsocketProtocol) + headers.Del("Sec-Websocket-Protocol") + } - headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version")) - headers.DelBytes([]byte("Sec-Websocket-Version")) + secWebsocketVersion := headers.Peek("Sec-Websocket-Version") + if len(secWebsocketVersion) > 0 { + headers.SetCanonical([]byte("Sec-WebSocket-Version"), secWebsocketVersion) + headers.Del("Sec-Websocket-Version") + } } diff --git a/pkg/proxy/fast/proxy_websocket_test.go b/pkg/proxy/fast/proxy_websocket_test.go index ef22895cc..313718322 100644 --- a/pkg/proxy/fast/proxy_websocket_test.go +++ b/pkg/proxy/fast/proxy_websocket_test.go @@ -18,9 +18,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traefik/traefik/v3/pkg/testhelpers" + "github.com/valyala/fasthttp" "golang.org/x/net/websocket" ) +const dialTimeout = time.Second + func TestWebSocketUpgradeCase(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { challengeKey := r.Header.Get("Sec-Websocket-Key") @@ -49,6 +52,31 @@ func TestWebSocketUpgradeCase(t *testing.T) { conn.Close() } +func TestCleanWebSocketHeaders(t *testing.T) { + // Asserts that no headers are sent if the request contain anything. + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + + cleanWebSocketHeaders(&req.Header) + + want := "GET / HTTP/1.1\r\n\r\n" + assert.Equal(t, want, req.Header.String()) + + // Asserts that the Sec-WebSocket-* is enforced. + req.Reset() + + req.Header.Set("Sec-Websocket-Key", "key") + req.Header.Set("Sec-Websocket-Extensions", "extensions") + req.Header.Set("Sec-Websocket-Accept", "accept") + req.Header.Set("Sec-Websocket-Protocol", "protocol") + req.Header.Set("Sec-Websocket-Version", "version") + + cleanWebSocketHeaders(&req.Header) + + want = "GET / HTTP/1.1\r\nSec-WebSocket-Key: key\r\nSec-WebSocket-Extensions: extensions\r\nSec-WebSocket-Accept: accept\r\nSec-WebSocket-Protocol: protocol\r\nSec-WebSocket-Version: version\r\n\r\n" + assert.Equal(t, want, req.Header.String()) +} + func TestWebSocketTCPClose(t *testing.T) { errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} @@ -535,29 +563,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) { assert.Equal(t, "ok", resp) } -func createTLSWebsocketServer() *httptest.Server { - upgrader := gorillawebsocket.Upgrader{} - srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - mt, message, err := conn.ReadMessage() - if err != nil { - break - } - - err = conn.WriteMessage(mt, message) - if err != nil { - break - } - } - })) - return srv -} - func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() defer srv.Close() @@ -592,7 +597,28 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { assert.Equal(t, "ok", resp) } -const dialTimeout = time.Second +func createTLSWebsocketServer() *httptest.Server { + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + return srv +} type websocketRequestOpt func(w *websocketRequest) diff --git a/pkg/proxy/httputil/proxy.go b/pkg/proxy/httputil/proxy.go index 8e4a43fe5..11213952a 100644 --- a/pkg/proxy/httputil/proxy.go +++ b/pkg/proxy/httputil/proxy.go @@ -70,7 +70,9 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu outReq.Host = outReq.URL.Host } - cleanWebSocketHeaders(outReq) + if isWebSocketUpgrade(outReq) { + cleanWebSocketHeaders(outReq) + } } } @@ -79,10 +81,6 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu // Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. // https://tools.ietf.org/html/rfc6455#page-20 func cleanWebSocketHeaders(req *http.Request) { - if !isWebSocketUpgrade(req) { - return - } - req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"] delete(req.Header, "Sec-Websocket-Key") diff --git a/pkg/proxy/httputil/proxy_websocket_test.go b/pkg/proxy/httputil/proxy_websocket_test.go index 5472bcce0..48296f955 100644 --- a/pkg/proxy/httputil/proxy_websocket_test.go +++ b/pkg/proxy/httputil/proxy_websocket_test.go @@ -2,6 +2,7 @@ package httputil import ( "bufio" + "bytes" "crypto/tls" "errors" "fmt" @@ -18,6 +19,8 @@ import ( "golang.org/x/net/websocket" ) +const dialTimeout = time.Second + func TestWebSocketTCPClose(t *testing.T) { errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} @@ -419,28 +422,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) { assert.Equal(t, "ok", resp) } -func createTLSWebsocketServer() *httptest.Server { - upgrader := gorillawebsocket.Upgrader{} - srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - mt, message, err := conn.ReadMessage() - if err != nil { - break - } - err = conn.WriteMessage(mt, message) - if err != nil { - break - } - } - })) - return srv -} - func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() defer srv.Close() @@ -495,7 +476,58 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { assert.Equal(t, "ok", resp) } -const dialTimeout = time.Second +func TestCleanWebSocketHeaders(t *testing.T) { + // Asserts that no headers are sent if the request contain anything. + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Del("User-Agent") + + cleanWebSocketHeaders(req) + + b := bytes.NewBuffer(nil) + err := req.Header.Write(b) + require.NoError(t, err) + + assert.Empty(t, b) + + // Asserts that the Sec-WebSocket-* is enforced. + req.Header.Set("Sec-Websocket-Key", "key") + req.Header.Set("Sec-Websocket-Extensions", "extensions") + req.Header.Set("Sec-Websocket-Accept", "accept") + req.Header.Set("Sec-Websocket-Protocol", "protocol") + req.Header.Set("Sec-Websocket-Version", "version") + + cleanWebSocketHeaders(req) + + want := http.Header{ + "Sec-WebSocket-Key": {"key"}, + "Sec-WebSocket-Extensions": {"extensions"}, + "Sec-WebSocket-Accept": {"accept"}, + "Sec-WebSocket-Protocol": {"protocol"}, + "Sec-WebSocket-Version": {"version"}, + } + assert.Equal(t, want, req.Header) +} + +func createTLSWebsocketServer() *httptest.Server { + var upgrader gorillawebsocket.Upgrader + return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) +} type websocketRequestOpt func(w *websocketRequest)