diff --git a/pkg/proxy/fast/connpool.go b/pkg/proxy/fast/connpool.go index 375fac2fb..4c0be0fe9 100644 --- a/pkg/proxy/fast/connpool.go +++ b/pkg/proxy/fast/connpool.go @@ -20,8 +20,9 @@ import ( // rwWithUpgrade contains a ResponseWriter and an upgradeHandler, // used to upgrade the connection (e.g. Websockets). type rwWithUpgrade struct { - RW http.ResponseWriter - Upgrade upgradeHandler + ReqMethod string + RW http.ResponseWriter + Upgrade upgradeHandler } // conn is an enriched net.Conn. @@ -211,6 +212,10 @@ func (c *conn) handleResponse(r rwWithUpgrade) error { r.RW.WriteHeader(res.StatusCode()) + if noResponseBodyExpected(r.ReqMethod) { + return nil + } + if res.Header.ContentLength() == 0 { return nil } @@ -444,8 +449,8 @@ func (c *connPool) askForNewConn(errCh chan<- error) { c.releaseConn(newConn) } -// isBodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC 7230, section 3.3. +// isBodyAllowedForStatus reports whether a given response status code permits a body. +// See RFC 7230, section 3.3. // From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L459 func isBodyAllowedForStatus(status int) bool { switch { @@ -458,3 +463,9 @@ func isBodyAllowedForStatus(status int) bool { } return true } + +// noResponseBodyExpected reports whether a given request method permits a body. +// From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L250 +func noResponseBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} diff --git a/pkg/proxy/fast/proxy.go b/pkg/proxy/fast/proxy.go index 717b1ff06..06e68513e 100644 --- a/pkg/proxy/fast/proxy.go +++ b/pkg/proxy/fast/proxy.go @@ -284,8 +284,9 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR // Sending the responseWriter unlocks the connection readLoop, to handle the response. co.RWCh <- rwWithUpgrade{ - RW: rw, - Upgrade: upgradeResponseHandler(req.Context(), reqUpType), + ReqMethod: req.Method, + RW: rw, + Upgrade: upgradeResponseHandler(req.Context(), reqUpType), } if err := <-co.ErrCh; err != nil { diff --git a/pkg/proxy/fast/proxy_test.go b/pkg/proxy/fast/proxy_test.go index f4593d9ce..b9f95b606 100644 --- a/pkg/proxy/fast/proxy_test.go +++ b/pkg/proxy/fast/proxy_test.go @@ -278,6 +278,34 @@ func TestPreservePath(t *testing.T) { assert.Equal(t, http.StatusOK, res.Code) } +func TestHeadRequest(t *testing.T) { + var callCount int + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + callCount++ + + assert.Equal(t, http.MethodHead, req.Method) + + rw.Header().Set("Content-Length", "42") + })) + t.Cleanup(server.Close) + + builder := NewProxyBuilder(&transportManagerMock{}, static.FastProxyConfig{}) + + serverURL, err := url.JoinPath(server.URL) + require.NoError(t, err) + + proxyHandler, err := builder.Build("", testhelpers.MustParseURL(serverURL), true, true) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodHead, "/", http.NoBody) + res := httptest.NewRecorder() + + proxyHandler.ServeHTTP(res, req) + + assert.Equal(t, 1, callCount) + assert.Equal(t, http.StatusOK, res.Code) +} + func newCertificate(t *testing.T, domain string) *tls.Certificate { t.Helper()