diff --git a/docs/content/middlewares/http/compress.md b/docs/content/middlewares/http/compress.md index 46028ee0b..1f204be29 100644 --- a/docs/content/middlewares/http/compress.md +++ b/docs/content/middlewares/http/compress.md @@ -179,9 +179,15 @@ http: _Optional, Default=1024_ `minResponseBodyBytes` specifies the minimum amount of bytes a response body must have to be compressed. - Responses smaller than the specified values will not be compressed. +!!! tip "Streaming" + + When data is sent to the client on flush, the `minResponseBodyBytes` configuration is ignored and the data is compressed. + This is particularly the case when data is streamed to the client when using `Transfer-encoding: chunked` response. + +When chunked data is sent to the client on flush, it will be compressed by default even if the received data has not reached + ```yaml tab="Docker & Swarm" labels: - "traefik.http.middlewares.test-compress.compress.minresponsebodybytes=1200" diff --git a/pkg/middlewares/compress/compress_test.go b/pkg/middlewares/compress/compress_test.go index 8279165ac..430df7611 100644 --- a/pkg/middlewares/compress/compress_test.go +++ b/pkg/middlewares/compress/compress_test.go @@ -609,83 +609,106 @@ func TestMinResponseBodyBytes(t *testing.T) { func Test1xxResponses(t *testing.T) { fakeBody := generateBytes(100000) - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() - h.Add("Link", "; rel=preload; as=style") - h.Add("Link", "; rel=preload; as=script") - w.WriteHeader(http.StatusEarlyHints) - - h.Add("Link", "; rel=preload; as=script") - w.WriteHeader(http.StatusProcessing) - - if _, err := w.Write(fakeBody); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - }) - cfg := dynamic.Compress{ - MinResponseBodyBytes: 1024, - Encodings: defaultSupportedEncodings, - } - compress, err := New(context.Background(), next, cfg, "testing") - require.NoError(t, err) - - server := httptest.NewServer(compress) - t.Cleanup(server.Close) - frontendClient := server.Client() - - checkLinkHeaders := func(t *testing.T, expected, got []string) { - t.Helper() - - if len(expected) != len(got) { - t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) - } - - for i := range expected { - if i >= len(got) { - t.Errorf("Expected %q link header; got nothing", expected[i]) - - continue - } - - if expected[i] != got[i] { - t.Errorf("Expected %q link header; got %q", expected[i], got[i]) - } - } - } - - var respCounter uint8 - trace := &httptrace.ClientTrace{ - Got1xxResponse: func(code int, header textproto.MIMEHeader) error { - switch code { - case http.StatusEarlyHints: - checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) - case http.StatusProcessing: - checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) - default: - t.Error("Unexpected 1xx response") - } - - respCounter++ - - return nil + testCases := []struct { + desc string + encoding string + }{ + { + desc: "gzip", + encoding: gzipName, + }, + { + desc: "brotli", + encoding: brotliName, + }, + { + desc: "zstd", + encoding: zstdName, }, } - req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) - req.Header.Add(acceptEncodingHeader, gzipName) + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - res, err := frontendClient.Do(req) - assert.NoError(t, err) + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) - defer res.Body.Close() + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) - if respCounter != 2 { - t.Errorf("Expected 2 1xx responses; got %d", respCounter) + if _, err := w.Write(fakeBody); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }) + cfg := dynamic.Compress{ + MinResponseBodyBytes: 1024, + Encodings: defaultSupportedEncodings, + } + compress, err := New(context.Background(), next, cfg, "testing") + require.NoError(t, err) + + server := httptest.NewServer(compress) + t.Cleanup(server.Close) + frontendClient := server.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) + req.Header.Add(acceptEncodingHeader, test.encoding) + + res, err := frontendClient.Do(req) + assert.NoError(t, err) + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + assert.Equal(t, test.encoding, res.Header.Get(contentEncodingHeader)) + body, _ := io.ReadAll(res.Body) + assert.NotEqualValues(t, body, fakeBody) + }) } - checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) - - assert.Equal(t, gzipName, res.Header.Get(contentEncodingHeader)) - body, _ := io.ReadAll(res.Body) - assert.NotEqualValues(t, body, fakeBody) } func BenchmarkCompressGzip(b *testing.B) { diff --git a/pkg/middlewares/compress/compression_handler.go b/pkg/middlewares/compress/compression_handler.go index 1c5065ca7..215607ccd 100644 --- a/pkg/middlewares/compress/compression_handler.go +++ b/pkg/middlewares/compress/compression_handler.go @@ -192,12 +192,17 @@ func (r *responseWriter) Header() http.Header { } func (r *responseWriter) WriteHeader(statusCode int) { - if r.statusCodeSet { + // Handle informational headers + // This is gated to not forward 1xx responses on builds prior to go1.20. + if statusCode >= 100 && statusCode <= 199 { + r.rw.WriteHeader(statusCode) return } - r.statusCode = statusCode - r.statusCodeSet = true + if !r.statusCodeSet { + r.statusCode = statusCode + r.statusCodeSet = true + } } func (r *responseWriter) Write(p []byte) (int, error) { @@ -319,11 +324,16 @@ func (r *responseWriter) Flush() { } // Here, nothing was ever written either to rw or to bw (since we're still - // waiting to decide whether to compress), so we do not need to flush anything. - // Note that we diverge with klauspost's gzip behavior, where they instead - // force compression and flush whatever was in the buffer in this case. + // waiting to decide whether to compress), so to be aligned with klauspost's + // gzip behavior we force the compression and flush whatever was in the buffer in this case. if !r.compressionStarted { - return + r.rw.Header().Del(contentLength) + + r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding()) + r.rw.WriteHeader(r.statusCode) + r.headersSent = true + + r.compressionStarted = true } // Conversely, we here know that something was already written to bw (or is diff --git a/pkg/middlewares/compress/compression_handler_test.go b/pkg/middlewares/compress/compression_handler_test.go index c702500d7..b078ed71f 100644 --- a/pkg/middlewares/compress/compression_handler_test.go +++ b/pkg/middlewares/compress/compression_handler_test.go @@ -498,6 +498,73 @@ func Test_FlushAfterAllWrites(t *testing.T) { } } +func Test_FlushForceCompress(t *testing.T) { + testCases := []struct { + desc string + cfg Config + algo string + readerBuilder func(io.Reader) (io.Reader, error) + acceptEncoding string + }{ + { + desc: "brotli", + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: brotliName, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return brotli.NewReader(reader), nil + }, + acceptEncoding: "br", + }, + { + desc: "zstd", + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: zstdName, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return zstd.NewReader(reader) + }, + acceptEncoding: "zstd", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + + _, err := rw.Write(smallTestBody) + require.NoError(t, err) + + rw.(http.Flusher).Flush() + }) + + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next)) + defer srv.Close() + + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) + require.NoError(t, err) + + req.Header.Set(acceptEncoding, test.acceptEncoding) + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding)) + + reader, err := test.readerBuilder(res.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, smallTestBody, got) + }) + } +} + func Test_ExcludedContentTypes(t *testing.T) { testCases := []struct { desc string