From 807dc46ad0b160cb9e5b4f312752dca1d1bde8e8 Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Mon, 6 Jan 2020 16:56:05 +0100 Subject: [PATCH] Handle respondingtimeout and better shutdown tests. Co-authored-by: Mathieu Lonjaret --- cmd/traefik/traefik.go | 2 +- pkg/server/server_entrypoint_tcp.go | 86 ++++--- pkg/server/server_entrypoint_tcp_test.go | 273 +++++++++++++++-------- pkg/tcp/router.go | 36 ++- 4 files changed, 258 insertions(+), 139 deletions(-) diff --git a/cmd/traefik/traefik.go b/cmd/traefik/traefik.go index b773a0917..d457f1232 100644 --- a/cmd/traefik/traefik.go +++ b/cmd/traefik/traefik.go @@ -172,7 +172,7 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err acmeProviders := initACMEProvider(staticConfiguration, &providerAggregator, tlsManager) - serverEntryPointsTCP, err := server.NewTCPEntryPoints(*staticConfiguration) + serverEntryPointsTCP, err := server.NewTCPEntryPoints(staticConfiguration.EntryPoints) if err != nil { return nil, err } diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index e67b5b200..db4dcb5bd 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -52,9 +52,9 @@ func (h *httpForwarder) Accept() (net.Conn, error) { type TCPEntryPoints map[string]*TCPEntryPoint // NewTCPEntryPoints creates a new TCPEntryPoints. -func NewTCPEntryPoints(staticConfiguration static.Configuration) (TCPEntryPoints, error) { +func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, error) { serverEntryPointsTCP := make(TCPEntryPoints) - for entryPointName, config := range staticConfiguration.EntryPoints { + for entryPointName, config := range entryPointsConfig { ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName)) var err error @@ -171,6 +171,23 @@ func (e *TCPEntryPoint) StartTCP(ctx context.Context) { } safe.Go(func() { + // Enforce read/write deadlines at the connection level, + // because when we're peeking the first byte to determine whether we are doing TLS, + // the deadlines at the server level are not taken into account. + if e.transportConfiguration.RespondingTimeouts.ReadTimeout > 0 { + err := writeCloser.SetReadDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.ReadTimeout))) + if err != nil { + logger.Errorf("Error while setting read deadline: %v", err) + } + } + + if e.transportConfiguration.RespondingTimeouts.WriteTimeout > 0 { + err = writeCloser.SetWriteDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.WriteTimeout))) + if err != nil { + logger.Errorf("Error while setting write deadline: %v", err) + } + } + e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker)) }) } @@ -191,48 +208,48 @@ func (e *TCPEntryPoint) Shutdown(ctx context.Context) { logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut) var wg sync.WaitGroup + + shutdownServer := func(server stoppableServer) { + defer wg.Done() + err := server.Shutdown(ctx) + if err == nil { + return + } + if ctx.Err() == context.DeadlineExceeded { + logger.Debugf("Server failed to shutdown within deadline because: %s", err) + if err = server.Close(); err != nil { + logger.Error(err) + } + return + } + logger.Error(err) + // We expect Close to fail again because Shutdown most likely failed when trying to close a listener. + // We still call it however, to make sure that all connections get closed as well. + server.Close() + } + if e.httpServer.Server != nil { wg.Add(1) - go func() { - defer wg.Done() - if err := e.httpServer.Server.Shutdown(ctx); err != nil { - if ctx.Err() == context.DeadlineExceeded { - logger.Debugf("Wait server shutdown is overdue to: %s", err) - err = e.httpServer.Server.Close() - if err != nil { - logger.Error(err) - } - } - } - }() + go shutdownServer(e.httpServer.Server) } if e.httpsServer.Server != nil { wg.Add(1) - go func() { - defer wg.Done() - if err := e.httpsServer.Server.Shutdown(ctx); err != nil { - if ctx.Err() == context.DeadlineExceeded { - logger.Debugf("Wait server shutdown is overdue to: %s", err) - err = e.httpsServer.Server.Close() - if err != nil { - logger.Error(err) - } - } - } - }() + go shutdownServer(e.httpsServer.Server) } if e.tracker != nil { wg.Add(1) go func() { defer wg.Done() - if err := e.tracker.Shutdown(ctx); err != nil { - if ctx.Err() == context.DeadlineExceeded { - logger.Debugf("Wait hijack connection is overdue to: %s", err) - e.tracker.Close() - } + err := e.tracker.Shutdown(ctx) + if err == nil { + return } + if ctx.Err() == context.DeadlineExceeded { + logger.Debugf("Server failed to shutdown before deadline because: %s", err) + } + e.tracker.Close() }() } @@ -459,8 +476,11 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati } serverHTTP := &http.Server{ - Handler: handler, - ErrorLog: httpServerLogger, + Handler: handler, + ErrorLog: httpServerLogger, + ReadTimeout: time.Duration(configuration.Transport.RespondingTimeouts.ReadTimeout), + WriteTimeout: time.Duration(configuration.Transport.RespondingTimeouts.WriteTimeout), + IdleTimeout: time.Duration(configuration.Transport.RespondingTimeouts.IdleTimeout), } listener := newHTTPForwarder(ln) diff --git a/pkg/server/server_entrypoint_tcp_test.go b/pkg/server/server_entrypoint_tcp_test.go index 425a06390..6a0f75c08 100644 --- a/pkg/server/server_entrypoint_tcp_test.go +++ b/pkg/server/server_entrypoint_tcp_test.go @@ -3,8 +3,11 @@ package server import ( "bufio" "context" + "errors" + "io" "net" "net/http" + "strings" "testing" "time" @@ -15,128 +18,206 @@ import ( "github.com/stretchr/testify/require" ) -func TestShutdownHTTP(t *testing.T) { - entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{ - Address: ":0", - Transport: &static.EntryPointsTransport{ - LifeCycle: &static.LifeCycle{ - RequestAcceptGraceTimeout: 0, - GraceTimeOut: types.Duration(5 * time.Second), - }, - }, - ForwardedHeaders: &static.ForwardedHeaders{}, - }) - require.NoError(t, err) - - go entryPoint.StartTCP(context.Background()) - - router := &tcp.Router{} - router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - time.Sleep(1 * time.Second) - rw.WriteHeader(http.StatusOK) - })) - entryPoint.SwitchRouter(router) - - conn, err := net.Dial("tcp", entryPoint.listener.Addr().String()) - require.NoError(t, err) - - go entryPoint.Shutdown(context.Background()) - - request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil) - require.NoError(t, err) - - err = request.Write(conn) - require.NoError(t, err) - - resp, err := http.ReadResponse(bufio.NewReader(conn), request) - require.NoError(t, err) - assert.Equal(t, resp.StatusCode, http.StatusOK) -} - -func TestShutdownHTTPHijacked(t *testing.T) { - entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{ - Address: ":0", - Transport: &static.EntryPointsTransport{ - LifeCycle: &static.LifeCycle{ - RequestAcceptGraceTimeout: 0, - GraceTimeOut: types.Duration(5 * time.Second), - }, - }, - ForwardedHeaders: &static.ForwardedHeaders{}, - }) - require.NoError(t, err) - - go entryPoint.StartTCP(context.Background()) - +func TestShutdownHijacked(t *testing.T) { router := &tcp.Router{} router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { conn, _, err := rw.(http.Hijacker).Hijack() require.NoError(t, err) - time.Sleep(1 * time.Second) resp := http.Response{StatusCode: http.StatusOK} err = resp.Write(conn) require.NoError(t, err) })) - - entryPoint.SwitchRouter(router) - - conn, err := net.Dial("tcp", entryPoint.listener.Addr().String()) - require.NoError(t, err) - - go entryPoint.Shutdown(context.Background()) - - request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil) - require.NoError(t, err) - - err = request.Write(conn) - require.NoError(t, err) - - resp, err := http.ReadResponse(bufio.NewReader(conn), request) - require.NoError(t, err) - assert.Equal(t, resp.StatusCode, http.StatusOK) + testShutdown(t, router) } -func TestShutdownTCPConn(t *testing.T) { +func TestShutdownHTTP(t *testing.T) { + router := &tcp.Router{} + router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + time.Sleep(time.Second) + })) + testShutdown(t, router) +} + +func TestShutdownTCP(t *testing.T) { + router := &tcp.Router{} + router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) { + for { + _, err := http.ReadRequest(bufio.NewReader(conn)) + + if err == io.EOF || (err != nil && strings.HasSuffix(err.Error(), "use of closed network connection")) { + return + } + require.NoError(t, err) + + resp := http.Response{StatusCode: http.StatusOK} + err = resp.Write(conn) + require.NoError(t, err) + } + })) + + testShutdown(t, router) +} + +func testShutdown(t *testing.T, router *tcp.Router) { + epConfig := &static.EntryPointsTransport{} + epConfig.SetDefaults() + + epConfig.LifeCycle.RequestAcceptGraceTimeout = 0 + epConfig.LifeCycle.GraceTimeOut = types.Duration(5 * time.Second) + entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{ - Address: ":0", - Transport: &static.EntryPointsTransport{ - LifeCycle: &static.LifeCycle{ - RequestAcceptGraceTimeout: 0, - GraceTimeOut: types.Duration(5 * time.Second), - }, - }, + // We explicitly use an IPV4 address because on Alpine, with an IPV6 address + // there seems to be shenanigans related to properly cleaning up file descriptors + Address: "127.0.0.1:0", + Transport: epConfig, ForwardedHeaders: &static.ForwardedHeaders{}, }) require.NoError(t, err) - go entryPoint.StartTCP(context.Background()) + conn, err := startEntrypoint(entryPoint, router) + require.NoError(t, err) - router := &tcp.Router{} - router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) { - _, err := http.ReadRequest(bufio.NewReader(conn)) - require.NoError(t, err) - time.Sleep(1 * time.Second) + epAddr := entryPoint.listener.Addr().String() - resp := http.Response{StatusCode: http.StatusOK} - err = resp.Write(conn) - require.NoError(t, err) - })) + request, err := http.NewRequest(http.MethodHead, "http://127.0.0.1:8082", nil) + require.NoError(t, err) - entryPoint.SwitchRouter(router) + time.Sleep(time.Millisecond * 100) - conn, err := net.Dial("tcp", entryPoint.listener.Addr().String()) + // We need to do a write on the conn before the shutdown to make it "exist". + // Because the connection indeed exists as far as TCP is concerned, + // but since we only pass it along to the HTTP server after at least one byte is peaked, + // the HTTP server (and hence its shutdown) does not know about the connection until that first byte peaking. + err = request.Write(conn) require.NoError(t, err) go entryPoint.Shutdown(context.Background()) - request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil) - require.NoError(t, err) + // Make sure that new connections are not permitted anymore. + // Note that this should be true not only after Shutdown has returned, + // but technically also as early as the Shutdown has closed the listener, + // i.e. during the shutdown and before the gracetime is over. + var testOk bool + for i := 0; i < 10; i++ { + loopConn, err := net.Dial("tcp", epAddr) + if err == nil { + loopConn.Close() + time.Sleep(time.Millisecond * 100) + continue + } + if !strings.HasSuffix(err.Error(), "connection refused") && !strings.HasSuffix(err.Error(), "reset by peer") { + t.Fatalf(`unexpected error: got %v, wanted "connection refused" or "reset by peer"`, err) + } + testOk = true + break + } + if !testOk { + t.Fatal("entry point never closed") + } - err = request.Write(conn) - require.NoError(t, err) + // And make sure that the connection we had opened before shutting things down is still operational resp, err := http.ReadResponse(bufio.NewReader(conn), request) require.NoError(t, err) - assert.Equal(t, resp.StatusCode, http.StatusOK) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func startEntrypoint(entryPoint *TCPEntryPoint, router *tcp.Router) (net.Conn, error) { + go entryPoint.StartTCP(context.Background()) + + entryPoint.SwitchRouter(router) + + var conn net.Conn + var err error + var epStarted bool + for i := 0; i < 10; i++ { + conn, err = net.Dial("tcp", entryPoint.listener.Addr().String()) + if err != nil { + time.Sleep(time.Millisecond * 100) + continue + } + epStarted = true + break + } + if !epStarted { + return nil, errors.New("entry point never started") + } + return conn, err +} + +func TestReadTimeoutWithoutFirstByte(t *testing.T) { + epConfig := &static.EntryPointsTransport{} + epConfig.SetDefaults() + epConfig.RespondingTimeouts.ReadTimeout = types.Duration(time.Second * 2) + + entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{ + Address: ":0", + Transport: epConfig, + ForwardedHeaders: &static.ForwardedHeaders{}, + }) + require.NoError(t, err) + + router := &tcp.Router{} + router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + conn, err := startEntrypoint(entryPoint, router) + require.NoError(t, err) + + errChan := make(chan error) + + go func() { + b := make([]byte, 2048) + _, err := conn.Read(b) + errChan <- err + }() + + select { + case err := <-errChan: + require.Equal(t, io.EOF, err) + case <-time.Tick(time.Second * 5): + t.Error("Timeout while read") + } +} + +func TestReadTimeoutWithFirstByte(t *testing.T) { + epConfig := &static.EntryPointsTransport{} + epConfig.SetDefaults() + epConfig.RespondingTimeouts.ReadTimeout = types.Duration(time.Second * 2) + + entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{ + Address: ":0", + Transport: epConfig, + ForwardedHeaders: &static.ForwardedHeaders{}, + }) + require.NoError(t, err) + + router := &tcp.Router{} + router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + conn, err := startEntrypoint(entryPoint, router) + require.NoError(t, err) + + _, err = conn.Write([]byte("GET /some HTTP/1.1\r\n")) + require.NoError(t, err) + + errChan := make(chan error) + + go func() { + b := make([]byte, 2048) + _, err := conn.Read(b) + errChan <- err + }() + + select { + case err := <-errChan: + require.Equal(t, io.EOF, err) + case <-time.Tick(time.Second * 5): + t.Error("Timeout while read") + } } diff --git a/pkg/tcp/router.go b/pkg/tcp/router.go index f2b3d8e88..89ad868ea 100644 --- a/pkg/tcp/router.go +++ b/pkg/tcp/router.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strings" + "time" "github.com/containous/traefik/v2/pkg/log" ) @@ -34,7 +35,23 @@ func (r *Router) ServeTCP(conn WriteCloser) { } br := bufio.NewReader(conn) - serverName, tls, peeked := clientHelloServerName(br) + serverName, tls, peeked, err := clientHelloServerName(br) + if err != nil { + conn.Close() + return + } + + // Remove read/write deadline and delegate this to underlying tcp server (for now only handled by HTTP Server) + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + log.WithoutContext().Errorf("Error while setting read deadline: %v", err) + } + + err = conn.SetWriteDeadline(time.Time{}) + if err != nil { + log.WithoutContext().Errorf("Error while setting write deadline: %v", err) + } + if !tls { switch { case r.catchAllNoTLS != nil: @@ -176,33 +193,34 @@ func (c *Conn) Read(p []byte) (n int, err error) { // clientHelloServerName returns the SNI server name inside the TLS ClientHello, // without consuming any bytes from br. // On any error, the empty string is returned. -func clientHelloServerName(br *bufio.Reader) (string, bool, string) { +func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { hdr, err := br.Peek(1) if err != nil { - if err != io.EOF { - log.Errorf("Error while Peeking first byte: %s", err) + opErr, ok := err.(*net.OpError) + if err != io.EOF && (!ok || !opErr.Timeout()) { + log.WithoutContext().Errorf("Error while Peeking first byte: %s", err) } - return "", false, "" + return "", false, "", err } const recordTypeHandshake = 0x16 if hdr[0] != recordTypeHandshake { // log.Errorf("Error not tls") - return "", false, getPeeked(br) // Not TLS. + return "", false, getPeeked(br), nil // Not TLS. } const recordHeaderLen = 5 hdr, err = br.Peek(recordHeaderLen) if err != nil { log.Errorf("Error while Peeking hello: %s", err) - return "", false, getPeeked(br) + return "", false, getPeeked(br), nil } recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] helloBytes, err := br.Peek(recordHeaderLen + recLen) if err != nil { log.Errorf("Error while Hello: %s", err) - return "", true, getPeeked(br) + return "", true, getPeeked(br), nil } sni := "" @@ -214,7 +232,7 @@ func clientHelloServerName(br *bufio.Reader) (string, bool, string) { }) _ = server.Handshake() - return sni, true, getPeeked(br) + return sni, true, getPeeked(br), nil } func getPeeked(br *bufio.Reader) string {