diff --git a/integration/fixtures/headers/secure.toml b/integration/fixtures/headers/secure.toml index d04e5a9c8..08a12319b 100644 --- a/integration/fixtures/headers/secure.toml +++ b/integration/fixtures/headers/secure.toml @@ -2,6 +2,9 @@ checkNewVersion = false sendAnonymousUsage = false +[api] + insecure = true + [log] level = "DEBUG" @@ -24,6 +27,11 @@ rule = "Host(`test2.localhost`)" service = "service1" + [http.routers.router3] + rule = "Host(`internal.localhost`)" + middlewares = ["secure"] + service = "api@internal" + [http.middlewares] [http.middlewares.secure.headers] featurePolicy = "vibrate 'none';" diff --git a/integration/headers_test.go b/integration/headers_test.go index 34df16a78..d598873ea 100644 --- a/integration/headers_test.go +++ b/integration/headers_test.go @@ -131,16 +131,18 @@ func (s *HeadersSuite) TestSecureHeadersResponses(c *check.C) { c.Assert(err, checker.IsNil) testCase := []struct { - desc string - expected http.Header - reqHost string + desc string + expected http.Header + reqHost string + internalReqHost string }{ { desc: "Feature-Policy Set", expected: http.Header{ "Feature-Policy": {"vibrate 'none';"}, }, - reqHost: "test.localhost", + reqHost: "test.localhost", + internalReqHost: "internal.localhost", }, } @@ -149,7 +151,14 @@ func (s *HeadersSuite) TestSecureHeadersResponses(c *check.C) { c.Assert(err, checker.IsNil) req.Host = test.reqHost - err = try.Request(req, 500*time.Millisecond, try.HasHeaderStruct(test.expected)) + err = try.Request(req, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK), try.HasHeaderStruct(test.expected)) + c.Assert(err, checker.IsNil) + + req, err = http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/api/rawdata", nil) + c.Assert(err, checker.IsNil) + req.Host = test.internalReqHost + + err = try.Request(req, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK), try.HasHeaderStruct(test.expected)) c.Assert(err, checker.IsNil) } } diff --git a/pkg/middlewares/headers/headers.go b/pkg/middlewares/headers/headers.go index 43faa6dff..53cd34163 100644 --- a/pkg/middlewares/headers/headers.go +++ b/pkg/middlewares/headers/headers.go @@ -54,13 +54,13 @@ func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name strin nextHandler := next if hasSecureHeaders { - logger.Debug("Setting up secureHeaders from %v", cfg) + logger.Debugf("Setting up secureHeaders from %v", cfg) handler = newSecure(next, cfg) nextHandler = handler } if hasCustomHeaders || hasCorsHeaders { - logger.Debug("Setting up customHeaders/Cors from %v", cfg) + logger.Debugf("Setting up customHeaders/Cors from %v", cfg) handler = NewHeader(nextHandler, cfg) } diff --git a/pkg/responsemodifiers/response_modifier.go b/pkg/responsemodifiers/response_modifier.go index bc7b272cf..3fec88043 100644 --- a/pkg/responsemodifiers/response_modifier.go +++ b/pkg/responsemodifiers/response_modifier.go @@ -19,6 +19,7 @@ type Builder struct { } // Build Builds the response modifier. +// It returns nil if there is no modifier to apply. func (f *Builder) Build(ctx context.Context, names []string) func(*http.Response) error { var modifiers []func(*http.Response) error @@ -60,5 +61,5 @@ func (f *Builder) Build(ctx context.Context, names []string) func(*http.Response } } - return func(response *http.Response) error { return nil } + return nil } diff --git a/pkg/responsemodifiers/response_modifier_test.go b/pkg/responsemodifiers/response_modifier_test.go index 9b3d7806b..9854e5725 100644 --- a/pkg/responsemodifiers/response_modifier_test.go +++ b/pkg/responsemodifiers/response_modifier_test.go @@ -184,6 +184,9 @@ func TestBuilderBuild(t *testing.T) { builder := NewBuilder(rtConf.Middlewares) rm := builder.Build(context.Background(), test.middlewares) + if rm == nil { + return + } resp := test.buildResponse(test.conf) diff --git a/pkg/server/service/internalhandler.go b/pkg/server/service/internalhandler.go index 2beb0578e..2843ae91f 100644 --- a/pkg/server/service/internalhandler.go +++ b/pkg/server/service/internalhandler.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/containous/traefik/v2/pkg/config/runtime" + "github.com/containous/traefik/v2/pkg/log" ) type serviceManager interface { @@ -42,13 +43,87 @@ func NewInternalHandlers(api func(configuration *runtime.Configuration) http.Han } } -// BuildHTTP builds an HTTP handler. -func (m *InternalHandlers) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) { - if strings.HasSuffix(serviceName, "@internal") { - return m.get(serviceName) +type responseModifier struct { + r *http.Request + w http.ResponseWriter + + headersSent bool // whether headers have already been sent + code int // status code, must default to 200 + + modifier func(*http.Response) error // can be nil + modified bool // whether modifier has already been called for the current request + modifierErr error // returned by modifier call +} + +// modifier can be nil. +func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier { + return &responseModifier{ + r: r, + w: w, + modifier: modifier, + code: http.StatusOK, + } +} + +func (w *responseModifier) WriteHeader(code int) { + if w.headersSent { + return + } + defer func() { + w.code = code + w.headersSent = true + }() + + if w.modifier == nil || w.modified { + w.w.WriteHeader(code) + return } - return m.serviceManager.BuildHTTP(rootCtx, serviceName, responseModifier) + resp := http.Response{ + Header: w.w.Header(), + Request: w.r, + } + + if err := w.modifier(&resp); err != nil { + w.modifierErr = err + // we are propagating when we are called in Write, but we're logging anyway, + // because we could be called from another place which does not take care of + // checking w.modifierErr. + log.Errorf("Error when applying response modifier: %v", err) + w.w.WriteHeader(http.StatusInternalServerError) + return + } + + w.modified = true + w.w.WriteHeader(code) +} + +func (w *responseModifier) Header() http.Header { + return w.w.Header() +} + +func (w *responseModifier) Write(b []byte) (int, error) { + w.WriteHeader(w.code) + if w.modifierErr != nil { + return 0, w.modifierErr + } + + return w.w.Write(b) +} + +// BuildHTTP builds an HTTP handler. +func (m *InternalHandlers) BuildHTTP(rootCtx context.Context, serviceName string, respModifier func(*http.Response) error) (http.Handler, error) { + if !strings.HasSuffix(serviceName, "@internal") { + return m.serviceManager.BuildHTTP(rootCtx, serviceName, respModifier) + } + + internalHandler, err := m.get(serviceName) + if err != nil { + return nil, err + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + internalHandler.ServeHTTP(newResponseModifier(w, r, respModifier), r) + }), nil } func (m *InternalHandlers) get(serviceName string) (http.Handler, error) {