From 76dcbe34291f7e53c68f10a93ec7b32bc749c24d Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Mon, 23 Apr 2018 11:28:04 +0200 Subject: [PATCH] Fix error pages redirect and headers. --- middlewares/errorpages/error_pages.go | 33 ++++++++++++++++++---- middlewares/errorpages/error_pages_test.go | 6 ++-- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/middlewares/errorpages/error_pages.go b/middlewares/errorpages/error_pages.go index 49c218639..a2e9d426c 100644 --- a/middlewares/errorpages/error_pages.go +++ b/middlewares/errorpages/error_pages.go @@ -3,8 +3,10 @@ package errorpages import ( "bufio" "bytes" + "fmt" "net" "net/http" + "net/url" "strconv" "strings" @@ -75,8 +77,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. recorder := newResponseRecorder(w) next.ServeHTTP(recorder, req) - w.WriteHeader(recorder.GetCode()) - // check the recorder code against the configured http status code ranges for _, block := range h.httpCodeRanges { if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { @@ -88,20 +88,43 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1) } - if newReq, err := http.NewRequest(http.MethodGet, h.backendURL+query, nil); err != nil { + pageReq, err := newRequest(h.backendURL + query) + if err != nil { + log.Error(err) + w.WriteHeader(recorder.GetCode()) w.Write([]byte(http.StatusText(recorder.GetCode()))) - } else { - h.backendHandler.ServeHTTP(w, newReq) + return } + + utils.CopyHeaders(pageReq.Header, req.Header) + utils.CopyHeaders(w.Header(), recorder.Header()) + w.WriteHeader(recorder.GetCode()) + h.backendHandler.ServeHTTP(w, pageReq) return } } // did not catch a configured status code so proceed with the request utils.CopyHeaders(w.Header(), recorder.Header()) + w.WriteHeader(recorder.GetCode()) w.Write(recorder.GetBody().Bytes()) } +func newRequest(baseURL string) (*http.Request, error) { + u, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("error pages: error when parse URL: %v", err) + } + + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("error pages: error when create query: %v", err) + } + + req.RequestURI = u.RequestURI() + return req, nil +} + type responseRecorder interface { http.ResponseWriter http.Flusher diff --git a/middlewares/errorpages/error_pages_test.go b/middlewares/errorpages/error_pages_test.go index 0ff70689b..2264dc336 100644 --- a/middlewares/errorpages/error_pages_test.go +++ b/middlewares/errorpages/error_pages_test.go @@ -65,7 +65,7 @@ func TestHandler(t *testing.T) { errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}}, backendCode: http.StatusServiceUnavailable, backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + if r.RequestURI == "/503" { fmt.Fprintln(w, "My 503 page.") } else { fmt.Fprintln(w, "Failed") @@ -82,7 +82,7 @@ func TestHandler(t *testing.T) { errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}}, backendCode: http.StatusServiceUnavailable, backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + if r.RequestURI == "/503" { fmt.Fprintln(w, "My 503 page.") } else { fmt.Fprintln(w, "Failed") @@ -318,6 +318,7 @@ func TestHandlerOldWayIntegration(t *testing.T) { require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Foo", "bar") w.WriteHeader(test.backendCode) fmt.Fprintln(w, http.StatusText(test.backendCode)) }) @@ -330,6 +331,7 @@ func TestHandlerOldWayIntegration(t *testing.T) { n.ServeHTTP(recorder, req) test.validate(t, recorder) + assert.Equal(t, "bar", recorder.Header().Get("X-Foo"), "missing header") }) } }