1
0
mirror of https://github.com/containous/traefik.git synced 2024-12-22 13:34:03 +03:00

Extreme Makeover: server refactoring

This commit is contained in:
Ludovic Fernandez 2018-06-11 11:36:03 +02:00 committed by Traefiker Bot
parent bddb4cc33c
commit eac20d61df
19 changed files with 2356 additions and 1965 deletions

View File

@ -19,12 +19,18 @@ import (
var singleton *HealthCheck
var once sync.Once
// GetHealthCheck returns the health check which is guaranteed to be a singleton.
func GetHealthCheck(metrics metricsRegistry) *HealthCheck {
once.Do(func() {
singleton = newHealthCheck(metrics)
})
return singleton
// BalancerHandler includes functionality for load-balancing management.
type BalancerHandler interface {
ServeHTTP(w http.ResponseWriter, req *http.Request)
Servers() []*url.URL
RemoveServer(u *url.URL) error
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
}
// metricsRegistry is a local interface in the health check package, exposing only the required metrics
// necessary for the health check package. This makes it easier for the tests.
type metricsRegistry interface {
BackendServerUpGauge() metrics.Gauge
}
// Options are the public health check options.
@ -36,59 +42,59 @@ type Options struct {
Port int
Transport http.RoundTripper
Interval time.Duration
LB LoadBalancer
LB BalancerHandler
}
func (opt Options) String() string {
return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Port: %d Interval: %s]", opt.Hostname, opt.Headers, opt.Path, opt.Port, opt.Interval)
}
// BackendHealthCheck HealthCheck configuration for a backend
type BackendHealthCheck struct {
// BackendConfig HealthCheck configuration for a backend
type BackendConfig struct {
Options
name string
disabledURLs []*url.URL
requestTimeout time.Duration
}
func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) {
u := &url.URL{}
*u = *serverURL
if len(b.Scheme) > 0 {
u.Scheme = b.Scheme
}
if b.Port != 0 {
u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port))
}
u.Path += b.Path
return http.NewRequest(http.MethodGet, u.String(), nil)
}
// this function adds additional http headers and hostname to http.request
func (b *BackendConfig) addHeadersAndHost(req *http.Request) *http.Request {
if b.Options.Hostname != "" {
req.Host = b.Options.Hostname
}
for k, v := range b.Options.Headers {
req.Header.Set(k, v)
}
return req
}
// HealthCheck struct
type HealthCheck struct {
Backends map[string]*BackendHealthCheck
Backends map[string]*BackendConfig
metrics metricsRegistry
cancel context.CancelFunc
}
// LoadBalancer includes functionality for load-balancing management.
type LoadBalancer interface {
RemoveServer(u *url.URL) error
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
Servers() []*url.URL
}
func newHealthCheck(metrics metricsRegistry) *HealthCheck {
return &HealthCheck{
Backends: make(map[string]*BackendHealthCheck),
metrics: metrics,
}
}
// metricsRegistry is a local interface in the health check package, exposing only the required metrics
// necessary for the health check package. This makes it easier for the tests.
type metricsRegistry interface {
BackendServerUpGauge() metrics.Gauge
}
// NewBackendHealthCheck Instantiate a new BackendHealthCheck
func NewBackendHealthCheck(options Options, backendName string) *BackendHealthCheck {
return &BackendHealthCheck{
Options: options,
name: backendName,
requestTimeout: 5 * time.Second,
}
}
// SetBackendsConfiguration set backends configuration
func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendHealthCheck) {
func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) {
hc.Backends = backends
if hc.cancel != nil {
hc.cancel()
@ -104,7 +110,7 @@ func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backe
}
}
func (hc *HealthCheck) execute(ctx context.Context, backend *BackendHealthCheck) {
func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) {
log.Debugf("Initial health check for backend: %q", backend.name)
hc.checkBackend(backend)
ticker := time.NewTicker(backend.Interval)
@ -121,7 +127,7 @@ func (hc *HealthCheck) execute(ctx context.Context, backend *BackendHealthCheck)
}
}
func (hc *HealthCheck) checkBackend(backend *BackendHealthCheck) {
func (hc *HealthCheck) checkBackend(backend *BackendConfig) {
enabledURLs := backend.LB.Servers()
var newDisabledURLs []*url.URL
for _, url := range backend.disabledURLs {
@ -152,38 +158,33 @@ func (hc *HealthCheck) checkBackend(backend *BackendHealthCheck) {
}
}
func (b *BackendHealthCheck) newRequest(serverURL *url.URL) (*http.Request, error) {
u := &url.URL{}
*u = *serverURL
if len(b.Scheme) > 0 {
u.Scheme = b.Scheme
}
if b.Port != 0 {
u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port))
}
u.Path += b.Path
return http.NewRequest(http.MethodGet, u.String(), nil)
// GetHealthCheck returns the health check which is guaranteed to be a singleton.
func GetHealthCheck(metrics metricsRegistry) *HealthCheck {
once.Do(func() {
singleton = newHealthCheck(metrics)
})
return singleton
}
// this function adds additional http headers and hostname to http.request
func (b *BackendHealthCheck) addHeadersAndHost(req *http.Request) *http.Request {
if b.Options.Hostname != "" {
req.Host = b.Options.Hostname
func newHealthCheck(metrics metricsRegistry) *HealthCheck {
return &HealthCheck{
Backends: make(map[string]*BackendConfig),
metrics: metrics,
}
}
for k, v := range b.Options.Headers {
req.Header.Set(k, v)
// NewBackendConfig Instantiate a new BackendConfig
func NewBackendConfig(options Options, backendName string) *BackendConfig {
return &BackendConfig{
Options: options,
name: backendName,
requestTimeout: 5 * time.Second,
}
return req
}
// checkHealth returns a nil error in case it was successful and otherwise
// a non-nil error with a meaningful description why the health check failed.
func checkHealth(serverURL *url.URL, backend *BackendHealthCheck) error {
func checkHealth(serverURL *url.URL, backend *BackendConfig) error {
req, err := backend.newRequest(serverURL)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %s", err)

View File

@ -102,7 +102,7 @@ func TestSetBackendsConfiguration(t *testing.T) {
defer ts.Close()
lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}}
backend := NewBackendHealthCheck(Options{
backend := NewBackendConfig(Options{
Path: "/path",
Interval: healthCheckInterval,
LB: lb,
@ -117,7 +117,7 @@ func TestSetBackendsConfiguration(t *testing.T) {
collectingMetrics := testhelpers.NewCollectingHealthCheckMetrics()
check := HealthCheck{
Backends: make(map[string]*BackendHealthCheck),
Backends: make(map[string]*BackendConfig),
metrics: collectingMetrics,
}
@ -209,7 +209,7 @@ func TestNewRequest(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := NewBackendHealthCheck(test.options, "backendName")
backend := NewBackendConfig(test.options, "backendName")
u, err := url.Parse(test.serverURL)
require.NoError(t, err)
@ -279,7 +279,7 @@ func TestAddHeadersAndHost(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := NewBackendHealthCheck(test.options, "backendName")
backend := NewBackendConfig(test.options, "backendName")
u, err := url.Parse(test.serverURL)
require.NoError(t, err)
@ -305,6 +305,10 @@ type testLoadBalancer struct {
servers []*url.URL
}
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// noop
}
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
lb.Lock()
defer lb.Unlock()

View File

@ -102,7 +102,7 @@ func (s *AccessLogSuite) TestAccessLogAuthFrontend(c *check.C) {
formatOnly: false,
code: "401",
user: "-",
frontendName: "Auth for frontend-Host-frontend-auth-docker-local",
frontendName: "Basic Auth for frontend-Host-frontend-auth-docker-local",
backendURL: "/",
},
}
@ -354,7 +354,7 @@ func (s *AccessLogSuite) TestAccessLogEntrypointRedirect(c *check.C) {
formatOnly: false,
code: "302",
user: "-",
frontendName: "entrypoint redirect for frontend-",
frontendName: "entrypoint redirect for httpRedirect",
backendURL: "/",
},
{

View File

@ -31,38 +31,43 @@ func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing
if authConfig == nil {
return nil, fmt.Errorf("error creating Authenticator: auth is nil")
}
var err error
authenticator := Authenticator{}
tracingAuthenticator := tracingAuthenticator{}
authenticator := &Authenticator{}
tracingAuth := tracingAuthenticator{}
if authConfig.Basic != nil {
authenticator.users, err = parserBasicUsers(authConfig.Basic)
if err != nil {
return nil, err
}
basicAuth := goauth.NewBasicAuthenticator("traefik", authenticator.secretBasic)
tracingAuthenticator.handler = createAuthBasicHandler(basicAuth, authConfig)
tracingAuthenticator.name = "Auth Basic"
tracingAuthenticator.clientSpanKind = false
tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig)
tracingAuth.name = "Auth Basic"
tracingAuth.clientSpanKind = false
} else if authConfig.Digest != nil {
authenticator.users, err = parserDigestUsers(authConfig.Digest)
if err != nil {
return nil, err
}
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
tracingAuthenticator.handler = createAuthDigestHandler(digestAuth, authConfig)
tracingAuthenticator.name = "Auth Digest"
tracingAuthenticator.clientSpanKind = false
tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig)
tracingAuth.name = "Auth Digest"
tracingAuth.clientSpanKind = false
} else if authConfig.Forward != nil {
tracingAuthenticator.handler = createAuthForwardHandler(authConfig)
tracingAuthenticator.name = "Auth Forward"
tracingAuthenticator.clientSpanKind = true
tracingAuth.handler = createAuthForwardHandler(authConfig)
tracingAuth.name = "Auth Forward"
tracingAuth.clientSpanKind = true
}
if tracingMiddleware != nil {
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuthenticator.name, tracingAuthenticator.handler, tracingAuthenticator.clientSpanKind)
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind)
} else {
authenticator.handler = tracingAuthenticator.handler
authenticator.handler = tracingAuth.handler
}
return &authenticator, nil
return authenticator, nil
}
func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc {

View File

@ -23,14 +23,14 @@ func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
return cbreaker.Fallback(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tracing.LogEventf(r, "blocked by circuitbreaker (%q)", expression)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
}))
return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
}))
}
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
cb.circuitBreaker.ServeHTTP(rw, r)
}

View File

@ -10,19 +10,18 @@ import (
// has at least one active Server in respect to the healthchecks and if this
// is not the case, it will stop the middleware chain and respond with 503.
type EmptyBackendHandler struct {
lb healthcheck.LoadBalancer
next http.Handler
next healthcheck.BalancerHandler
}
// NewEmptyBackendHandler creates a new EmptyBackendHandler instance.
func NewEmptyBackendHandler(lb healthcheck.LoadBalancer, next http.Handler) *EmptyBackendHandler {
return &EmptyBackendHandler{lb: lb, next: next}
func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler {
return &EmptyBackendHandler{next: lb}
}
// ServeHTTP responds with 503 when there is no active Server and otherwise
// invokes the next handler in the middleware chain.
func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if len(h.lb.Servers()) == 0 {
if len(h.next.Servers()) == 0 {
rw.WriteHeader(http.StatusServiceUnavailable)
rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
} else {

View File

@ -32,10 +32,7 @@ func TestEmptyBackendHandler(t *testing.T) {
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
t.Parallel()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer}, nextHandler)
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer})
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
@ -53,12 +50,8 @@ type healthCheckLoadBalancer struct {
amountServer int
}
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
@ -68,3 +61,23 @@ func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
}
return servers
}
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
return 0, false
}
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
return nil, nil
}
func (lb *healthCheckLoadBalancer) Next() http.Handler {
return nil
}

View File

@ -6,20 +6,6 @@ import (
"github.com/urfave/negroni"
)
// NegroniHandlerWrapper is used to wrap negroni handler middleware
type NegroniHandlerWrapper struct {
name string
next negroni.Handler
clientSpanKind bool
}
// HTTPHandlerWrapper is used to wrap http handler middleware
type HTTPHandlerWrapper struct {
name string
handler http.Handler
clientSpanKind bool
}
// NewNegroniHandlerWrapper return a negroni.Handler struct
func (t *Tracing) NewNegroniHandlerWrapper(name string, handler negroni.Handler, clientSpanKind bool) negroni.Handler {
if t.IsEnabled() && handler != nil {
@ -44,6 +30,13 @@ func (t *Tracing) NewHTTPHandlerWrapper(name string, handler http.Handler, clien
return handler
}
// NegroniHandlerWrapper is used to wrap negroni handler middleware
type NegroniHandlerWrapper struct {
name string
next negroni.Handler
clientSpanKind bool
}
func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
var finish func()
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
@ -54,6 +47,13 @@ func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
}
}
// HTTPHandlerWrapper is used to wrap http handler middleware
type HTTPHandlerWrapper struct {
name string
handler http.Handler
clientSpanKind bool
}
func (t *HTTPHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
var finish func()
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)

View File

@ -1,9 +0,0 @@
package server
import (
"net/http"
)
func notFoundHandler(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}

View File

@ -2,7 +2,7 @@ package server
import "sync"
const bufferPoolSize int = 32 * 1024
const bufferPoolSize = 32 * 1024
func newBufferPool() *bufferPool {
return &bufferPool{

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,581 @@
package server
import (
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"reflect"
"sort"
"strings"
"time"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/log"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/rules"
"github.com/containous/traefik/safe"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/eapache/channels"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
// loadConfiguration manages dynamically frontends, backends and TLS configurations
func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
// Copy configurations to new map so we don't change current if LoadConfig fails
newConfigurations := make(types.Configurations)
for k, v := range currentConfigurations {
newConfigurations[k] = v
}
newConfigurations[configMsg.ProviderName] = configMsg.Configuration
s.metricsRegistry.ConfigReloadsCounter().Add(1)
newServerEntryPoints, err := s.loadConfig(newConfigurations, s.globalConfiguration)
if err != nil {
s.metricsRegistry.ConfigReloadsFailureCounter().Add(1)
s.metricsRegistry.LastConfigReloadFailureGauge().Set(float64(time.Now().Unix()))
log.Error("Error loading new configuration, aborted ", err)
return
}
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil {
if newServerEntryPoint.certs.Get() != nil {
log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName)
}
} else {
s.serverEntryPoints[newServerEntryPointName].certs.Set(newServerEntryPoint.certs.Get())
}
log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
}
s.currentConfigurations.Set(newConfigurations)
for _, listener := range s.configurationListeners {
listener(*configMsg.Configuration)
}
s.postLoadConfiguration()
}
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
// provider configurations.
func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]*serverEntryPoint, error) {
redirectHandlers, err := s.buildEntryPointRedirect()
if err != nil {
return nil, err
}
serverEntryPoints := s.buildServerEntryPoints()
errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{})
backendsHandlers := map[string]http.Handler{}
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
var postConfigs []handlerPostConfig
for providerName, config := range configurations {
frontendNames := sortedFrontendNamesForConfig(config)
for _, frontendName := range frontendNames {
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
redirectHandlers, serverEntryPoints, errorHandler,
backendsHandlers, backendsHealthCheck)
if err != nil {
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
}
if len(frontendPostConfigs) > 0 {
postConfigs = append(postConfigs, frontendPostConfigs...)
}
}
}
for _, postConfig := range postConfigs {
err := postConfig(backendsHandlers)
if err != nil {
log.Errorf("middleware post configuration error: %v", err)
}
}
healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck)
// Get new certificates list sorted per entrypoints
// Update certificates
entryPointsCertificates, err := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints)
// FIXME error management
// Sort routes and update certificates
for serverEntryPointName, serverEntryPoint := range serverEntryPoints {
serverEntryPoint.httpRouter.GetHandler().SortRoutes()
if _, exists := entryPointsCertificates[serverEntryPointName]; exists {
serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName])
}
}
return serverEntryPoints, err
}
func (s *Server) loadFrontendConfig(
providerName string, frontendName string, config *types.Configuration,
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler,
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
) ([]handlerPostConfig, error) {
frontend := config.Frontends[frontendName]
if len(frontend.EntryPoints) == 0 {
return nil, fmt.Errorf("no entrypoint defined for frontend %s", frontendName)
}
backend := config.Backends[frontend.Backend]
if backend == nil {
return nil, fmt.Errorf("undefined backend '%s' for frontend %s", frontend.Backend, frontendName)
}
frontendHash, err := frontend.Hash()
if err != nil {
return nil, fmt.Errorf("error calculating hash value for frontend %s: %v", frontendName, err)
}
var postConfigs []handlerPostConfig
for _, entryPointName := range frontend.EntryPoints {
log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName)
entryPoint := s.entryPoints[entryPointName].Configuration
if backendsHandlers[entryPointName+providerName+frontendHash] == nil {
log.Debugf("Creating backend %s", frontend.Backend)
handlers, responseModifier, postConfig, err := s.buildMiddlewares(frontendName, frontend, config.Backends, entryPointName, entryPoint, providerName)
if err != nil {
return nil, err
}
if postConfig != nil {
postConfigs = append(postConfigs, postConfig)
}
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier)
if err != nil {
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
}
lb, healthCheckConfig, err := s.buildBalancerMiddlewares(frontendName, frontend, backend, fwd)
if err != nil {
return nil, err
}
if healthCheckConfig != nil {
backendsHealthCheck[entryPointName+providerName+frontendHash] = healthCheckConfig
}
n := negroni.New()
if _, exist := redirectHandlers[entryPointName]; exist {
n.Use(redirectHandlers[entryPointName])
}
for _, handler := range handlers {
n.Use(handler)
}
n.UseHandler(lb)
backendsHandlers[entryPointName+providerName+frontendHash] = n
} else {
log.Debugf("Reusing backend %s [%s - %s - %s - %s]",
frontend.Backend, entryPointName, providerName, frontendName, frontendHash)
}
serverRoute, err := buildServerRoute(serverEntryPoints[entryPointName], frontendName, frontend)
if err != nil {
return nil, err
}
handler := buildMatcherMiddlewares(serverRoute, backendsHandlers[entryPointName+providerName+frontendHash])
serverRoute.Route.Handler(handler)
err = serverRoute.Route.GetError()
if err != nil {
// FIXME error management
log.Errorf("Error building route: %s", err)
}
}
return postConfigs, nil
}
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
frontendName string, frontend *types.Frontend,
errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) {
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
if err != nil {
return nil, fmt.Errorf("failed to create RoundTripper for frontend %s: %v", frontendName, err)
}
rewriter, err := NewHeaderRewriter(entryPoint.ForwardedHeaders.TrustedIPs, entryPoint.ForwardedHeaders.Insecure)
if err != nil {
return nil, fmt.Errorf("error creating rewriter for frontend %s: %v", frontendName, err)
}
var fwd http.Handler
fwd, err = forward.New(
forward.Stream(true),
forward.PassHostHeader(frontend.PassHostHeader),
forward.RoundTripper(roundTripper),
forward.ErrorHandler(errorHandler),
forward.Rewriter(rewriter),
forward.ResponseModifier(responseModifier),
forward.BufferPool(s.bufferPool),
)
if err != nil {
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)
}
if s.tracingMiddleware.IsEnabled() {
tm := s.tracingMiddleware.NewForwarderMiddleware(frontendName, frontend.Backend)
next := fwd
fwd = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tm.ServeHTTP(w, r, next.ServeHTTP)
})
}
return fwd, nil
}
func buildServerRoute(serverEntryPoint *serverEntryPoint, frontendName string, frontend *types.Frontend) (*types.ServerRoute, error) {
serverRoute := &types.ServerRoute{Route: serverEntryPoint.httpRouter.GetHandler().NewRoute().Name(frontendName)}
priority := 0
for routeName, route := range frontend.Routes {
rls := rules.Rules{Route: serverRoute}
newRoute, err := rls.Parse(route.Rule)
if err != nil {
return nil, fmt.Errorf("error creating route for frontend %s: %v", frontendName, err)
}
serverRoute.Route = newRoute
priority += len(route.Rule)
log.Debugf("Creating route %s %s", routeName, route.Rule)
}
if frontend.Priority > 0 {
serverRoute.Route.Priority(frontend.Priority)
} else {
serverRoute.Route.Priority(priority)
}
return serverRoute, nil
}
func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration)
s.defaultConfigurationValues(configMsg.Configuration)
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
jsonConf, _ := json.Marshal(configMsg.Configuration)
log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLS == nil {
log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
return
}
if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
return
}
providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName]
if !ok {
providerConfigUpdateCh = make(chan types.ConfigMessage)
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
s.routinesPool.Go(func(stop chan bool) {
s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
})
}
providerConfigUpdateCh <- configMsg
}
func (s *Server) defaultConfigurationValues(configuration *types.Configuration) {
if configuration == nil || configuration.Frontends == nil {
return
}
s.configureFrontends(configuration.Frontends)
configureBackends(configuration.Backends)
}
func (s *Server) configureFrontends(frontends map[string]*types.Frontend) {
defaultEntrypoints := s.globalConfiguration.DefaultEntryPoints
for frontendName, frontend := range frontends {
// default endpoints if not defined in frontends
if len(frontend.EntryPoints) == 0 {
frontend.EntryPoints = defaultEntrypoints
}
frontendEntryPoints, undefinedEntryPoints := s.filterEntryPoints(frontend.EntryPoints)
if len(undefinedEntryPoints) > 0 {
log.Errorf("Undefined entry point(s) '%s' for frontend %s", strings.Join(undefinedEntryPoints, ","), frontendName)
}
frontend.EntryPoints = frontendEntryPoints
}
}
func (s *Server) filterEntryPoints(entryPoints []string) ([]string, []string) {
var frontendEntryPoints []string
var undefinedEntryPoints []string
for _, fepName := range entryPoints {
var exist bool
for epName := range s.entryPoints {
if epName == fepName {
exist = true
break
}
}
if exist {
frontendEntryPoints = append(frontendEntryPoints, fepName)
} else {
undefinedEntryPoints = append(undefinedEntryPoints, fepName)
}
}
return frontendEntryPoints, undefinedEntryPoints
}
func configureBackends(backends map[string]*types.Backend) {
for backendName := range backends {
backend := backends[backendName]
if backend.LoadBalancer != nil && backend.LoadBalancer.Sticky {
log.Warnf("Deprecated configuration found: %s. Please use %s.", "backend.LoadBalancer.Sticky", "backend.LoadBalancer.Stickiness")
}
_, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
if err == nil {
if backend.LoadBalancer != nil && backend.LoadBalancer.Stickiness == nil && backend.LoadBalancer.Sticky {
backend.LoadBalancer.Stickiness = &types.Stickiness{
CookieName: "_TRAEFIK_BACKEND",
}
}
} else {
log.Debugf("Backend %s: %v", backendName, err)
var stickiness *types.Stickiness
if backend.LoadBalancer != nil {
if backend.LoadBalancer.Stickiness == nil {
if backend.LoadBalancer.Sticky {
stickiness = &types.Stickiness{
CookieName: "_TRAEFIK_BACKEND",
}
}
} else {
stickiness = backend.LoadBalancer.Stickiness
}
}
backend.LoadBalancer = &types.LoadBalancer{
Method: "wrr",
Stickiness: stickiness,
}
}
}
}
func (s *Server) listenConfigurations(stop chan bool) {
for {
select {
case <-stop:
return
case configMsg, ok := <-s.configurationValidatedChan:
if !ok || configMsg.Configuration == nil {
return
}
s.loadConfiguration(configMsg)
}
}
}
// throttleProviderConfigReload throttles the configuration reload speed for a single provider.
// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration.
// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing,
// it will publish the last of the newly received configurations.
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- types.ConfigMessage, in <-chan types.ConfigMessage, stop chan bool) {
ring := channels.NewRingChannel(1)
defer ring.Close()
s.routinesPool.Go(func(stop chan bool) {
for {
select {
case <-stop:
return
case nextConfig := <-ring.Out():
publish <- nextConfig.(types.ConfigMessage)
time.Sleep(throttle)
}
}
})
for {
select {
case <-stop:
return
case nextConfig := <-in:
ring.In() <- nextConfig
}
}
}
func buildMatcherMiddlewares(serverRoute *types.ServerRoute, handler http.Handler) http.Handler {
// path replace - This needs to always be the very last on the handler chain (first in the order in this function)
// -- Replacing Path should happen at the very end of the Modifier chain, after all the Matcher+Modifiers ran
if len(serverRoute.ReplacePath) > 0 {
handler = &middlewares.ReplacePath{
Path: serverRoute.ReplacePath,
Handler: handler,
}
}
if len(serverRoute.ReplacePathRegex) > 0 {
sp := strings.Split(serverRoute.ReplacePathRegex, " ")
if len(sp) == 2 {
handler = middlewares.NewReplacePathRegexHandler(sp[0], sp[1], handler)
} else {
log.Warnf("Invalid syntax for ReplacePathRegex: %s. Separate the regular expression and the replacement by a space.", serverRoute.ReplacePathRegex)
}
}
// add prefix - This needs to always be right before ReplacePath on the chain (second in order in this function)
// -- Adding Path Prefix should happen after all *Strip Matcher+Modifiers ran, but before Replace (in case it's configured)
if len(serverRoute.AddPrefix) > 0 {
handler = &middlewares.AddPrefix{
Prefix: serverRoute.AddPrefix,
Handler: handler,
}
}
// strip prefix
if len(serverRoute.StripPrefixes) > 0 {
handler = &middlewares.StripPrefix{
Prefixes: serverRoute.StripPrefixes,
Handler: handler,
}
}
// strip prefix with regex
if len(serverRoute.StripPrefixesRegex) > 0 {
handler = middlewares.NewStripPrefixRegex(handler, serverRoute.StripPrefixesRegex)
}
return handler
}
func (s *Server) postLoadConfiguration() {
if s.metricsRegistry.IsEnabled() {
activeConfig := s.currentConfigurations.Get().(types.Configurations)
metrics.OnConfigurationUpdate(activeConfig)
}
if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
return
}
if s.globalConfiguration.ACME.OnHostRule {
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
for _, config := range currentConfigurations {
for _, frontend := range config.Frontends {
// check if one of the frontend entrypoints is configured with TLS
// and is configured with ACME
acmeEnabled := false
for _, entryPoint := range frontend.EntryPoints {
if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
acmeEnabled = true
break
}
}
if acmeEnabled {
for _, route := range frontend.Routes {
rls := rules.Rules{}
domains, err := rls.ParseDomains(route.Rule)
if err != nil {
log.Errorf("Error parsing domains: %v", err)
} else {
s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
}
}
}
}
}
}
}
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) (map[string]map[string]*tls.Certificate, error) {
newEPCertificates := make(map[string]map[string]*tls.Certificate)
// Get all certificates
for _, config := range configurations {
if config.TLS != nil && len(config.TLS) > 0 {
if err := traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, defaultEntryPoints); err != nil {
return nil, err
}
}
}
return newEPCertificates, nil
}
func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
serverEntryPoints := make(map[string]*serverEntryPoint)
for entryPointName, entryPoint := range s.entryPoints {
serverEntryPoints[entryPointName] = &serverEntryPoint{
httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()),
onDemandListener: entryPoint.OnDemandListener,
}
if entryPoint.CertificateStore != nil {
serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore.DynamicCerts
} else {
serverEntryPoints[entryPointName].certs = &safe.Safe{}
}
}
return serverEntryPoints
}
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
rt := mux.NewRouter()
rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(http.NotFound), "backend not found")
rt.StrictSlash(true)
rt.SkipClean(true)
return rt
}
func sortedFrontendNamesForConfig(configuration *types.Configuration) []string {
var keys []string
for key := range configuration.Frontends {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}

View File

@ -0,0 +1,484 @@
package server
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/containous/flaeg"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/rules"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/roundrobin"
)
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
// generated from src/crypto/tls:
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var (
localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE-----
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
fblo6RBxUQ==
-----END CERTIFICATE-----`)
// LocalhostKey is the private key for localhostCert.
localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
-----END RSA PRIVATE KEY-----`)
)
type testLoadBalancer struct{}
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// noop
}
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *testLoadBalancer) Servers() []*url.URL {
return []*url.URL{}
}
func TestServerLoadConfigHealthCheckOptions(t *testing.T) {
healthChecks := []*types.HealthCheck{
nil,
{
Path: "/path",
},
}
for _, lbMethod := range []string{"Wrr", "Drr"} {
for _, healthCheck := range healthChecks {
t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)},
}
entryPoints := map[string]EntryPoint{
"http": {
Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: lbMethod,
},
HealthCheck: healthCheck,
},
},
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
CertFile: localhostCert,
KeyFile: localhostKey,
},
EntryPoints: []string{"http"},
},
},
},
}
srv := NewServer(globalConfig, nil, entryPoints)
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
require.NoError(t, err)
expectedNumHealthCheckBackends := 0
if healthCheck != nil {
expectedNumHealthCheckBackends = 1
}
assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends")
})
}
}
}
func TestServerLoadConfigEmptyBasicAuth(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
EntryPoints: configuration.EntryPoints{
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
BasicAuth: []string{""},
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: "Wrr",
},
},
},
},
}
entryPoints := map[string]EntryPoint{}
for key, value := range globalConfig.EntryPoints {
entryPoints[key] = EntryPoint{
Configuration: value,
}
}
srv := NewServer(globalConfig, nil, entryPoints)
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
require.NoError(t, err)
}
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http", "https"},
}
entryPoints := map[string]EntryPoint{
"https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}},
"http": {Configuration: &configuration.EntryPoint{}},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
CertFile: localhostCert,
KeyFile: localhostKey,
},
},
},
},
}
srv := NewServer(globalConfig, nil, entryPoints)
if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil {
t.Fatalf("got error: %s", err)
} else if mapEntryPoints["https"].certs.Get() == nil {
t.Fatal("got error: https entryPoint must have TLS certificates.")
}
}
func TestReuseBackend(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http"},
}
entryPoints := map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
}},
}
dynamicConfigs := types.Configurations{
"config": th.BuildConfiguration(
th.WithFrontends(
th.WithFrontend("backend",
th.WithFrontendName("frontend0"),
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))),
th.WithFrontend("backend",
th.WithFrontendName("frontend1"),
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")),
th.WithBasicAuth("foo", "bar")),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLBMethod("wrr"),
th.WithServersNew(th.WithServerNew(testServer.URL))),
),
),
}
srv := NewServer(globalConfig, nil, entryPoints)
serverEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig)
if err != nil {
t.Fatalf("error loading config: %s", err)
}
// Test that the /ok path returns a status 200.
responseRecorderOk := &httptest.ResponseRecorder{}
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk)
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
// Test that the /unauthorized path returns a 401 because of
// the basic authentication defined on the frontend.
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
}
func TestThrottleProviderConfigReload(t *testing.T) {
throttleDuration := 30 * time.Millisecond
publishConfig := make(chan types.ConfigMessage)
providerConfig := make(chan types.ConfigMessage)
stop := make(chan bool)
defer func() {
stop <- true
}()
globalConfig := configuration.GlobalConfiguration{}
server := NewServer(globalConfig, nil, nil)
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
publishedConfigCount := 0
stopConsumeConfigs := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-stopConsumeConfigs:
return
case <-publishConfig:
publishedConfigCount++
}
}
}()
// publish 5 new configs, one new config each 10 milliseconds
for i := 0; i < 5; i++ {
providerConfig <- types.ConfigMessage{}
time.Sleep(10 * time.Millisecond)
}
// after 50 milliseconds 5 new configs were published
// with a throttle duration of 30 milliseconds this means, we should have received 2 new configs
assert.Equal(t, 2, publishedConfigCount, "times configs were published")
stopConsumeConfigs <- true
select {
case <-publishConfig:
// There should be exactly one more message that we receive after ~60 milliseconds since the start of the test.
select {
case <-publishConfig:
t.Error("extra config publication found")
case <-time.After(100 * time.Millisecond):
return
}
case <-time.After(100 * time.Millisecond):
t.Error("Last config was not published in time")
}
}
func TestServerMultipleFrontendRules(t *testing.T) {
testCases := []struct {
expression string
requestURL string
expectedURL string
}{
{
expression: "Host:foo.bar",
requestURL: "http://foo.bar",
expectedURL: "http://foo.bar",
},
{
expression: "PathPrefix:/management;ReplacePath:/health",
requestURL: "http://foo.bar/management",
expectedURL: "http://foo.bar/health",
},
{
expression: "Host:foo.bar;AddPrefix:/blah",
requestURL: "http://foo.bar/baz",
expectedURL: "http://foo.bar/blah/baz",
},
{
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}",
requestURL: "http://foo.bar/one/some/12345/four",
expectedURL: "http://foo.bar/four",
},
{
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero",
requestURL: "http://foo.bar/one/some/12345/four",
expectedURL: "http://foo.bar/zero/four",
},
{
expression: "AddPrefix:/blah;ReplacePath:/baz",
requestURL: "http://foo.bar/hello",
expectedURL: "http://foo.bar/baz",
},
{
expression: "PathPrefixStrip:/management;ReplacePath:/health",
requestURL: "http://foo.bar/management",
expectedURL: "http://foo.bar/health",
},
}
for _, test := range testCases {
test := test
t.Run(test.expression, func(t *testing.T) {
t.Parallel()
router := mux.NewRouter()
route := router.NewRoute()
serverRoute := &types.ServerRoute{Route: route}
rls := &rules.Rules{Route: serverRoute}
expression := test.expression
routeResult, err := rls.Parse(expression)
if err != nil {
t.Fatalf("Error while building route for %s: %+v", expression, err)
}
request := th.MustNewRequest(http.MethodGet, test.requestURL, nil)
routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult})
if !routeMatch {
t.Fatalf("Rule %s doesn't match", expression)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, test.expectedURL, r.URL.String(), "URL")
})
hd := buildMatcherMiddlewares(serverRoute, handler)
serverRoute.Route.Handler(hd)
serverRoute.Route.GetHandler().ServeHTTP(nil, request)
})
}
}
func TestServerBuildHealthCheckOptions(t *testing.T) {
lb := &testLoadBalancer{}
globalInterval := 15 * time.Second
testCases := []struct {
desc string
hc *types.HealthCheck
expectedOpts *healthcheck.Options
}{
{
desc: "nil health check",
hc: nil,
expectedOpts: nil,
},
{
desc: "empty path",
hc: &types.HealthCheck{
Path: "",
},
expectedOpts: nil,
},
{
desc: "unparseable interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "unparseable",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
LB: lb,
},
},
{
desc: "sub-zero interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "-42s",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
LB: lb,
},
},
{
desc: "parseable interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "5m",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: 5 * time.Minute,
LB: lb,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
opts := buildHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)})
assert.Equal(t, test.expectedOpts, opts, "health check options")
})
}
}

View File

@ -0,0 +1,428 @@
package server
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"time"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/server/cookie"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/vulcand/oxy/buffer"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/ratelimit"
"github.com/vulcand/oxy/roundrobin"
"github.com/vulcand/oxy/utils"
"golang.org/x/net/http2"
)
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
func (s *Server) buildBalancerMiddlewares(frontendName string, frontend *types.Frontend, backend *types.Backend, fwd http.Handler) (http.Handler, *healthcheck.BackendConfig, error) {
balancer, err := s.buildLoadBalancer(frontendName, frontend.Backend, backend, fwd)
if err != nil {
return nil, nil, err
}
// Health Check
var backendHealthCheck *healthcheck.BackendConfig
if hcOpts := buildHealthCheckOptions(balancer, frontend.Backend, backend.HealthCheck, s.globalConfiguration.HealthCheck); hcOpts != nil {
log.Debugf("Setting up backend health check %s", *hcOpts)
hcOpts.Transport = s.defaultForwardingRoundTripper
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, frontend.Backend)
}
// Empty (backend with no servers)
var lb http.Handler = middlewares.NewEmptyBackendHandler(balancer)
// Rate Limit
if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 {
handler, err := buildRateLimiter(lb, frontend.RateLimit)
if err != nil {
return nil, nil, fmt.Errorf("error creating rate limiter: %v", err)
}
lb = s.wrapHTTPHandlerWithAccessLog(
s.tracingMiddleware.NewHTTPHandlerWrapper("Rate limit", handler, false),
fmt.Sprintf("rate limit for %s", frontendName),
)
}
// Max Connections
if backend.MaxConn != nil && backend.MaxConn.Amount != 0 {
log.Debugf("Creating load-balancer connection limit")
handler, err := buildMaxConn(lb, backend.MaxConn)
if err != nil {
return nil, nil, err
}
lb = s.wrapHTTPHandlerWithAccessLog(handler, fmt.Sprintf("connection limit for %s", frontendName))
}
// Retry
if s.globalConfiguration.Retry != nil {
handler := s.buildRetryMiddleware(lb, s.globalConfiguration.Retry, len(backend.Servers), frontend.Backend)
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Retry", handler, false)
}
// Buffering
if backend.Buffering != nil {
handler, err := buildBufferingMiddleware(lb, backend.Buffering)
if err != nil {
return nil, nil, fmt.Errorf("error setting up buffering middleware: %s", err)
}
// TODO refactor ?
lb = handler
}
// Circuit Breaker
if backend.CircuitBreaker != nil {
log.Debugf("Creating circuit breaker %s", backend.CircuitBreaker.Expression)
expression := backend.CircuitBreaker.Expression
circuitBreaker, err := middlewares.NewCircuitBreaker(lb, expression, middlewares.NewCircuitBreakerOptions(expression))
if err != nil {
return nil, nil, fmt.Errorf("error creating circuit breaker: %v", err)
}
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Circuit breaker", circuitBreaker, false)
}
return lb, backendHealthCheck, nil
}
func (s *Server) buildLoadBalancer(frontendName string, backendName string, backend *types.Backend, fwd http.Handler) (healthcheck.BalancerHandler, error) {
var rr *roundrobin.RoundRobin
var saveFrontend http.Handler
if s.accessLoggerMiddleware != nil {
saveBackend := accesslog.NewSaveBackend(fwd, backendName)
saveFrontend = accesslog.NewSaveFrontend(saveBackend, frontendName)
rr, _ = roundrobin.New(saveFrontend)
} else {
rr, _ = roundrobin.New(fwd)
}
var stickySession *roundrobin.StickySession
var cookieName string
if stickiness := backend.LoadBalancer.Stickiness; stickiness != nil {
cookieName = cookie.GetName(stickiness.CookieName, backendName)
stickySession = roundrobin.NewStickySession(cookieName)
}
lbMethod, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
if err != nil {
return nil, fmt.Errorf("error loading load balancer method '%+v' for frontend %s: %v", backend.LoadBalancer, frontendName, err)
}
var lb healthcheck.BalancerHandler
switch lbMethod {
case types.Drr:
log.Debug("Creating load-balancer drr")
if stickySession != nil {
log.Debugf("Sticky session with cookie %v", cookieName)
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.NewRebalancer(rr)
if err != nil {
return nil, err
}
}
case types.Wrr:
log.Debug("Creating load-balancer wrr")
if stickySession != nil {
log.Debugf("Sticky session with cookie %v", cookieName)
if s.accessLoggerMiddleware != nil {
lb, err = roundrobin.New(saveFrontend, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
}
} else {
lb = rr
}
default:
return nil, fmt.Errorf("invalid load-balancing method %q", lbMethod)
}
if err := s.configureLBServers(lb, backend, backendName); err != nil {
return nil, fmt.Errorf("error configuring load balancer for frontend %s: %v", frontendName, err)
}
return lb, nil
}
func (s *Server) configureLBServers(lb healthcheck.BalancerHandler, backend *types.Backend, backendName string) error {
for name, srv := range backend.Servers {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
}
log.Debugf("Creating server %s at %s with weight %d", name, u, srv.Weight)
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
}
s.metricsRegistry.BackendServerUpGauge().With("backend", backendName, "url", srv.URL).Set(1)
}
return nil
}
// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one
// given a custom TLS configuration is passed and the passTLSCert option is set to true.
func (s *Server) getRoundTripper(entryPointName string, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) {
if passTLSCert {
tlsConfig, err := createClientTLSConfig(entryPointName, tls)
if err != nil {
return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err)
}
transport, err := createHTTPTransport(s.globalConfiguration)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP transport: %v", err)
}
transport.TLSClientConfig = tlsConfig
return transport, nil
}
return s.defaultForwardingRoundTripper, nil
}
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
// behaviour and backwards compatibility issues.
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
dialer := &net.Dialer{
Timeout: configuration.DefaultDialTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
if globalConfiguration.ForwardingTimeouts != nil {
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
Transport: &http2.Transport{
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
return net.Dial(netw, addr)
},
AllowHTTP: true,
},
})
if globalConfiguration.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
}
if globalConfiguration.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if len(globalConfiguration.RootCAs) > 0 {
transport.TLSClientConfig = &tls.Config{
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
}
}
err := http2.ConfigureTransport(transport)
if err != nil {
return nil, err
}
return transport, nil
}
func createRootCACertPool(rootCAs traefiktls.RootCAs) *x509.CertPool {
roots := x509.NewCertPool()
for _, cert := range rootCAs {
certContent, err := cert.Read()
if err != nil {
log.Error("Error while read RootCAs", err)
continue
}
roots.AppendCertsFromPEM(certContent)
}
return roots
}
func createClientTLSConfig(entryPointName string, tlsOption *traefiktls.TLS) (*tls.Config, error) {
if tlsOption == nil {
return nil, errors.New("no TLS provided")
}
config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
if err != nil {
return nil, err
}
if len(tlsOption.ClientCAFiles) > 0 {
log.Warnf("Deprecated configuration found during client TLS configuration creation: %s. Please use %s (which allows to make the CA Files optional).", "tls.ClientCAFiles", "tls.ClientCA.files")
tlsOption.ClientCA.Files = tlsOption.ClientCAFiles
tlsOption.ClientCA.Optional = false
}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
for _, caFile := range tlsOption.ClientCA.Files {
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !pool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
}
}
config.RootCAs = pool
}
config.BuildNameToCertificate()
return config, nil
}
func (s *Server) buildRetryMiddleware(handler http.Handler, retry *configuration.Retry, countServers int, backendName string) http.Handler {
retryListeners := middlewares.RetryListeners{}
if s.metricsRegistry.IsEnabled() {
retryListeners = append(retryListeners, middlewares.NewMetricsRetryListener(s.metricsRegistry, backendName))
}
if s.accessLoggerMiddleware != nil {
retryListeners = append(retryListeners, &accesslog.SaveRetries{})
}
retryAttempts := countServers
if retry.Attempts > 0 {
retryAttempts = retry.Attempts
}
log.Debugf("Creating retries max attempts %d", retryAttempts)
return middlewares.NewRetry(retryAttempts, handler, retryListeners)
}
func buildRateLimiter(handler http.Handler, rlConfig *types.RateLimit) (http.Handler, error) {
extractFunc, err := utils.NewExtractor(rlConfig.ExtractorFunc)
if err != nil {
return nil, err
}
log.Debugf("Creating load-balancer rate limiter")
rateSet := ratelimit.NewRateSet()
for _, rate := range rlConfig.RateSet {
if err := rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
return nil, err
}
}
return ratelimit.New(handler, extractFunc, rateSet)
}
func buildBufferingMiddleware(handler http.Handler, config *types.Buffering) (http.Handler, error) {
log.Debugf("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes,
config.MaxResponseBodyBytes, config.RetryExpression)
return buffer.New(
handler,
buffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
buffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
buffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
buffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
buffer.CondSetter(len(config.RetryExpression) > 0, buffer.Retry(config.RetryExpression)),
)
}
func buildMaxConn(lb http.Handler, maxConns *types.MaxConn) (http.Handler, error) {
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
log.Debugf("Creating load-balancer connection limit")
handler, err := connlimit.New(lb, extractFunc, maxConns.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return handler, nil
}
func buildHealthCheckOptions(lb healthcheck.BalancerHandler, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options {
if hc == nil || hc.Path == "" || hcConfig == nil {
return nil
}
interval := time.Duration(hcConfig.Interval)
if hc.Interval != "" {
intervalOverride, err := time.ParseDuration(hc.Interval)
if err != nil {
log.Errorf("Illegal health check interval for backend '%s': %s", backend, err)
} else if intervalOverride <= 0 {
log.Errorf("Health check interval smaller than zero for backend '%s', backend", backend)
} else {
interval = intervalOverride
}
}
return &healthcheck.Options{
Scheme: hc.Scheme,
Path: hc.Path,
Port: hc.Port,
Interval: interval,
LB: lb,
Hostname: hc.Hostname,
Headers: hc.Headers,
}
}

View File

@ -0,0 +1,81 @@
package server
import (
"testing"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
)
func TestConfigureBackends(t *testing.T) {
validMethod := "Drr"
defaultMethod := "wrr"
testCases := []struct {
desc string
lb *types.LoadBalancer
expectedMethod string
expectedStickiness *types.Stickiness
}{
{
desc: "valid load balancer method with sticky enabled",
lb: &types.LoadBalancer{
Method: validMethod,
Stickiness: &types.Stickiness{},
},
expectedMethod: validMethod,
expectedStickiness: &types.Stickiness{},
},
{
desc: "valid load balancer method with sticky disabled",
lb: &types.LoadBalancer{
Method: validMethod,
Stickiness: nil,
},
expectedMethod: validMethod,
},
{
desc: "invalid load balancer method with sticky enabled",
lb: &types.LoadBalancer{
Method: "Invalid",
Stickiness: &types.Stickiness{},
},
expectedMethod: defaultMethod,
expectedStickiness: &types.Stickiness{},
},
{
desc: "invalid load balancer method with sticky disabled",
lb: &types.LoadBalancer{
Method: "Invalid",
Stickiness: nil,
},
expectedMethod: defaultMethod,
},
{
desc: "missing load balancer",
lb: nil,
expectedMethod: defaultMethod,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := &types.Backend{
LoadBalancer: test.lb,
}
configureBackends(map[string]*types.Backend{
"backend": backend,
})
expected := types.LoadBalancer{
Method: test.expectedMethod,
Stickiness: test.expectedStickiness,
}
assert.Equal(t, expected, *backend.LoadBalancer)
})
}
}

View File

@ -0,0 +1,316 @@
package server
import (
"fmt"
"net/http"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
mauth "github.com/containous/traefik/middlewares/auth"
"github.com/containous/traefik/middlewares/errorpages"
"github.com/containous/traefik/middlewares/redirect"
"github.com/containous/traefik/types"
thoas_stats "github.com/thoas/stats"
"github.com/unrolled/secure"
"github.com/urfave/negroni"
)
type handlerPostConfig func(backendsHandlers map[string]http.Handler) error
type modifyResponse func(*http.Response) error
func (s *Server) buildMiddlewares(frontendName string, frontend *types.Frontend,
backends map[string]*types.Backend,
entryPointName string, entryPoint *configuration.EntryPoint,
providerName string) ([]negroni.Handler, modifyResponse, handlerPostConfig, error) {
var middle []negroni.Handler
var postConfig handlerPostConfig
// Error pages
if len(frontend.Errors) > 0 {
handlers, err := buildErrorPagesMiddleware(frontendName, frontend, backends, entryPointName, providerName)
if err != nil {
return nil, nil, nil, err
}
postConfig = errorPagesPostConfig(handlers)
for _, handler := range handlers {
middle = append(middle, handler)
}
}
// Metrics
if s.metricsRegistry.IsEnabled() {
handler := middlewares.NewBackendMetricsMiddleware(s.metricsRegistry, frontend.Backend)
middle = append(middle, handler)
}
// Whitelist
ipWhitelistMiddleware, err := buildIPWhiteLister(frontend.WhiteList, frontend.WhitelistSourceRange)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating IP Whitelister: %s", err)
}
if ipWhitelistMiddleware != nil {
log.Debugf("Configured IP Whitelists: %v", frontend.WhiteList.SourceRange)
handler := s.tracingMiddleware.NewNegroniHandlerWrapper(
"IP whitelist",
s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for %s", frontendName)),
false)
middle = append(middle, handler)
}
// Redirect
if frontend.Redirect != nil && entryPointName != frontend.Redirect.EntryPoint {
rewrite, err := s.buildRedirectHandler(entryPointName, frontend.Redirect)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating Frontend Redirect: %v", err)
}
handler := s.wrapNegroniHandlerWithAccessLog(rewrite, fmt.Sprintf("frontend redirect for %s", frontendName))
middle = append(middle, handler)
log.Debugf("Frontend %s redirect created", frontendName)
}
// Header
headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers)
if headerMiddleware != nil {
log.Debugf("Adding header middleware for frontend %s", frontendName)
handler := s.tracingMiddleware.NewNegroniHandlerWrapper("Header", headerMiddleware, false)
middle = append(middle, handler)
}
// Secure
secureMiddleware := middlewares.NewSecure(frontend.Headers)
if secureMiddleware != nil {
log.Debugf("Adding secure middleware for frontend %s", frontendName)
handler := negroni.HandlerFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly)
middle = append(middle, handler)
}
// Basic auth
if len(frontend.BasicAuth) > 0 {
log.Debugf("Adding basic authentication for frontend %s", frontendName)
authMiddleware, err := s.buildBasicAuthMiddleware(frontend.BasicAuth)
if err != nil {
return nil, nil, nil, err
}
handler := s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Basic Auth for %s", frontendName))
middle = append(middle, handler)
}
return middle, buildModifyResponse(secureMiddleware, headerMiddleware), postConfig, nil
}
func (s *Server) buildServerEntryPointMiddlewares(serverEntryPointName string, serverEntryPoint *serverEntryPoint) ([]negroni.Handler, error) {
serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()}
if s.tracingMiddleware.IsEnabled() {
serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(serverEntryPointName))
}
if s.accessLoggerMiddleware != nil {
serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware)
}
if s.metricsRegistry.IsEnabled() {
serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, serverEntryPointName))
}
if s.globalConfiguration.API != nil {
if s.globalConfiguration.API.Stats == nil {
s.globalConfiguration.API.Stats = thoas_stats.New()
}
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.Stats)
if s.globalConfiguration.API.Statistics != nil {
if s.globalConfiguration.API.StatsRecorder == nil {
s.globalConfiguration.API.StatsRecorder = middlewares.NewStatsRecorder(s.globalConfiguration.API.Statistics.RecentErrors)
}
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.StatsRecorder)
}
}
if s.entryPoints[serverEntryPointName].Configuration.Auth != nil {
authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[serverEntryPointName].Configuration.Auth, s.tracingMiddleware)
if err != nil {
return nil, fmt.Errorf("failed to create authentication middleware: %v", err)
}
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", serverEntryPointName)))
}
if s.entryPoints[serverEntryPointName].Configuration.Compress {
serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{})
}
ipWhitelistMiddleware, err := buildIPWhiteLister(
s.entryPoints[serverEntryPointName].Configuration.WhiteList,
s.entryPoints[serverEntryPointName].Configuration.WhitelistSourceRange)
if err != nil {
return nil, fmt.Errorf("failed to create ip whitelist middleware: %v", err)
}
if ipWhitelistMiddleware != nil {
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", serverEntryPointName)))
}
return serverMiddlewares, nil
}
func errorPagesPostConfig(epHandlers []*errorpages.Handler) handlerPostConfig {
return func(backendsHandlers map[string]http.Handler) error {
for _, errorPageHandler := range epHandlers {
if handler, ok := backendsHandlers[errorPageHandler.BackendName]; ok {
err := errorPageHandler.PostLoad(handler)
if err != nil {
return fmt.Errorf("failed to configure error pages for backend %s: %v", errorPageHandler.BackendName, err)
}
} else {
err := errorPageHandler.PostLoad(nil)
if err != nil {
return fmt.Errorf("failed to configure error pages for %s: %v", errorPageHandler.FallbackURL, err)
}
}
}
return nil
}
}
func buildErrorPagesMiddleware(frontendName string, frontend *types.Frontend, backends map[string]*types.Backend, entryPointName string, providerName string) ([]*errorpages.Handler, error) {
var errorPageHandlers []*errorpages.Handler
for errorPageName, errorPage := range frontend.Errors {
if frontend.Backend == errorPage.Backend {
log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).",
errorPageName, frontendName, errorPage.Backend)
} else if backends[errorPage.Backend] == nil {
log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.",
errorPageName, frontendName, errorPage.Backend)
} else {
errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend)
if err != nil {
return nil, fmt.Errorf("error creating error pages: %v", err)
}
if errorPageServer, ok := backends[errorPage.Backend].Servers["error"]; ok {
errorPagesHandler.FallbackURL = errorPageServer.URL
}
errorPageHandlers = append(errorPageHandlers, errorPagesHandler)
}
}
return errorPageHandlers, nil
}
func (s *Server) buildBasicAuthMiddleware(authData []string) (*mauth.Authenticator, error) {
users := types.Users{}
for _, user := range authData {
users = append(users, user)
}
auth := &types.Auth{}
auth.Basic = &types.Basic{
Users: users,
}
authMiddleware, err := mauth.NewAuthenticator(auth, s.tracingMiddleware)
if err != nil {
return nil, fmt.Errorf("error creating Basic Auth: %v", err)
}
return authMiddleware, nil
}
func (s *Server) buildEntryPointRedirect() (map[string]negroni.Handler, error) {
redirectHandlers := map[string]negroni.Handler{}
for entryPointName, ep := range s.entryPoints {
entryPoint := ep.Configuration
if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint {
handler, err := s.buildRedirectHandler(entryPointName, entryPoint.Redirect)
if err != nil {
return nil, fmt.Errorf("error loading configuration for entrypoint %s: %v", entryPointName, err)
}
handlerToUse := s.wrapNegroniHandlerWithAccessLog(handler, fmt.Sprintf("entrypoint redirect for %s", entryPointName))
redirectHandlers[entryPointName] = handlerToUse
}
}
return redirectHandlers, nil
}
func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) {
// entry point redirect
if len(opt.EntryPoint) > 0 {
entryPoint := s.entryPoints[opt.EntryPoint].Configuration
if entryPoint == nil {
return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName)
}
log.Debugf("Creating entry point redirect %s -> %s", srcEntryPointName, opt.EntryPoint)
return redirect.NewEntryPointHandler(entryPoint, opt.Permanent)
}
// regex redirect
redirection, err := redirect.NewRegexHandler(opt.Regex, opt.Replacement, opt.Permanent)
if err != nil {
return nil, err
}
log.Debugf("Creating regex redirect %s -> %s -> %s", srcEntryPointName, opt.Regex, opt.Replacement)
return redirection, nil
}
func buildIPWhiteLister(whiteList *types.WhiteList, wlRange []string) (*middlewares.IPWhiteLister, error) {
if whiteList != nil &&
len(whiteList.SourceRange) > 0 {
return middlewares.NewIPWhiteLister(whiteList.SourceRange, whiteList.UseXForwardedFor)
} else if len(wlRange) > 0 {
return middlewares.NewIPWhiteLister(wlRange, false)
}
return nil, nil
}
func (s *Server) wrapNegroniHandlerWithAccessLog(handler negroni.Handler, frontendName string) negroni.Handler {
if s.accessLoggerMiddleware != nil {
saveBackend := accesslog.NewSaveNegroniBackend(handler, "Træfik")
saveFrontend := accesslog.NewSaveNegroniFrontend(saveBackend, frontendName)
return saveFrontend
}
return handler
}
func (s *Server) wrapHTTPHandlerWithAccessLog(handler http.Handler, frontendName string) http.Handler {
if s.accessLoggerMiddleware != nil {
saveBackend := accesslog.NewSaveBackend(handler, "Træfik")
saveFrontend := accesslog.NewSaveFrontend(saveBackend, frontendName)
return saveFrontend
}
return handler
}
func buildModifyResponse(secure *secure.Secure, header *middlewares.HeaderStruct) func(res *http.Response) error {
return func(res *http.Response) error {
if secure != nil {
if err := secure.ModifyResponseHeaders(res); err != nil {
return err
}
}
if header != nil {
if err := header.ModifyResponseHeaders(res); err != nil {
return err
}
}
return nil
}
}

View File

@ -0,0 +1,253 @@
package server
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestServerEntryPointWhitelistConfig(t *testing.T) {
testCases := []struct {
desc string
entrypoint *configuration.EntryPoint
expectMiddleware bool
}{
{
desc: "no whitelist middleware if no config on entrypoint",
entrypoint: &configuration.EntryPoint{
Address: ":0",
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
expectMiddleware: false,
},
{
desc: "whitelist middleware should be added if configured on entrypoint",
entrypoint: &configuration.EntryPoint{
Address: ":0",
WhitelistSourceRange: []string{
"127.0.0.1/32",
},
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
expectMiddleware: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
srv := Server{
globalConfiguration: configuration.GlobalConfiguration{},
metricsRegistry: metrics.NewVoidRegistry(),
entryPoints: map[string]EntryPoint{
"test": {
Configuration: test.entrypoint,
},
},
}
srv.serverEntryPoints = srv.buildServerEntryPoints()
srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"])
handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni)
found := false
for _, handler := range handler.Handlers() {
if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) {
found = true
}
}
if found && !test.expectMiddleware {
t.Error("ip whitelist middleware was installed even though it should not")
}
if !found && test.expectMiddleware {
t.Error("ip whitelist middleware was not installed even though it should have")
}
})
}
}
func TestBuildIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whitelistSourceRange []string
whiteList *types.WhiteList
middlewareConfigured bool
errMessage string
}{
{
desc: "no whitelists configured",
whitelistSourceRange: nil,
middlewareConfigured: false,
errMessage: "",
},
{
desc: "whitelists configured (deprecated)",
whitelistSourceRange: []string{
"1.2.3.4/24",
"fe80::/16",
},
middlewareConfigured: true,
errMessage: "",
},
{
desc: "invalid whitelists configured (deprecated)",
whitelistSourceRange: []string{
"foo",
},
middlewareConfigured: false,
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
},
{
desc: "whitelists configured",
whiteList: &types.WhiteList{
SourceRange: []string{
"1.2.3.4/24",
"fe80::/16",
},
UseXForwardedFor: false,
},
middlewareConfigured: true,
errMessage: "",
},
{
desc: "invalid whitelists configured (deprecated)",
whiteList: &types.WhiteList{
SourceRange: []string{
"foo",
},
UseXForwardedFor: false,
},
middlewareConfigured: false,
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
middleware, err := buildIPWhiteLister(test.whiteList, test.whitelistSourceRange)
if test.errMessage != "" {
require.EqualError(t, err, test.errMessage)
} else {
assert.NoError(t, err)
if test.middlewareConfigured {
require.NotNil(t, middleware, "not expected middleware to be configured")
} else {
require.Nil(t, middleware, "expected middleware to be configured")
}
}
})
}
}
func TestBuildRedirectHandler(t *testing.T) {
srv := Server{
globalConfiguration: configuration.GlobalConfiguration{},
entryPoints: map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{Address: ":80"}},
"https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}},
},
}
testCases := []struct {
desc string
srcEntryPointName string
url string
entryPoint *configuration.EntryPoint
redirect *types.Redirect
expectedURL string
}{
{
desc: "redirect regex",
srcEntryPointName: "http",
url: "http://foo.com",
redirect: &types.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
},
expectedURL: "https://foobar.com",
},
{
desc: "redirect entry point",
srcEntryPointName: "http",
url: "http://foo:80",
redirect: &types.Redirect{
EntryPoint: "https",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
EntryPoint: "https",
},
},
expectedURL: "https://foo:443",
},
{
desc: "redirect entry point with regex (ignored)",
srcEntryPointName: "http",
url: "http://foo.com:80",
redirect: &types.Redirect{
EntryPoint: "https",
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
EntryPoint: "https",
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
},
expectedURL: "https://foo.com:443",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect)
require.NoError(t, err)
req := th.MustNewRequest(http.MethodGet, test.url, nil)
recorder := httptest.NewRecorder()
rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Location", "fail")
}))
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
})
}
}

View File

@ -22,12 +22,12 @@ func (s *Server) listenSignals() {
if s.accessLoggerMiddleware != nil {
if err := s.accessLoggerMiddleware.Rotate(); err != nil {
log.Errorf("Error rotating access log: %s", err)
log.Errorf("Error rotating access log: %v", err)
}
}
if err := log.RotateFile(); err != nil {
log.Errorf("Error rotating traefik log: %s", err)
log.Errorf("Error rotating traefik log: %v", err)
}
}
}

View File

@ -2,83 +2,22 @@ package server
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"time"
"github.com/containous/flaeg"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/rules"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/unrolled/secure"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/roundrobin"
)
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
// generated from src/crypto/tls:
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var (
localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE-----
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
fblo6RBxUQ==
-----END CERTIFICATE-----`)
// LocalhostKey is the private key for localhostCert.
localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
-----END RSA PRIVATE KEY-----`)
)
type testLoadBalancer struct{}
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *testLoadBalancer) Servers() []*url.URL {
return []*url.URL{}
}
func TestPrepareServerTimeouts(t *testing.T) {
testCases := []struct {
desc string
@ -282,559 +221,6 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c
return server, stop, invokeStopChan
}
func TestThrottleProviderConfigReload(t *testing.T) {
throttleDuration := 30 * time.Millisecond
publishConfig := make(chan types.ConfigMessage)
providerConfig := make(chan types.ConfigMessage)
stop := make(chan bool)
defer func() {
stop <- true
}()
globalConfig := configuration.GlobalConfiguration{}
server := NewServer(globalConfig, nil, nil)
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
publishedConfigCount := 0
stopConsumeConfigs := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-stopConsumeConfigs:
return
case <-publishConfig:
publishedConfigCount++
}
}
}()
// publish 5 new configs, one new config each 10 milliseconds
for i := 0; i < 5; i++ {
providerConfig <- types.ConfigMessage{}
time.Sleep(10 * time.Millisecond)
}
// after 50 milliseconds 5 new configs were published
// with a throttle duration of 30 milliseconds this means, we should have received 2 new configs
assert.Equal(t, 2, publishedConfigCount, "times configs were published")
stopConsumeConfigs <- true
select {
case <-publishConfig:
// There should be exactly one more message that we receive after ~60 milliseconds since the start of the test.
select {
case <-publishConfig:
t.Error("extra config publication found")
case <-time.After(100 * time.Millisecond):
return
}
case <-time.After(100 * time.Millisecond):
t.Error("Last config was not published in time")
}
}
func TestServerMultipleFrontendRules(t *testing.T) {
testCases := []struct {
expression string
requestURL string
expectedURL string
}{
{
expression: "Host:foo.bar",
requestURL: "http://foo.bar",
expectedURL: "http://foo.bar",
},
{
expression: "PathPrefix:/management;ReplacePath:/health",
requestURL: "http://foo.bar/management",
expectedURL: "http://foo.bar/health",
},
{
expression: "Host:foo.bar;AddPrefix:/blah",
requestURL: "http://foo.bar/baz",
expectedURL: "http://foo.bar/blah/baz",
},
{
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}",
requestURL: "http://foo.bar/one/some/12345/four",
expectedURL: "http://foo.bar/four",
},
{
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero",
requestURL: "http://foo.bar/one/some/12345/four",
expectedURL: "http://foo.bar/zero/four",
},
{
expression: "AddPrefix:/blah;ReplacePath:/baz",
requestURL: "http://foo.bar/hello",
expectedURL: "http://foo.bar/baz",
},
{
expression: "PathPrefixStrip:/management;ReplacePath:/health",
requestURL: "http://foo.bar/management",
expectedURL: "http://foo.bar/health",
},
}
for _, test := range testCases {
test := test
t.Run(test.expression, func(t *testing.T) {
t.Parallel()
router := mux.NewRouter()
route := router.NewRoute()
serverRoute := &types.ServerRoute{Route: route}
rules := &rules.Rules{Route: serverRoute}
expression := test.expression
routeResult, err := rules.Parse(expression)
if err != nil {
t.Fatalf("Error while building route for %s: %+v", expression, err)
}
request := th.MustNewRequest(http.MethodGet, test.requestURL, nil)
routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult})
if !routeMatch {
t.Fatalf("Rule %s doesn't match", expression)
}
server := new(Server)
server.wireFrontendBackend(serverRoute, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, test.expectedURL, r.URL.String(), "URL")
}))
serverRoute.Route.GetHandler().ServeHTTP(nil, request)
})
}
}
func TestServerLoadConfigHealthCheckOptions(t *testing.T) {
healthChecks := []*types.HealthCheck{
nil,
{
Path: "/path",
},
}
for _, lbMethod := range []string{"Wrr", "Drr"} {
for _, healthCheck := range healthChecks {
t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)},
}
entryPoints := map[string]EntryPoint{
"http": {
Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: lbMethod,
},
HealthCheck: healthCheck,
},
},
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
CertFile: localhostCert,
KeyFile: localhostKey,
},
EntryPoints: []string{"http"},
},
},
},
}
srv := NewServer(globalConfig, nil, entryPoints)
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
require.NoError(t, err)
expectedNumHealthCheckBackends := 0
if healthCheck != nil {
expectedNumHealthCheckBackends = 1
}
assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends")
})
}
}
}
func TestServerParseHealthCheckOptions(t *testing.T) {
lb := &testLoadBalancer{}
globalInterval := 15 * time.Second
testCases := []struct {
desc string
hc *types.HealthCheck
expectedOpts *healthcheck.Options
}{
{
desc: "nil health check",
hc: nil,
expectedOpts: nil,
},
{
desc: "empty path",
hc: &types.HealthCheck{
Path: "",
},
expectedOpts: nil,
},
{
desc: "unparseable interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "unparseable",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
LB: lb,
},
},
{
desc: "sub-zero interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "-42s",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
LB: lb,
},
},
{
desc: "parseable interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "5m",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: 5 * time.Minute,
LB: lb,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
opts := parseHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)})
assert.Equal(t, test.expectedOpts, opts, "health check options")
})
}
}
func TestBuildIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whitelistSourceRange []string
whiteList *types.WhiteList
middlewareConfigured bool
errMessage string
}{
{
desc: "no whitelists configured",
whitelistSourceRange: nil,
middlewareConfigured: false,
errMessage: "",
},
{
desc: "whitelists configured (deprecated)",
whitelistSourceRange: []string{
"1.2.3.4/24",
"fe80::/16",
},
middlewareConfigured: true,
errMessage: "",
},
{
desc: "invalid whitelists configured (deprecated)",
whitelistSourceRange: []string{
"foo",
},
middlewareConfigured: false,
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
},
{
desc: "whitelists configured",
whiteList: &types.WhiteList{
SourceRange: []string{
"1.2.3.4/24",
"fe80::/16",
},
UseXForwardedFor: false,
},
middlewareConfigured: true,
errMessage: "",
},
{
desc: "invalid whitelists configured (deprecated)",
whiteList: &types.WhiteList{
SourceRange: []string{
"foo",
},
UseXForwardedFor: false,
},
middlewareConfigured: false,
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
middleware, err := buildIPWhiteLister(test.whiteList, test.whitelistSourceRange)
if test.errMessage != "" {
require.EqualError(t, err, test.errMessage)
} else {
assert.NoError(t, err)
if test.middlewareConfigured {
require.NotNil(t, middleware, "not expected middleware to be configured")
} else {
require.Nil(t, middleware, "expected middleware to be configured")
}
}
})
}
}
func TestServerLoadConfigEmptyBasicAuth(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
EntryPoints: configuration.EntryPoints{
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
BasicAuth: []string{""},
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: "Wrr",
},
},
},
},
}
srv := NewServer(globalConfig, nil, nil)
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
require.NoError(t, err)
}
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http", "https"},
}
entryPoints := map[string]EntryPoint{
"https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}},
"http": {Configuration: &configuration.EntryPoint{}},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
CertFile: localhostCert,
KeyFile: localhostKey,
},
},
},
},
}
srv := NewServer(globalConfig, nil, entryPoints)
if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil {
t.Fatalf("got error: %s", err)
} else if mapEntryPoints["https"].certs.Get() == nil {
t.Fatal("got error: https entryPoint must have TLS certificates.")
}
}
func TestConfigureBackends(t *testing.T) {
validMethod := "Drr"
defaultMethod := "wrr"
testCases := []struct {
desc string
lb *types.LoadBalancer
expectedMethod string
expectedStickiness *types.Stickiness
}{
{
desc: "valid load balancer method with sticky enabled",
lb: &types.LoadBalancer{
Method: validMethod,
Stickiness: &types.Stickiness{},
},
expectedMethod: validMethod,
expectedStickiness: &types.Stickiness{},
},
{
desc: "valid load balancer method with sticky disabled",
lb: &types.LoadBalancer{
Method: validMethod,
Stickiness: nil,
},
expectedMethod: validMethod,
},
{
desc: "invalid load balancer method with sticky enabled",
lb: &types.LoadBalancer{
Method: "Invalid",
Stickiness: &types.Stickiness{},
},
expectedMethod: defaultMethod,
expectedStickiness: &types.Stickiness{},
},
{
desc: "invalid load balancer method with sticky disabled",
lb: &types.LoadBalancer{
Method: "Invalid",
Stickiness: nil,
},
expectedMethod: defaultMethod,
},
{
desc: "missing load balancer",
lb: nil,
expectedMethod: defaultMethod,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := &types.Backend{
LoadBalancer: test.lb,
}
configureBackends(map[string]*types.Backend{
"backend": backend,
})
expected := types.LoadBalancer{
Method: test.expectedMethod,
Stickiness: test.expectedStickiness,
}
assert.Equal(t, expected, *backend.LoadBalancer)
})
}
}
func TestServerEntryPointWhitelistConfig(t *testing.T) {
testCases := []struct {
desc string
entrypoint *configuration.EntryPoint
expectMiddleware bool
}{
{
desc: "no whitelist middleware if no config on entrypoint",
entrypoint: &configuration.EntryPoint{
Address: ":0",
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
expectMiddleware: false,
},
{
desc: "whitelist middleware should be added if configured on entrypoint",
entrypoint: &configuration.EntryPoint{
Address: ":0",
WhitelistSourceRange: []string{
"127.0.0.1/32",
},
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
expectMiddleware: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
srv := Server{
globalConfiguration: configuration.GlobalConfiguration{},
metricsRegistry: metrics.NewVoidRegistry(),
entryPoints: map[string]EntryPoint{
"test": {
Configuration: test.entrypoint,
},
},
}
srv.serverEntryPoints = srv.buildEntryPoints()
srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"])
handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni)
found := false
for _, handler := range handler.Handlers() {
if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) {
found = true
}
}
if found && !test.expectMiddleware {
t.Error("ip whitelist middleware was installed even though it should not")
}
if !found && test.expectMiddleware {
t.Error("ip whitelist middleware was not installed even though it should have")
}
})
}
}
func TestServerResponseEmptyBackend(t *testing.T) {
const requestPath = "/path"
const routeRule = "Path:" + requestPath
@ -962,157 +348,6 @@ func TestServerResponseEmptyBackend(t *testing.T) {
}
}
func TestReuseBackend(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http"},
}
entryPoints := map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
}},
}
dynamicConfigs := types.Configurations{
"config": th.BuildConfiguration(
th.WithFrontends(
th.WithFrontend("backend",
th.WithFrontendName("frontend0"),
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))),
th.WithFrontend("backend",
th.WithFrontendName("frontend1"),
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")),
th.WithBasicAuth("foo", "bar")),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLBMethod("wrr"),
th.WithServersNew(th.WithServerNew(testServer.URL))),
),
),
}
srv := NewServer(globalConfig, nil, entryPoints)
serverEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig)
if err != nil {
t.Fatalf("error loading config: %s", err)
}
// Test that the /ok path returns a status 200.
responseRecorderOk := &httptest.ResponseRecorder{}
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk)
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
// Test that the /unauthorized path returns a 401 because of
// the basic authentication defined on the frontend.
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
}
func TestBuildRedirectHandler(t *testing.T) {
srv := Server{
globalConfiguration: configuration.GlobalConfiguration{},
entryPoints: map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{Address: ":80"}},
"https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}},
},
}
testCases := []struct {
desc string
srcEntryPointName string
url string
entryPoint *configuration.EntryPoint
redirect *types.Redirect
expectedURL string
}{
{
desc: "redirect regex",
srcEntryPointName: "http",
url: "http://foo.com",
redirect: &types.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
},
expectedURL: "https://foobar.com",
},
{
desc: "redirect entry point",
srcEntryPointName: "http",
url: "http://foo:80",
redirect: &types.Redirect{
EntryPoint: "https",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
EntryPoint: "https",
},
},
expectedURL: "https://foo:443",
},
{
desc: "redirect entry point with regex (ignored)",
srcEntryPointName: "http",
url: "http://foo.com:80",
redirect: &types.Redirect{
EntryPoint: "https",
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
EntryPoint: "https",
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
},
expectedURL: "https://foo.com:443",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect)
require.NoError(t, err)
req := th.MustNewRequest(http.MethodGet, test.url, nil)
recorder := httptest.NewRecorder()
rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Location", "fail")
}))
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
})
}
}
type mockContext struct {
headers http.Header
}