diff --git a/autogen/gentemplates/gen.go b/autogen/gentemplates/gen.go index 7d129c2a7..1c6f6129e 100644 --- a/autogen/gentemplates/gen.go +++ b/autogen/gentemplates/gen.go @@ -231,7 +231,7 @@ var _templatesConsul_catalogTmpl = []byte(`[backends] status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} @@ -632,7 +632,7 @@ var _templatesDockerTmpl = []byte(`{{$backendServers := .Servers}} status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} @@ -884,7 +884,7 @@ var _templatesEcsTmpl = []byte(`[backends] status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} @@ -1588,7 +1588,7 @@ var _templatesMarathonTmpl = []byte(`{{ $apps := .Applications }} status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} @@ -1826,7 +1826,7 @@ var _templatesMesosTmpl = []byte(`[backends] status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} @@ -2117,7 +2117,7 @@ var _templatesRancherTmpl = []byte(`{{ $backendServers := .Backends }} status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} diff --git a/middlewares/error_pages.go b/middlewares/error_pages.go deleted file mode 100644 index 735aa6f24..000000000 --- a/middlewares/error_pages.go +++ /dev/null @@ -1,174 +0,0 @@ -package middlewares - -import ( - "bufio" - "bytes" - "net" - "net/http" - "strconv" - "strings" - - "github.com/containous/traefik/log" - "github.com/containous/traefik/types" - "github.com/vulcand/oxy/forward" - "github.com/vulcand/oxy/utils" -) - -// Compile time validation that the response recorder implements http interfaces correctly. -var _ Stateful = &errorPagesResponseRecorderWithCloseNotify{} - -//ErrorPagesHandler is a middleware that provides the custom error pages -type ErrorPagesHandler struct { - HTTPCodeRanges types.HTTPCodeRanges - BackendURL string - errorPageForwarder *forward.Forwarder -} - -//NewErrorPagesHandler initializes the utils.ErrorHandler for the custom error pages -func NewErrorPagesHandler(errorPage *types.ErrorPage, backendURL string) (*ErrorPagesHandler, error) { - fwd, err := forward.New() - if err != nil { - return nil, err - } - - httpCodeRanges, err := types.NewHTTPCodeRanges(errorPage.Status) - if err != nil { - return nil, err - } - - return &ErrorPagesHandler{ - HTTPCodeRanges: httpCodeRanges, - BackendURL: backendURL + errorPage.Query, - errorPageForwarder: fwd}, - nil -} - -func (ep *ErrorPagesHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) { - recorder := newErrorPagesResponseRecorder(w) - - next.ServeHTTP(recorder, req) - - w.WriteHeader(recorder.GetCode()) - //check the recorder code against the configured http status code ranges - for _, block := range ep.HTTPCodeRanges { - if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { - log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode()) - finalURL := strings.Replace(ep.BackendURL, "{status}", strconv.Itoa(recorder.GetCode()), -1) - if newReq, err := http.NewRequest(http.MethodGet, finalURL, nil); err != nil { - w.Write([]byte(http.StatusText(recorder.GetCode()))) - } else { - ep.errorPageForwarder.ServeHTTP(w, newReq) - } - return - } - } - - //did not catch a configured status code so proceed with the request - utils.CopyHeaders(w.Header(), recorder.Header()) - w.Write(recorder.GetBody().Bytes()) -} - -type errorPagesResponseRecorder interface { - http.ResponseWriter - http.Flusher - GetCode() int - GetBody() *bytes.Buffer - IsStreamingResponseStarted() bool -} - -// newErrorPagesResponseRecorder returns an initialized responseRecorder. -func newErrorPagesResponseRecorder(rw http.ResponseWriter) errorPagesResponseRecorder { - recorder := &errorPagesResponseRecorderWithoutCloseNotify{ - HeaderMap: make(http.Header), - Body: new(bytes.Buffer), - Code: http.StatusOK, - responseWriter: rw, - } - if _, ok := rw.(http.CloseNotifier); ok { - return &errorPagesResponseRecorderWithCloseNotify{recorder} - } - return recorder -} - -// errorPagesResponseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that -// records its mutations for later inspection. -type errorPagesResponseRecorderWithoutCloseNotify struct { - Code int // the HTTP response code from WriteHeader - HeaderMap http.Header // the HTTP response headers - Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to - - responseWriter http.ResponseWriter - err error - streamingResponseStarted bool -} - -type errorPagesResponseRecorderWithCloseNotify struct { - *errorPagesResponseRecorderWithoutCloseNotify -} - -// CloseNotify returns a channel that receives at most a -// single value (true) when the client connection has gone -// away. -func (rw *errorPagesResponseRecorderWithCloseNotify) CloseNotify() <-chan bool { - return rw.responseWriter.(http.CloseNotifier).CloseNotify() -} - -// Header returns the response headers. -func (rw *errorPagesResponseRecorderWithoutCloseNotify) Header() http.Header { - m := rw.HeaderMap - if m == nil { - m = make(http.Header) - rw.HeaderMap = m - } - return m -} - -func (rw *errorPagesResponseRecorderWithoutCloseNotify) GetCode() int { - return rw.Code -} - -func (rw *errorPagesResponseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { - return rw.Body -} - -func (rw *errorPagesResponseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { - return rw.streamingResponseStarted -} - -// Write always succeeds and writes to rw.Body, if not nil. -func (rw *errorPagesResponseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { - if rw.err != nil { - return 0, rw.err - } - return rw.Body.Write(buf) -} - -// WriteHeader sets rw.Code. -func (rw *errorPagesResponseRecorderWithoutCloseNotify) WriteHeader(code int) { - rw.Code = code -} - -// Hijack hijacks the connection -func (rw *errorPagesResponseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return rw.responseWriter.(http.Hijacker).Hijack() -} - -// Flush sends any buffered data to the client. -func (rw *errorPagesResponseRecorderWithoutCloseNotify) Flush() { - if !rw.streamingResponseStarted { - utils.CopyHeaders(rw.responseWriter.Header(), rw.Header()) - rw.responseWriter.WriteHeader(rw.Code) - rw.streamingResponseStarted = true - } - - _, err := rw.responseWriter.Write(rw.Body.Bytes()) - if err != nil { - log.Errorf("Error writing response in responseRecorder: %s", err) - rw.err = err - } - rw.Body.Reset() - flusher, ok := rw.responseWriter.(http.Flusher) - if ok { - flusher.Flush() - } -} diff --git a/middlewares/error_pages_test.go b/middlewares/error_pages_test.go deleted file mode 100644 index 0c6c25b74..000000000 --- a/middlewares/error_pages_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package middlewares - -import ( - "fmt" - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/containous/traefik/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/urfave/negroni" -) - -func TestErrorPage(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Test Server") - })) - defer ts.Close() - - testErrorPage := &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}} - - testHandler, err := NewErrorPagesHandler(testErrorPage, ts.URL) - require.NoError(t, err) - - assert.Equal(t, testHandler.BackendURL, ts.URL+"/test", "Should be equal") - - recorder := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, ts.URL+"/test", nil) - require.NoError(t, err) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "traefik") - }) - n := negroni.New() - n.Use(testHandler) - n.UseHandler(handler) - - n.ServeHTTP(recorder, req) - - assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status") - assert.Contains(t, recorder.Body.String(), "traefik") - - // ---- - - handler500 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintln(w, "oops") - }) - recorder500 := httptest.NewRecorder() - n500 := negroni.New() - n500.Use(testHandler) - n500.UseHandler(handler500) - - n500.ServeHTTP(recorder500, req) - - assert.Equal(t, http.StatusInternalServerError, recorder500.Code, "HTTP status Internal Server Error") - assert.Contains(t, recorder500.Body.String(), "Test Server") - assert.NotContains(t, recorder500.Body.String(), "oops", "Should not return the oops page") - - handler502 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - fmt.Fprintln(w, "oops") - }) - recorder502 := httptest.NewRecorder() - n502 := negroni.New() - n502.Use(testHandler) - n502.UseHandler(handler502) - - n502.ServeHTTP(recorder502, req) - - assert.Equal(t, http.StatusBadGateway, recorder502.Code, "HTTP status Bad Gateway") - assert.Contains(t, recorder502.Body.String(), "oops") - assert.NotContains(t, recorder502.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code") -} - -func TestErrorPageQuery(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.RequestURI() == "/"+strconv.Itoa(503) { - fmt.Fprintln(w, "503 Test Server") - } else { - fmt.Fprintln(w, "Failed") - } - - })) - defer ts.Close() - - testErrorPage := &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}} - - testHandler, err := NewErrorPagesHandler(testErrorPage, ts.URL) - require.NoError(t, err) - - assert.Equal(t, testHandler.BackendURL, ts.URL+"/{status}", "Should be equal") - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusServiceUnavailable) - fmt.Fprintln(w, "oops") - }) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest(http.MethodGet, ts.URL+"/test", nil) - require.NoError(t, err) - - n := negroni.New() - n.Use(testHandler) - n.UseHandler(handler) - - n.ServeHTTP(recorder, req) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status Service Unavailable") - assert.Contains(t, recorder.Body.String(), "503 Test Server") - assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") -} - -func TestErrorPageSingleCode(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.RequestURI() == "/"+strconv.Itoa(503) { - fmt.Fprintln(w, "503 Test Server") - } else { - fmt.Fprintln(w, "Failed") - } - - })) - defer ts.Close() - - testErrorPage := &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}} - - testHandler, err := NewErrorPagesHandler(testErrorPage, ts.URL) - require.NoError(t, err) - - assert.Equal(t, testHandler.BackendURL, ts.URL+"/{status}", "Should be equal") - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusServiceUnavailable) - fmt.Fprintln(w, "oops") - }) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest(http.MethodGet, ts.URL+"/test", nil) - require.NoError(t, err) - - n := negroni.New() - n.Use(testHandler) - n.UseHandler(handler) - - n.ServeHTTP(recorder, req) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status Service Unavailable") - assert.Contains(t, recorder.Body.String(), "503 Test Server") - assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") -} - -func TestNewErrorPagesResponseRecorder(t *testing.T) { - testCases := []struct { - desc string - rw http.ResponseWriter - expected http.ResponseWriter - }{ - { - desc: "Without Close Notify", - rw: httptest.NewRecorder(), - expected: &errorPagesResponseRecorderWithoutCloseNotify{}, - }, - { - desc: "With Close Notify", - rw: &mockRWCloseNotify{}, - expected: &errorPagesResponseRecorderWithCloseNotify{}, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - rec := newErrorPagesResponseRecorder(test.rw) - - assert.IsType(t, rec, test.expected) - }) - } -} - -type mockRWCloseNotify struct{} - -func (m *mockRWCloseNotify) CloseNotify() <-chan bool { - panic("implement me") -} - -func (m *mockRWCloseNotify) Header() http.Header { - panic("implement me") -} - -func (m *mockRWCloseNotify) Write([]byte) (int, error) { - panic("implement me") -} - -func (m *mockRWCloseNotify) WriteHeader(int) { - panic("implement me") -} diff --git a/middlewares/errorpages/error_pages.go b/middlewares/errorpages/error_pages.go new file mode 100644 index 000000000..49c218639 --- /dev/null +++ b/middlewares/errorpages/error_pages.go @@ -0,0 +1,205 @@ +package errorpages + +import ( + "bufio" + "bytes" + "net" + "net/http" + "strconv" + "strings" + + "github.com/containous/traefik/log" + "github.com/containous/traefik/middlewares" + "github.com/containous/traefik/types" + "github.com/pkg/errors" + "github.com/vulcand/oxy/forward" + "github.com/vulcand/oxy/utils" +) + +// Compile time validation that the response recorder implements http interfaces correctly. +var _ middlewares.Stateful = &responseRecorderWithCloseNotify{} + +// Handler is a middleware that provides the custom error pages +type Handler struct { + BackendName string + backendHandler http.Handler + httpCodeRanges types.HTTPCodeRanges + backendURL string + backendQuery string + FallbackURL string // Deprecated +} + +// NewHandler initializes the utils.ErrorHandler for the custom error pages +func NewHandler(errorPage *types.ErrorPage, backendName string) (*Handler, error) { + if len(backendName) == 0 { + return nil, errors.New("error pages: backend name is mandatory ") + } + + httpCodeRanges, err := types.NewHTTPCodeRanges(errorPage.Status) + if err != nil { + return nil, err + } + + return &Handler{ + BackendName: backendName, + httpCodeRanges: httpCodeRanges, + backendQuery: errorPage.Query, + backendURL: "http://0.0.0.0", + }, nil +} + +// PostLoad adds backend handler if available +func (h *Handler) PostLoad(backendHandler http.Handler) error { + if backendHandler == nil { + fwd, err := forward.New() + if err != nil { + return err + } + + h.backendHandler = fwd + h.backendURL = h.FallbackURL + } else { + h.backendHandler = backendHandler + } + + return nil +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) { + if h.backendHandler == nil { + log.Error("Error pages: no backend handler.") + next.ServeHTTP(w, req) + return + } + + recorder := newResponseRecorder(w) + next.ServeHTTP(recorder, req) + + w.WriteHeader(recorder.GetCode()) + + // check the recorder code against the configured http status code ranges + for _, block := range h.httpCodeRanges { + if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { + log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode()) + + var query string + if len(h.backendQuery) > 0 { + query = "/" + strings.TrimPrefix(h.backendQuery, "/") + query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1) + } + + if newReq, err := http.NewRequest(http.MethodGet, h.backendURL+query, nil); err != nil { + w.Write([]byte(http.StatusText(recorder.GetCode()))) + } else { + h.backendHandler.ServeHTTP(w, newReq) + } + return + } + } + + // did not catch a configured status code so proceed with the request + utils.CopyHeaders(w.Header(), recorder.Header()) + w.Write(recorder.GetBody().Bytes()) +} + +type responseRecorder interface { + http.ResponseWriter + http.Flusher + GetCode() int + GetBody() *bytes.Buffer + IsStreamingResponseStarted() bool +} + +// newResponseRecorder returns an initialized responseRecorder. +func newResponseRecorder(rw http.ResponseWriter) responseRecorder { + recorder := &responseRecorderWithoutCloseNotify{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: http.StatusOK, + responseWriter: rw, + } + if _, ok := rw.(http.CloseNotifier); ok { + return &responseRecorderWithCloseNotify{recorder} + } + return recorder +} + +// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that +// records its mutations for later inspection. +type responseRecorderWithoutCloseNotify struct { + Code int // the HTTP response code from WriteHeader + HeaderMap http.Header // the HTTP response headers + Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to + + responseWriter http.ResponseWriter + err error + streamingResponseStarted bool +} + +type responseRecorderWithCloseNotify struct { + *responseRecorderWithoutCloseNotify +} + +// CloseNotify returns a channel that receives at most a +// single value (true) when the client connection has gone away. +func (rw *responseRecorderWithCloseNotify) CloseNotify() <-chan bool { + return rw.responseWriter.(http.CloseNotifier).CloseNotify() +} + +// Header returns the response headers. +func (rw *responseRecorderWithoutCloseNotify) Header() http.Header { + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + return rw.HeaderMap +} + +func (rw *responseRecorderWithoutCloseNotify) GetCode() int { + return rw.Code +} + +func (rw *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { + return rw.Body +} + +func (rw *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { + return rw.streamingResponseStarted +} + +// Write always succeeds and writes to rw.Body, if not nil. +func (rw *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { + if rw.err != nil { + return 0, rw.err + } + return rw.Body.Write(buf) +} + +// WriteHeader sets rw.Code. +func (rw *responseRecorderWithoutCloseNotify) WriteHeader(code int) { + rw.Code = code +} + +// Hijack hijacks the connection +func (rw *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.responseWriter.(http.Hijacker).Hijack() +} + +// Flush sends any buffered data to the client. +func (rw *responseRecorderWithoutCloseNotify) Flush() { + if !rw.streamingResponseStarted { + utils.CopyHeaders(rw.responseWriter.Header(), rw.Header()) + rw.responseWriter.WriteHeader(rw.Code) + rw.streamingResponseStarted = true + } + + _, err := rw.responseWriter.Write(rw.Body.Bytes()) + if err != nil { + log.Errorf("Error writing response in responseRecorder: %s", err) + rw.err = err + } + rw.Body.Reset() + + if flusher, ok := rw.responseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/middlewares/errorpages/error_pages_test.go b/middlewares/errorpages/error_pages_test.go new file mode 100644 index 000000000..4700e5b3b --- /dev/null +++ b/middlewares/errorpages/error_pages_test.go @@ -0,0 +1,383 @@ +package errorpages + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/containous/traefik/testhelpers" + "github.com/containous/traefik/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/negroni" +) + +func TestHandler(t *testing.T) { + testCases := []struct { + desc string + errorPage *types.ErrorPage + backendCode int + backendErrorHandler http.HandlerFunc + validate func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + desc: "no error", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusOK, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK)) + }, + }, + { + desc: "in the range", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusInternalServerError, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My error page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + { + desc: "not in the range", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusBadGateway, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway)) + assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code") + }, + }, + { + desc: "query replacement", + errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}}, + backendCode: http.StatusServiceUnavailable, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + fmt.Fprintln(w, "My 503 page.") + } else { + fmt.Fprintln(w, "Failed") + } + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My 503 page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + { + desc: "Single code", + errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}}, + backendCode: http.StatusServiceUnavailable, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + fmt.Fprintln(w, "My 503 page.") + } else { + fmt.Fprintln(w, "Failed") + } + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My 503 page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + errorPageHandler, err := NewHandler(test.errorPage, "test") + require.NoError(t, err) + + errorPageHandler.backendHandler = test.backendErrorHandler + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(test.backendCode) + fmt.Fprintln(w, http.StatusText(test.backendCode)) + }) + + req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil) + + n := negroni.New() + n.Use(errorPageHandler) + n.UseHandler(handler) + + recorder := httptest.NewRecorder() + n.ServeHTTP(recorder, req) + + test.validate(t, recorder) + }) + } +} + +func TestHandlerOldWay(t *testing.T) { + testCases := []struct { + desc string + errorPage *types.ErrorPage + backendCode int + errorPageForwarder http.HandlerFunc + validate func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + desc: "no error", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusOK, + errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "OK") + }, + }, + { + desc: "in the range", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusInternalServerError, + errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.Contains(t, recorder.Body.String(), "My error page.") + assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page") + }, + }, + { + desc: "not in the range", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusBadGateway, + errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusBadGateway, recorder.Code) + assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway)) + assert.NotContains(t, recorder.Body.String(), "My error page.", "Should return the oops page since we have not configured the 502 code") + }, + }, + { + desc: "query replacement", + errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}}, + backendCode: http.StatusServiceUnavailable, + errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + fmt.Fprintln(w, "My 503 page.") + } else { + fmt.Fprintln(w, "Failed") + } + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My 503 page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + { + desc: "Single code", + errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}}, + backendCode: http.StatusServiceUnavailable, + errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + fmt.Fprintln(w, "My 503 page.") + } else { + fmt.Fprintln(w, "Failed") + } + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My 503 page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + } + + req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil) + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + errorPageHandler, err := NewHandler(test.errorPage, "test") + require.NoError(t, err) + errorPageHandler.FallbackURL = "http://localhost" + + errorPageHandler.PostLoad(test.errorPageForwarder) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(test.backendCode) + fmt.Fprintln(w, http.StatusText(test.backendCode)) + }) + + n := negroni.New() + n.Use(errorPageHandler) + n.UseHandler(handler) + + recorder := httptest.NewRecorder() + n.ServeHTTP(recorder, req) + + test.validate(t, recorder) + }) + } +} + +func TestHandlerOldWayIntegration(t *testing.T) { + errorPagesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() == "/"+strconv.Itoa(503) { + fmt.Fprintln(w, "My 503 page.") + } else { + fmt.Fprintln(w, "Test Server") + } + })) + defer errorPagesServer.Close() + + testCases := []struct { + desc string + errorPage *types.ErrorPage + backendCode int + validate func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + desc: "no error", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusOK, + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "OK") + }, + }, + { + desc: "in the range", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusInternalServerError, + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Test Server") + assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page") + }, + }, + { + desc: "not in the range", + errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusBadGateway, + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusBadGateway, recorder.Code) + assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway)) + assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code") + }, + }, + { + desc: "query replacement", + errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}}, + backendCode: http.StatusServiceUnavailable, + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My 503 page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + { + desc: "Single code", + errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}}, + backendCode: http.StatusServiceUnavailable, + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "My 503 page.") + assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page") + }, + }, + } + + req := testhelpers.MustNewRequest(http.MethodGet, errorPagesServer.URL+"/test", nil) + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + + errorPageHandler, err := NewHandler(test.errorPage, "test") + require.NoError(t, err) + errorPageHandler.FallbackURL = errorPagesServer.URL + + err = errorPageHandler.PostLoad(nil) + require.NoError(t, err) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(test.backendCode) + fmt.Fprintln(w, http.StatusText(test.backendCode)) + }) + + n := negroni.New() + n.Use(errorPageHandler) + n.UseHandler(handler) + + recorder := httptest.NewRecorder() + n.ServeHTTP(recorder, req) + + test.validate(t, recorder) + }) + } +} + +func TestNewResponseRecorder(t *testing.T) { + testCases := []struct { + desc string + rw http.ResponseWriter + expected http.ResponseWriter + }{ + { + desc: "Without Close Notify", + rw: httptest.NewRecorder(), + expected: &responseRecorderWithoutCloseNotify{}, + }, + { + desc: "With Close Notify", + rw: &mockRWCloseNotify{}, + expected: &responseRecorderWithCloseNotify{}, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rec := newResponseRecorder(test.rw) + + assert.IsType(t, rec, test.expected) + }) + } +} + +type mockRWCloseNotify struct{} + +func (m *mockRWCloseNotify) CloseNotify() <-chan bool { + panic("implement me") +} + +func (m *mockRWCloseNotify) Header() http.Header { + panic("implement me") +} + +func (m *mockRWCloseNotify) Write([]byte) (int, error) { + panic("implement me") +} + +func (m *mockRWCloseNotify) WriteHeader(int) { + panic("implement me") +} diff --git a/provider/docker/config_container_docker_test.go b/provider/docker/config_container_docker_test.go index d6b96f3b1..2d1de1218 100644 --- a/provider/docker/config_container_docker_test.go +++ b/provider/docker/config_container_docker_test.go @@ -236,12 +236,12 @@ func TestDockerBuildConfiguration(t *testing.T) { "foo": { Status: []string{"404"}, Query: "foo_query", - Backend: "foobar", + Backend: "backend-foobar", }, "bar": { Status: []string{"500", "600"}, Query: "bar_query", - Backend: "foobar", + Backend: "backend-foobar", }, }, RateLimit: &types.RateLimit{ diff --git a/provider/docker/config_container_swarm_test.go b/provider/docker/config_container_swarm_test.go index 5cab4c4a3..9b09c637f 100644 --- a/provider/docker/config_container_swarm_test.go +++ b/provider/docker/config_container_swarm_test.go @@ -243,12 +243,12 @@ func TestSwarmBuildConfiguration(t *testing.T) { "foo": { Status: []string{"404"}, Query: "foo_query", - Backend: "foobar", + Backend: "backend-foobar", }, "bar": { Status: []string{"500", "600"}, Query: "bar_query", - Backend: "foobar", + Backend: "backend-foobar", }, }, RateLimit: &types.RateLimit{ diff --git a/provider/docker/config_segment_test.go b/provider/docker/config_segment_test.go index babdc04d5..6c2bda38b 100644 --- a/provider/docker/config_segment_test.go +++ b/provider/docker/config_segment_test.go @@ -193,12 +193,12 @@ func TestSegmentBuildConfiguration(t *testing.T) { "foo": { Status: []string{"404"}, Query: "foo_query", - Backend: "foobar", + Backend: "backend-foobar", }, "bar": { Status: []string{"500", "600"}, Query: "bar_query", - Backend: "foobar", + Backend: "backend-foobar", }, }, RateLimit: &types.RateLimit{ diff --git a/provider/ecs/config_test.go b/provider/ecs/config_test.go index 3aea2ba2e..18b096d60 100644 --- a/provider/ecs/config_test.go +++ b/provider/ecs/config_test.go @@ -317,14 +317,14 @@ func TestBuildConfiguration(t *testing.T) { "500", "600", }, - Backend: "foobar", + Backend: "backend-foobar", Query: "bar_query", }, "foo": { Status: []string{ "404", }, - Backend: "foobar", + Backend: "backend-foobar", Query: "foo_query", }, }, diff --git a/provider/marathon/config_test.go b/provider/marathon/config_test.go index a6f5ac6d1..1df7fcbf8 100644 --- a/provider/marathon/config_test.go +++ b/provider/marathon/config_test.go @@ -307,14 +307,14 @@ func TestBuildConfiguration(t *testing.T) { "500", "600", }, - Backend: "foobar", + Backend: "backendfoobar", Query: "bar_query", }, "foo": { Status: []string{ "404", }, - Backend: "foobar", + Backend: "backendfoobar", Query: "foo_query", }, }, @@ -674,14 +674,14 @@ func TestBuildConfigurationSegments(t *testing.T) { "500", "600", }, - Backend: "foobar", + Backend: "backendfoobar", Query: "bar_query", }, "foo": { Status: []string{ "404", }, - Backend: "foobar", + Backend: "backendfoobar", Query: "foo_query", }, }, diff --git a/provider/mesos/config_test.go b/provider/mesos/config_test.go index 335858723..a394f9b82 100644 --- a/provider/mesos/config_test.go +++ b/provider/mesos/config_test.go @@ -260,12 +260,12 @@ func TestBuildConfiguration(t *testing.T) { "foo": { Status: []string{"404"}, Query: "foo_query", - Backend: "foobar", + Backend: "backend-foobar", }, "bar": { Status: []string{"500", "600"}, Query: "bar_query", - Backend: "foobar", + Backend: "backend-foobar", }, }, RateLimit: &types.RateLimit{ diff --git a/provider/provider_test.go b/provider/provider_test.go index 3e52f3cb5..b1b1ddaec 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -23,27 +23,25 @@ func (p *myProvider) Foo() string { func TestConfigurationErrors(t *testing.T) { templateErrorFile, err := ioutil.TempFile("", "provider-configuration-error") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + defer os.RemoveAll(templateErrorFile.Name()) + data := []byte("Not a valid template {{ Bar }}") + err = ioutil.WriteFile(templateErrorFile.Name(), data, 0700) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) templateInvalidTOMLFile, err := ioutil.TempFile("", "provider-configuration-error") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + defer os.RemoveAll(templateInvalidTOMLFile.Name()) + data = []byte(`Hello {{ .Name }} {{ Foo }}`) + err = ioutil.WriteFile(templateInvalidTOMLFile.Name(), data, 0700) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) invalids := []struct { provider *myProvider @@ -54,10 +52,9 @@ func TestConfigurationErrors(t *testing.T) { }{ { provider: &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: "/non/existent/template.tmpl", }, - nil, }, expectedError: "open /non/existent/template.tmpl: no such file or directory", }, @@ -68,19 +65,17 @@ func TestConfigurationErrors(t *testing.T) { }, { provider: &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: templateErrorFile.Name(), }, - nil, }, expectedError: `function "Bar" not defined`, }, { provider: &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: templateInvalidTOMLFile.Name(), }, - nil, }, expectedError: "Near line 1 (last key parsed 'Hello'): expected key separator '=', but got '<' instead", funcMap: template.FuncMap{ @@ -97,18 +92,17 @@ func TestConfigurationErrors(t *testing.T) { if err == nil || !strings.Contains(err.Error(), invalid.expectedError) { t.Fatalf("should have generate an error with %q, got %v", invalid.expectedError, err) } - if configuration != nil { - t.Fatalf("shouldn't have return a configuration object : %v", configuration) - } + + assert.Nil(t, configuration) } } func TestGetConfiguration(t *testing.T) { templateFile, err := ioutil.TempFile("", "provider-configuration") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + defer os.RemoveAll(templateFile.Name()) + data := []byte(`[backends] [backends.backend1] [backends.backend1.circuitbreaker] @@ -127,120 +121,103 @@ func TestGetConfiguration(t *testing.T) { [frontends.frontend11.routes.test_2] rule = "Path" value = "/test"`) + err = ioutil.WriteFile(templateFile.Name(), data, 0700) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) provider := &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: templateFile.Name(), }, - nil, } + configuration, err := provider.GetConfiguration(templateFile.Name(), nil, nil) - if err != nil { - t.Fatalf("Shouldn't have error out, got %v", err) - } - if configuration == nil { - t.Fatal("Configuration should not be nil, but was") - } + require.NoError(t, err) + + assert.NotNil(t, configuration) } func TestGetConfigurationReturnsCorrectMaxConnConfiguration(t *testing.T) { templateFile, err := ioutil.TempFile("", "provider-configuration") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + defer os.RemoveAll(templateFile.Name()) + data := []byte(`[backends] [backends.backend1] [backends.backend1.maxconn] amount = 10 extractorFunc = "request.host"`) + err = ioutil.WriteFile(templateFile.Name(), data, 0700) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) provider := &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: templateFile.Name(), }, - nil, } + configuration, err := provider.GetConfiguration(templateFile.Name(), nil, nil) - if err != nil { - t.Fatalf("Shouldn't have error out, got %v", err) - } - if configuration == nil { - t.Fatal("Configuration should not be nil, but was") - } + require.NoError(t, err) - if configuration.Backends["backend1"].MaxConn.Amount != 10 { - t.Fatal("Configuration did not parse MaxConn.Amount properly") - } - - if configuration.Backends["backend1"].MaxConn.ExtractorFunc != "request.host" { - t.Fatal("Configuration did not parse MaxConn.ExtractorFunc properly") - } + require.NotNil(t, configuration) + require.Contains(t, configuration.Backends, "backend1") + assert.EqualValues(t, 10, configuration.Backends["backend1"].MaxConn.Amount) + assert.Equal(t, "request.host", configuration.Backends["backend1"].MaxConn.ExtractorFunc) } func TestNilClientTLS(t *testing.T) { - provider := &myProvider{ - BaseProvider{ + p := &myProvider{ + BaseProvider: BaseProvider{ Filename: "", }, - nil, - } - _, err := provider.TLS.CreateTLSConfig() - if err != nil { - t.Fatal("CreateTLSConfig should assume that consumer does not want a TLS configuration if input is nil") } + + _, err := p.TLS.CreateTLSConfig() + require.NoError(t, err, "CreateTLSConfig should assume that consumer does not want a TLS configuration if input is nil") } func TestInsecureSkipVerifyClientTLS(t *testing.T) { - provider := &myProvider{ - BaseProvider{ + p := &myProvider{ + BaseProvider: BaseProvider{ Filename: "", }, - &types.ClientTLS{ + TLS: &types.ClientTLS{ InsecureSkipVerify: true, }, } - config, err := provider.TLS.CreateTLSConfig() - if err != nil { - t.Fatal("CreateTLSConfig should assume that consumer does not want a TLS configuration if input is nil") - } - if !config.InsecureSkipVerify { - t.Fatal("CreateTLSConfig should support setting only InsecureSkipVerify property") - } + + config, err := p.TLS.CreateTLSConfig() + require.NoError(t, err, "CreateTLSConfig should assume that consumer does not want a TLS configuration if input is nil") + + assert.True(t, config.InsecureSkipVerify, "CreateTLSConfig should support setting only InsecureSkipVerify property") } func TestInsecureSkipVerifyFalseClientTLS(t *testing.T) { - provider := &myProvider{ - BaseProvider{ + p := &myProvider{ + BaseProvider: BaseProvider{ Filename: "", }, - &types.ClientTLS{ + TLS: &types.ClientTLS{ InsecureSkipVerify: false, }, } - _, err := provider.TLS.CreateTLSConfig() - if err == nil { - t.Fatal("CreateTLSConfig should error if consumer does not set a TLS cert or key configuration and not chooses InsecureSkipVerify to be true") - } - t.Log(err) + + _, err := p.TLS.CreateTLSConfig() + assert.Errorf(t, err, "CreateTLSConfig should error if consumer does not set a TLS cert or key configuration and not chooses InsecureSkipVerify to be true") } func TestMatchingConstraints(t *testing.T) { - cases := []struct { + testCases := []struct { + desc string constraints types.Constraints tags []string expected bool }{ // simple test: must match { + desc: "tag==us-east-1 with us-east-1", constraints: types.Constraints{ { Key: "tag", @@ -255,6 +232,7 @@ func TestMatchingConstraints(t *testing.T) { }, // simple test: must match but does not match { + desc: "tag==us-east-1 with us-east-2", constraints: types.Constraints{ { Key: "tag", @@ -269,6 +247,7 @@ func TestMatchingConstraints(t *testing.T) { }, // simple test: must not match { + desc: "tag!=us-east-1 with us-east-1", constraints: types.Constraints{ { Key: "tag", @@ -283,6 +262,7 @@ func TestMatchingConstraints(t *testing.T) { }, // complex test: globbing { + desc: "tag!=us-east-* with us-east-1", constraints: types.Constraints{ { Key: "tag", @@ -297,6 +277,7 @@ func TestMatchingConstraints(t *testing.T) { }, // complex test: multiple constraints { + desc: "tag==us-east-* & tag!=api with us-east-1 & api", constraints: types.Constraints{ { Key: "tag", @@ -317,26 +298,23 @@ func TestMatchingConstraints(t *testing.T) { }, } - for i, c := range cases { - provider := myProvider{ - BaseProvider{ - Constraints: c.constraints, + for _, test := range testCases { + p := myProvider{ + BaseProvider: BaseProvider{ + Constraints: test.constraints, }, - nil, - } - actual, _ := provider.MatchConstraints(c.tags) - if actual != c.expected { - t.Fatalf("test #%v: expected %t, got %t, for %#v", i, c.expected, actual, c.constraints) } + + actual, _ := p.MatchConstraints(test.tags) + assert.Equal(t, test.expected, actual) } } func TestDefaultFuncMap(t *testing.T) { templateFile, err := ioutil.TempFile("", "provider-configuration") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer os.RemoveAll(templateFile.Name()) + data := []byte(` [backends] [backends.{{ "backend-1" | replace "-" "" }}] @@ -360,38 +338,30 @@ func TestDefaultFuncMap(t *testing.T) { [frontends.frontend-1.routes.test_2] rule = "Path" value = "/test"`) + err = ioutil.WriteFile(templateFile.Name(), data, 0700) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) provider := &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: templateFile.Name(), }, - nil, } + configuration, err := provider.GetConfiguration(templateFile.Name(), nil, nil) - if err != nil { - t.Fatalf("Shouldn't have error out, got %v", err) - } - if configuration == nil { - t.Fatal("Configuration should not be nil, but was") - } - if _, ok := configuration.Backends["backend1"]; !ok { - t.Fatal("backend1 should exists, but it not") - } - if _, ok := configuration.Frontends["frontend-1"]; !ok { - t.Fatal("Frontend frontend-1 should exists, but it not") - } + require.NoError(t, err) + + require.NotNil(t, configuration) + assert.Contains(t, configuration.Backends, "backend1") + assert.Contains(t, configuration.Frontends, "frontend-1") } func TestSprigFunctions(t *testing.T) { templateFile, err := ioutil.TempFile("", "provider-configuration") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + defer os.RemoveAll(templateFile.Name()) + data := []byte(` {{$backend_name := trimAll "-" uuidv4}} [backends] @@ -408,30 +378,22 @@ func TestSprigFunctions(t *testing.T) { [frontends.frontend-1.routes.test_2] rule = "Path" value = "/test"`) + err = ioutil.WriteFile(templateFile.Name(), data, 0700) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) provider := &myProvider{ - BaseProvider{ + BaseProvider: BaseProvider{ Filename: templateFile.Name(), }, - nil, } + configuration, err := provider.GetConfiguration(templateFile.Name(), nil, nil) - if err != nil { - t.Fatalf("Shouldn't have error out, got %v", err) - } - if configuration == nil { - t.Fatal("Configuration should not be nil, but was") - } - if len(configuration.Backends) != 1 { - t.Fatal("one backend should be defined, but it's not") - } - if _, ok := configuration.Frontends["frontend-1"]; !ok { - t.Fatal("Frontend frontend-1 should exists, but it not") - } + require.NoError(t, err) + + require.NotNil(t, configuration) + assert.Len(t, configuration.Backends, 1) + assert.Contains(t, configuration.Frontends, "frontend-1") } func TestBaseProvider_GetConfiguration(t *testing.T) { diff --git a/provider/rancher/config_test.go b/provider/rancher/config_test.go index 88aca6fb6..b6223ab7b 100644 --- a/provider/rancher/config_test.go +++ b/provider/rancher/config_test.go @@ -179,12 +179,12 @@ func TestProviderBuildConfiguration(t *testing.T) { "foo": { Status: []string{"404"}, Query: "foo_query", - Backend: "foobar", + Backend: "backend-foobar", }, "bar": { Status: []string{"500", "600"}, Query: "bar_query", - Backend: "foobar", + Backend: "backend-foobar", }, }, RateLimit: &types.RateLimit{ @@ -371,12 +371,12 @@ func TestProviderBuildConfiguration(t *testing.T) { Errors: map[string]*types.ErrorPage{ "bar": { Status: []string{"500", "600"}, - Backend: "foobar", + Backend: "backend-foobar", Query: "bar_query", }, "foo": { Status: []string{"404"}, - Backend: "foobar", + Backend: "backend-foobar", Query: "foo_query", }, }, diff --git a/server/server.go b/server/server.go index 30ea7bf6f..af1f9916c 100644 --- a/server/server.go +++ b/server/server.go @@ -30,6 +30,7 @@ import ( "github.com/containous/traefik/middlewares" "github.com/containous/traefik/middlewares/accesslog" mauth "github.com/containous/traefik/middlewares/auth" + "github.com/containous/traefik/middlewares/errorpages" "github.com/containous/traefik/middlewares/redirect" "github.com/containous/traefik/middlewares/tracing" "github.com/containous/traefik/provider" @@ -912,6 +913,8 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura redirectHandlers := make(map[string]negroni.Handler) backends := map[string]http.Handler{} backendsHealthCheck := map[string]*healthcheck.BackendHealthCheck{} + var errorPageHandlers []*errorpages.Handler + errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{}) for providerName, config := range configurations { @@ -1089,46 +1092,57 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura } if len(frontend.Errors) > 0 { - for _, errorPage := range frontend.Errors { - if config.Backends[errorPage.Backend] != nil && config.Backends[errorPage.Backend].Servers["error"].URL != "" { - errorPageHandler, err := middlewares.NewErrorPagesHandler(errorPage, config.Backends[errorPage.Backend].Servers["error"].URL) - if err != nil { - log.Errorf("Error creating custom error page middleware, %v", err) - } else { - n.Use(errorPageHandler) - } + for errorPageName, errorPage := range frontend.Errors { + if frontend.Backend == errorPage.Backend { + log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).", + errorPageName, frontendName, errorPage.Backend) + } else if config.Backends[errorPage.Backend] == nil { + log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.", + errorPageName, errorPage.Backend) } else { - log.Errorf("Error Page is configured for Frontend %s, but either Backend %s is not set or Backend URL is missing", frontendName, errorPage.Backend) + errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend) + if err != nil { + log.Errorf("Error creating error pages: %v", err) + } else { + if errorPageServer, ok := config.Backends[errorPage.Backend].Servers["error"]; ok { + errorPagesHandler.FallbackURL = errorPageServer.URL + } + + errorPageHandlers = append(errorPageHandlers, errorPagesHandler) + n.Use(errorPagesHandler) + } } } } if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 { lb, err = s.buildRateLimiter(lb, frontend.RateLimit) - lb = s.wrapHTTPHandlerWithAccessLog(lb, fmt.Sprintf("rate limit for %s", frontendName)) if err != nil { log.Errorf("Error creating rate limiter: %v", err) log.Errorf("Skipping frontend %s...", frontendName) continue frontend } + lb = s.wrapHTTPHandlerWithAccessLog(lb, fmt.Sprintf("rate limit for %s", frontendName)) } maxConns := config.Backends[frontend.Backend].MaxConn if maxConns != nil && maxConns.Amount != 0 { extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc) if err != nil { - log.Errorf("Error creating connlimit: %v", err) + log.Errorf("Error creating connection limit: %v", err) log.Errorf("Skipping frontend %s...", frontendName) continue frontend } - log.Debugf("Creating load-balancer connlimit") + + log.Debugf("Creating load-balancer connection limit") + lb, err = connlimit.New(lb, extractFunc, maxConns.Amount) - lb = s.wrapHTTPHandlerWithAccessLog(lb, fmt.Sprintf("connection limit for %s", frontendName)) if err != nil { - log.Errorf("Error creating connlimit: %v", err) + log.Errorf("Error creating connection limit: %v", err) log.Errorf("Skipping frontend %s...", frontendName) continue frontend } + lb = s.wrapHTTPHandlerWithAccessLog(lb, fmt.Sprintf("connection limit for %s", frontendName)) } if globalConfiguration.Retry != nil { @@ -1229,15 +1243,25 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura } } } + + for _, errorPageHandler := range errorPageHandlers { + if handler, ok := backends[errorPageHandler.BackendName]; ok { + errorPageHandler.PostLoad(handler) + } else { + errorPageHandler.PostLoad(nil) + } + } + healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck) + // Get new certificates list sorted per entrypoints // Update certificates entryPointsCertificates, err := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints) + // Sort routes and update certificates for serverEntryPointName, serverEntryPoint := range serverEntryPoints { serverEntryPoint.httpRouter.GetHandler().SortRoutes() - _, exists := entryPointsCertificates[serverEntryPointName] - if exists { + if _, exists := entryPointsCertificates[serverEntryPointName]; exists { serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName]) } } diff --git a/templates/consul_catalog.tmpl b/templates/consul_catalog.tmpl index cea3bd5c6..ed30f0582 100644 --- a/templates/consul_catalog.tmpl +++ b/templates/consul_catalog.tmpl @@ -96,7 +96,7 @@ status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} diff --git a/templates/docker.tmpl b/templates/docker.tmpl index 307275914..72387dbef 100644 --- a/templates/docker.tmpl +++ b/templates/docker.tmpl @@ -97,7 +97,7 @@ status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} diff --git a/templates/ecs.tmpl b/templates/ecs.tmpl index 719e56f55..5b59889f5 100644 --- a/templates/ecs.tmpl +++ b/templates/ecs.tmpl @@ -96,7 +96,7 @@ status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} diff --git a/templates/marathon.tmpl b/templates/marathon.tmpl index 83e04a9be..d9dfc3d6a 100644 --- a/templates/marathon.tmpl +++ b/templates/marathon.tmpl @@ -99,7 +99,7 @@ status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} diff --git a/templates/mesos.tmpl b/templates/mesos.tmpl index 9ea840d7c..30e3cac4e 100644 --- a/templates/mesos.tmpl +++ b/templates/mesos.tmpl @@ -99,7 +99,7 @@ status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}} diff --git a/templates/rancher.tmpl b/templates/rancher.tmpl index e78e5d9aa..f275fe940 100644 --- a/templates/rancher.tmpl +++ b/templates/rancher.tmpl @@ -97,7 +97,7 @@ status = [{{range $page.Status }} "{{.}}", {{end}}] - backend = "{{ $page.Backend }}" + backend = "backend-{{ $page.Backend }}" query = "{{ $page.Query }}" {{end}} {{end}}