2018-11-14 10:18:03 +01:00
package customerrors
import (
"context"
"fmt"
2023-06-14 17:42:44 +02:00
"io"
2018-11-14 10:18:03 +01:00
"net/http"
"net/http/httptest"
2023-06-14 17:42:44 +02:00
"net/http/httptrace"
"net/textproto"
2018-11-14 10:18:03 +01:00
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
2023-02-03 15:24:05 +01:00
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/testhelpers"
2018-11-14 10:18:03 +01:00
)
func TestHandler ( t * testing . T ) {
testCases := [ ] struct {
desc string
2019-07-10 09:26:04 +02:00
errorPage * dynamic . ErrorPage
2018-11-14 10:18:03 +01:00
backendCode int
backendErrorHandler http . HandlerFunc
validate func ( t * testing . T , recorder * httptest . ResponseRecorder )
} {
{
desc : "no error" ,
2019-07-10 09:26:04 +02:00
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/test" , Status : [ ] string { "500-501" , "503-599" } } ,
2018-11-14 10:18:03 +01:00
backendCode : http . StatusOK ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "My error page." )
2018-11-14 10:18:03 +01:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2018-11-14 10:18:03 +01:00
assert . Equal ( t , http . StatusOK , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , http . StatusText ( http . StatusOK ) )
} ,
} ,
2019-09-12 16:20:05 +02:00
{
desc : "no error, but not a 200" ,
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/test" , Status : [ ] string { "500-501" , "503-599" } } ,
backendCode : http . StatusPartialContent ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "My error page." )
2019-09-12 16:20:05 +02:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2019-09-12 16:20:05 +02:00
assert . Equal ( t , http . StatusPartialContent , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , http . StatusText ( http . StatusPartialContent ) )
} ,
} ,
{
desc : "a 304, so no Write called" ,
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/test" , Status : [ ] string { "500-501" , "503-599" } } ,
backendCode : http . StatusNotModified ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "whatever, should not be called" )
2019-09-12 16:20:05 +02:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2019-09-12 16:20:05 +02:00
assert . Equal ( t , http . StatusNotModified , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , "" )
} ,
} ,
2018-11-14 10:18:03 +01:00
{
desc : "in the range" ,
2019-07-10 09:26:04 +02:00
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/test" , Status : [ ] string { "500-501" , "503-599" } } ,
2018-11-14 10:18:03 +01:00
backendCode : http . StatusInternalServerError ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "My error page." )
2018-11-14 10:18:03 +01:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2018-11-14 10:18:03 +01:00
assert . Equal ( t , http . StatusInternalServerError , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , "My error page." )
} ,
} ,
{
desc : "not in the range" ,
2019-07-10 09:26:04 +02:00
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/test" , Status : [ ] string { "500-501" , "503-599" } } ,
2018-11-14 10:18:03 +01:00
backendCode : http . StatusBadGateway ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "My error page." )
2018-11-14 10:18:03 +01:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2018-11-14 10:18:03 +01:00
assert . Equal ( t , http . StatusBadGateway , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , http . StatusText ( http . StatusBadGateway ) )
} ,
} ,
{
desc : "query replacement" ,
2019-07-10 09:26:04 +02:00
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/{status}" , Status : [ ] string { "503-503" } } ,
2018-11-14 10:18:03 +01:00
backendCode : http . StatusServiceUnavailable ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
if r . RequestURI != "/503" {
return
2018-11-14 10:18:03 +01:00
}
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "My 503 page." )
2018-11-14 10:18:03 +01:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2018-11-14 10:18:03 +01:00
assert . Equal ( t , http . StatusServiceUnavailable , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , "My 503 page." )
} ,
} ,
{
2021-09-27 17:40:13 +02:00
desc : "single code and query replacement" ,
2019-07-10 09:26:04 +02:00
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/{status}" , Status : [ ] string { "503" } } ,
2018-11-14 10:18:03 +01:00
backendCode : http . StatusServiceUnavailable ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2021-09-27 17:40:13 +02:00
if r . RequestURI != "/503" {
return
2018-11-14 10:18:03 +01:00
}
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , "My 503 page." )
2018-11-14 10:18:03 +01:00
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
2021-01-28 09:00:03 +01:00
t . Helper ( )
2018-11-14 10:18:03 +01:00
assert . Equal ( t , http . StatusServiceUnavailable , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , "My 503 page." )
2021-09-27 17:40:13 +02:00
} ,
} ,
{
desc : "forward request host header" ,
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/test" , Status : [ ] string { "503" } } ,
backendCode : http . StatusServiceUnavailable ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
_ , _ = fmt . Fprintln ( w , r . Host )
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
t . Helper ( )
assert . Equal ( t , http . StatusServiceUnavailable , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , "localhost" )
2018-11-14 10:18:03 +01:00
} ,
} ,
2022-05-10 11:00:09 +02:00
{
desc : "full query replacement" ,
errorPage : & dynamic . ErrorPage { Service : "error" , Query : "/?status={status}&url={url}" , Status : [ ] string { "503" } } ,
backendCode : http . StatusServiceUnavailable ,
backendErrorHandler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
if r . RequestURI != "/?status=503&url=http%3A%2F%2Flocalhost%2Ftest%3Ffoo%3Dbar%26baz%3Dbuz" {
t . Log ( r . RequestURI )
return
}
_ , _ = fmt . Fprintln ( w , "My 503 page." )
} ) ,
validate : func ( t * testing . T , recorder * httptest . ResponseRecorder ) {
t . Helper ( )
assert . Equal ( t , http . StatusServiceUnavailable , recorder . Code , "HTTP status" )
assert . Contains ( t , recorder . Body . String ( ) , "My 503 page." )
} ,
} ,
2018-11-14 10:18:03 +01:00
}
for _ , test := range testCases {
test := test
t . Run ( test . desc , func ( t * testing . T ) {
t . Parallel ( )
serviceBuilderMock := & mockServiceBuilder { handler : test . backendErrorHandler }
handler := http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
w . WriteHeader ( test . backendCode )
2021-09-27 17:40:13 +02:00
2019-09-12 16:20:05 +02:00
if test . backendCode == http . StatusNotModified {
return
}
2021-09-27 17:40:13 +02:00
_ , _ = fmt . Fprintln ( w , http . StatusText ( test . backendCode ) )
2018-11-14 10:18:03 +01:00
} )
errorPageHandler , err := New ( context . Background ( ) , handler , * test . errorPage , serviceBuilderMock , "test" )
require . NoError ( t , err )
2022-05-10 11:00:09 +02:00
req := testhelpers . MustNewRequest ( http . MethodGet , "http://localhost/test?foo=bar&baz=buz" , nil )
2018-11-14 10:18:03 +01:00
recorder := httptest . NewRecorder ( )
errorPageHandler . ServeHTTP ( recorder , req )
test . validate ( t , recorder )
} )
}
}
2023-06-14 17:42:44 +02:00
// This test is an adapted version of net/http/httputil.Test1xxResponses test.
func Test1xxResponses ( t * testing . T ) {
next := http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
h := w . Header ( )
h . Add ( "Link" , "</style.css>; rel=preload; as=style" )
h . Add ( "Link" , "</script.js>; rel=preload; as=script" )
w . WriteHeader ( http . StatusEarlyHints )
h . Add ( "Link" , "</foo.js>; rel=preload; as=script" )
w . WriteHeader ( http . StatusProcessing )
h . Add ( "User-Agent" , "foobar" )
_ , _ = w . Write ( [ ] byte ( "Hello" ) )
w . WriteHeader ( http . StatusBadGateway )
} )
serviceBuilderMock := & mockServiceBuilder { handler : http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
_ , _ = fmt . Fprintln ( w , "My error page." )
} ) }
config := dynamic . ErrorPage { Service : "error" , Query : "/" , Status : [ ] string { "200" } }
errorPageHandler , err := New ( context . Background ( ) , next , config , serviceBuilderMock , "test" )
require . NoError ( t , err )
server := httptest . NewServer ( errorPageHandler )
t . Cleanup ( server . Close )
frontendClient := server . Client ( )
checkLinkHeaders := func ( t * testing . T , expected , got [ ] string ) {
t . Helper ( )
if len ( expected ) != len ( got ) {
t . Errorf ( "Expected %d link headers; got %d" , len ( expected ) , len ( got ) )
}
for i := range expected {
if i >= len ( got ) {
t . Errorf ( "Expected %q link header; got nothing" , expected [ i ] )
continue
}
if expected [ i ] != got [ i ] {
t . Errorf ( "Expected %q link header; got %q" , expected [ i ] , got [ i ] )
}
}
}
var respCounter uint8
trace := & httptrace . ClientTrace {
Got1xxResponse : func ( code int , header textproto . MIMEHeader ) error {
switch code {
case http . StatusEarlyHints :
checkLinkHeaders ( t , [ ] string { "</style.css>; rel=preload; as=style" , "</script.js>; rel=preload; as=script" } , header [ "Link" ] )
case http . StatusProcessing :
checkLinkHeaders ( t , [ ] string { "</style.css>; rel=preload; as=style" , "</script.js>; rel=preload; as=script" , "</foo.js>; rel=preload; as=script" } , header [ "Link" ] )
default :
t . Error ( "Unexpected 1xx response" )
}
respCounter ++
return nil
} ,
}
req , _ := http . NewRequestWithContext ( httptrace . WithClientTrace ( context . Background ( ) , trace ) , http . MethodGet , server . URL , nil )
res , err := frontendClient . Do ( req )
2023-11-17 01:50:06 +01:00
assert . NoError ( t , err )
2023-06-14 17:42:44 +02:00
defer res . Body . Close ( )
if respCounter != 2 {
t . Errorf ( "Expected 2 1xx responses; got %d" , respCounter )
}
checkLinkHeaders ( t , [ ] string { "</style.css>; rel=preload; as=style" , "</script.js>; rel=preload; as=script" , "</foo.js>; rel=preload; as=script" } , res . Header [ "Link" ] )
body , _ := io . ReadAll ( res . Body )
assert . Equal ( t , "My error page.\n" , string ( body ) )
}
2018-11-14 10:18:03 +01:00
type mockServiceBuilder struct {
handler http . Handler
}
2020-09-04 21:06:11 +02:00
func ( m * mockServiceBuilder ) BuildHTTP ( _ context . Context , _ string ) ( http . Handler , error ) {
2018-11-14 10:18:03 +01:00
return m . handler , nil
}