mirror of
https://github.com/containous/traefik.git
synced 2025-03-19 18:50:12 +03:00
Add WebSocket headers if they are present in the request
Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
parent
1cfcf0d318
commit
1ccbf743cb
@ -171,6 +171,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if reqUpType != "" {
|
||||
outReq.Header.Set("Connection", "Upgrade")
|
||||
outReq.Header.Set("Upgrade", reqUpType)
|
||||
|
||||
if strings.EqualFold(reqUpType, "websocket") {
|
||||
cleanWebSocketHeaders(&outReq.Header)
|
||||
}
|
||||
@ -351,7 +352,7 @@ func isGraphic(s string) bool {
|
||||
type fasthttpHeader interface {
|
||||
Peek(key string) []byte
|
||||
Set(key string, value string)
|
||||
SetBytesV(key string, value []byte)
|
||||
SetCanonical(key []byte, value []byte)
|
||||
DelBytes(key []byte)
|
||||
Del(key string)
|
||||
ConnectionUpgrade() bool
|
||||
@ -382,18 +383,33 @@ func fixPragmaCacheControl(header fasthttpHeader) {
|
||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||
// https://tools.ietf.org/html/rfc6455#page-20
|
||||
func cleanWebSocketHeaders(headers fasthttpHeader) {
|
||||
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key"))
|
||||
headers.Del("Sec-Websocket-Key")
|
||||
secWebsocketKey := headers.Peek("Sec-Websocket-Key")
|
||||
if len(secWebsocketKey) > 0 {
|
||||
headers.SetCanonical([]byte("Sec-WebSocket-Key"), secWebsocketKey)
|
||||
headers.Del("Sec-Websocket-Key")
|
||||
}
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions"))
|
||||
headers.Del("Sec-Websocket-Extensions")
|
||||
secWebsocketExtensions := headers.Peek("Sec-Websocket-Extensions")
|
||||
if len(secWebsocketExtensions) > 0 {
|
||||
headers.SetCanonical([]byte("Sec-WebSocket-Extensions"), secWebsocketExtensions)
|
||||
headers.Del("Sec-Websocket-Extensions")
|
||||
}
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept"))
|
||||
headers.Del("Sec-Websocket-Accept")
|
||||
secWebsocketAccept := headers.Peek("Sec-Websocket-Accept")
|
||||
if len(secWebsocketAccept) > 0 {
|
||||
headers.SetCanonical([]byte("Sec-WebSocket-Accept"), secWebsocketAccept)
|
||||
headers.Del("Sec-Websocket-Accept")
|
||||
}
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol"))
|
||||
headers.Del("Sec-Websocket-Protocol")
|
||||
secWebsocketProtocol := headers.Peek("Sec-Websocket-Protocol")
|
||||
if len(secWebsocketProtocol) > 0 {
|
||||
headers.SetCanonical([]byte("Sec-WebSocket-Protocol"), secWebsocketProtocol)
|
||||
headers.Del("Sec-Websocket-Protocol")
|
||||
}
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version"))
|
||||
headers.DelBytes([]byte("Sec-Websocket-Version"))
|
||||
secWebsocketVersion := headers.Peek("Sec-Websocket-Version")
|
||||
if len(secWebsocketVersion) > 0 {
|
||||
headers.SetCanonical([]byte("Sec-WebSocket-Version"), secWebsocketVersion)
|
||||
headers.Del("Sec-Websocket-Version")
|
||||
}
|
||||
}
|
||||
|
@ -18,9 +18,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
func TestWebSocketUpgradeCase(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||
@ -49,6 +52,31 @@ func TestWebSocketUpgradeCase(t *testing.T) {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestCleanWebSocketHeaders(t *testing.T) {
|
||||
// Asserts that no headers are sent if the request contain anything.
|
||||
req := fasthttp.AcquireRequest()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
|
||||
cleanWebSocketHeaders(&req.Header)
|
||||
|
||||
want := "GET / HTTP/1.1\r\n\r\n"
|
||||
assert.Equal(t, want, req.Header.String())
|
||||
|
||||
// Asserts that the Sec-WebSocket-* is enforced.
|
||||
req.Reset()
|
||||
|
||||
req.Header.Set("Sec-Websocket-Key", "key")
|
||||
req.Header.Set("Sec-Websocket-Extensions", "extensions")
|
||||
req.Header.Set("Sec-Websocket-Accept", "accept")
|
||||
req.Header.Set("Sec-Websocket-Protocol", "protocol")
|
||||
req.Header.Set("Sec-Websocket-Version", "version")
|
||||
|
||||
cleanWebSocketHeaders(&req.Header)
|
||||
|
||||
want = "GET / HTTP/1.1\r\nSec-WebSocket-Key: key\r\nSec-WebSocket-Extensions: extensions\r\nSec-WebSocket-Accept: accept\r\nSec-WebSocket-Protocol: protocol\r\nSec-WebSocket-Version: version\r\n\r\n"
|
||||
assert.Equal(t, want, req.Header.String())
|
||||
}
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
@ -535,29 +563,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
@ -592,7 +597,28 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
|
@ -70,7 +70,9 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
||||
cleanWebSocketHeaders(outReq)
|
||||
if isWebSocketUpgrade(outReq) {
|
||||
cleanWebSocketHeaders(outReq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,10 +81,6 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu
|
||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||
// https://tools.ietf.org/html/rfc6455#page-20
|
||||
func cleanWebSocketHeaders(req *http.Request) {
|
||||
if !isWebSocketUpgrade(req) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
|
||||
delete(req.Header, "Sec-Websocket-Key")
|
||||
|
||||
|
@ -2,6 +2,7 @@ package httputil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -18,6 +19,8 @@ import (
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
@ -419,28 +422,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
@ -495,7 +476,58 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
func TestCleanWebSocketHeaders(t *testing.T) {
|
||||
// Asserts that no headers are sent if the request contain anything.
|
||||
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||
req.Header.Del("User-Agent")
|
||||
|
||||
cleanWebSocketHeaders(req)
|
||||
|
||||
b := bytes.NewBuffer(nil)
|
||||
err := req.Header.Write(b)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, b)
|
||||
|
||||
// Asserts that the Sec-WebSocket-* is enforced.
|
||||
req.Header.Set("Sec-Websocket-Key", "key")
|
||||
req.Header.Set("Sec-Websocket-Extensions", "extensions")
|
||||
req.Header.Set("Sec-Websocket-Accept", "accept")
|
||||
req.Header.Set("Sec-Websocket-Protocol", "protocol")
|
||||
req.Header.Set("Sec-Websocket-Version", "version")
|
||||
|
||||
cleanWebSocketHeaders(req)
|
||||
|
||||
want := http.Header{
|
||||
"Sec-WebSocket-Key": {"key"},
|
||||
"Sec-WebSocket-Extensions": {"extensions"},
|
||||
"Sec-WebSocket-Accept": {"accept"},
|
||||
"Sec-WebSocket-Protocol": {"protocol"},
|
||||
"Sec-WebSocket-Version": {"version"},
|
||||
}
|
||||
assert.Equal(t, want, req.Header)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
var upgrader gorillawebsocket.Upgrader
|
||||
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user