diff --git a/pkg/logs/fields.go b/pkg/logs/fields.go index 27e05198f..962c1f449 100644 --- a/pkg/logs/fields.go +++ b/pkg/logs/fields.go @@ -11,7 +11,6 @@ const ( ServiceName = "serviceName" MetricsProviderName = "metricsProviderName" TracingProviderName = "tracingProviderName" - ServerName = "serverName" ServerIndex = "serverIndex" TLSStoreName = "tlsStoreName" ServersTransportName = "serversTransport" diff --git a/pkg/server/service/loadbalancer/wrr/wrr.go b/pkg/server/service/loadbalancer/wrr/wrr.go index 72ce19ff1..41e260d43 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr.go +++ b/pkg/server/service/loadbalancer/wrr/wrr.go @@ -3,6 +3,8 @@ package wrr import ( "container/heap" "context" + "crypto/sha256" + "encoding/hex" "errors" "hash/fnv" "net/http" @@ -15,9 +17,10 @@ import ( type namedHandler struct { http.Handler - name string - weight float64 - deadline float64 + name string + hashedName string + weight float64 + deadline float64 } type stickyCookie struct { @@ -53,9 +56,10 @@ type Balancer struct { handlersMu sync.RWMutex // References all the handlers by name and also by the hashed value of the name. - handlerMap map[string]*namedHandler - handlers []*namedHandler - curDeadline float64 + stickyMap map[string]*namedHandler + compatibilityStickyMap map[string]*namedHandler + handlers []*namedHandler + curDeadline float64 // status is a record of which child services of the Balancer are healthy, keyed // by name of child service. A service is initially added to the map when it is // created via Add, and it is later removed or added to the map as needed, @@ -73,7 +77,6 @@ func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { balancer := &Balancer{ status: make(map[string]struct{}), fenced: make(map[string]struct{}), - handlerMap: make(map[string]*namedHandler), wantsHealthCheck: wantHealthCheck, } if sticky != nil && sticky.Cookie != nil { @@ -88,6 +91,9 @@ func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { if sticky.Cookie.Path != nil { balancer.stickyCookie.path = *sticky.Cookie.Path } + + balancer.stickyMap = make(map[string]*namedHandler) + balancer.compatibilityStickyMap = make(map[string]*namedHandler) } return balancer @@ -218,7 +224,7 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if err == nil && cookie != nil { b.handlersMu.RLock() - handler, ok := b.handlerMap[cookie.Value] + handler, ok := b.stickyMap[cookie.Value] b.handlersMu.RUnlock() if ok && handler != nil { @@ -230,6 +236,22 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } } + + b.handlersMu.RLock() + handler, ok = b.compatibilityStickyMap[cookie.Value] + b.handlersMu.RUnlock() + + if ok && handler != nil { + b.handlersMu.RLock() + _, isHealthy := b.status[handler.name] + b.handlersMu.RUnlock() + if isHealthy { + b.writeStickyCookie(w, handler) + + handler.ServeHTTP(w, req) + return + } + } } } @@ -244,21 +266,25 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if b.stickyCookie != nil { - cookie := &http.Cookie{ - Name: b.stickyCookie.name, - Value: hash(server.name), - Path: b.stickyCookie.path, - HttpOnly: b.stickyCookie.httpOnly, - Secure: b.stickyCookie.secure, - SameSite: convertSameSite(b.stickyCookie.sameSite), - MaxAge: b.stickyCookie.maxAge, - } - http.SetCookie(w, cookie) + b.writeStickyCookie(w, server) } server.ServeHTTP(w, req) } +func (b *Balancer) writeStickyCookie(w http.ResponseWriter, handler *namedHandler) { + cookie := &http.Cookie{ + Name: b.stickyCookie.name, + Value: handler.hashedName, + Path: b.stickyCookie.path, + HttpOnly: b.stickyCookie.httpOnly, + Secure: b.stickyCookie.secure, + SameSite: convertSameSite(b.stickyCookie.sameSite), + MaxAge: b.stickyCookie.maxAge, + } + http.SetCookie(w, cookie) +} + // Add adds a handler. // A handler with a non-positive weight is ignored. func (b *Balancer) Add(name string, handler http.Handler, weight *int, fenced bool) { @@ -280,15 +306,41 @@ func (b *Balancer) Add(name string, handler http.Handler, weight *int, fenced bo if fenced { b.fenced[name] = struct{}{} } - b.handlerMap[name] = h - b.handlerMap[hash(name)] = h + + if b.stickyCookie != nil { + sha256HashedName := sha256Hash(name) + h.hashedName = sha256HashedName + + b.stickyMap[sha256HashedName] = h + b.compatibilityStickyMap[name] = h + + hashedName := fnvHash(name) + b.compatibilityStickyMap[hashedName] = h + + // server.URL was fnv hashed in service.Manager + // so we can have "double" fnv hash in already existing cookies + hashedName = fnvHash(hashedName) + b.compatibilityStickyMap[hashedName] = h + } b.handlersMu.Unlock() } -func hash(input string) string { +func fnvHash(input string) string { hasher := fnv.New64() // We purposely ignore the error because the implementation always returns nil. _, _ = hasher.Write([]byte(input)) return strconv.FormatUint(hasher.Sum64(), 16) } + +func sha256Hash(input string) string { + hash := sha256.New() + // We purposely ignore the error because the implementation always returns nil. + _, _ = hash.Write([]byte(input)) + + hashedInput := hex.EncodeToString(hash.Sum(nil)) + if len(hashedInput) < 16 { + return hashedInput + } + return hashedInput[:16] +} diff --git a/pkg/server/service/loadbalancer/wrr/wrr_test.go b/pkg/server/service/loadbalancer/wrr/wrr_test.go index d969bd279..1773c7ce0 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr_test.go +++ b/pkg/server/service/loadbalancer/wrr/wrr_test.go @@ -296,7 +296,7 @@ func TestSticky_FallBack(t *testing.T) { rw.WriteHeader(http.StatusOK) }), pointer(2), false) - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} + recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}, cookies: make(map[string]*http.Cookie)} req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: "test", Value: "second"}) @@ -373,7 +373,7 @@ func TestSticky_Fenced(t *testing.T) { rw.WriteHeader(http.StatusOK) }), pointer(1), true) - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} + recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}, cookies: make(map[string]*http.Cookie)} stickyReq := httptest.NewRequest(http.MethodGet, "/", nil) stickyReq.AddCookie(&http.Cookie{Name: "test", Value: "fenced"}) @@ -391,3 +391,99 @@ func TestSticky_Fenced(t *testing.T) { assert.Equal(t, 2, recorder.save["first"]) assert.Equal(t, 2, recorder.save["second"]) } + +func TestStickyWithCompatibility(t *testing.T) { + testCases := []struct { + desc string + servers []string + cookies []*http.Cookie + + expectedCookies []*http.Cookie + expectedServer string + }{ + { + desc: "No previous cookie", + servers: []string{"first"}, + + expectedServer: "first", + expectedCookies: []*http.Cookie{ + {Name: "test", Value: sha256Hash("first")}, + }, + }, + { + desc: "Sha256 previous cookie", + servers: []string{"first", "second"}, + cookies: []*http.Cookie{ + {Name: "test", Value: sha256Hash("first")}, + }, + expectedServer: "first", + expectedCookies: []*http.Cookie{}, + }, + { + desc: "Raw previous cookie", + servers: []string{"first", "second"}, + cookies: []*http.Cookie{ + {Name: "test", Value: "first"}, + }, + expectedServer: "first", + expectedCookies: []*http.Cookie{ + {Name: "test", Value: sha256Hash("first")}, + }, + }, + { + desc: "Fnv previous cookie", + servers: []string{"first", "second"}, + cookies: []*http.Cookie{ + {Name: "test", Value: fnvHash("first")}, + }, + expectedServer: "first", + expectedCookies: []*http.Cookie{ + {Name: "test", Value: sha256Hash("first")}, + }, + }, + { + desc: "Double fnv previous cookie", + servers: []string{"first", "second"}, + cookies: []*http.Cookie{ + {Name: "test", Value: fnvHash(fnvHash("first"))}, + }, + expectedServer: "first", + expectedCookies: []*http.Cookie{ + {Name: "test", Value: sha256Hash("first")}, + }, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + balancer := New(&dynamic.Sticky{Cookie: &dynamic.Cookie{Name: "test"}}, false) + + for _, server := range test.servers { + balancer.Add(server, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte(server)) + }), pointer(1), false) + } + + // Do it twice, to be sure it's not just the luck. + for range 2 { + req := httptest.NewRequest(http.MethodGet, "/", nil) + for _, cookie := range test.cookies { + req.AddCookie(cookie) + } + + recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}, cookies: make(map[string]*http.Cookie)} + balancer.ServeHTTP(recorder, req) + + assert.Equal(t, test.expectedServer, recorder.Body.String()) + + assert.Len(t, recorder.cookies, len(test.expectedCookies)) + for _, cookie := range test.expectedCookies { + assert.Equal(t, cookie.Value, recorder.cookies[cookie.Name].Value) + } + } + }) + } +} diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index 83e4f8d37..245cb08f6 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -2,11 +2,9 @@ package service import ( "context" - "encoding/hex" "encoding/json" "errors" "fmt" - "hash/fnv" "math/rand" "net/http" "net/url" @@ -335,18 +333,13 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName lb := wrr.New(service.Sticky, service.HealthCheck != nil) healthCheckTargets := make(map[string]*url.URL) - for _, server := range shuffle(service.Servers, m.rand) { - hasher := fnv.New64a() - _, _ = hasher.Write([]byte(server.URL)) // this will never return an error. - - proxyName := hex.EncodeToString(hasher.Sum(nil)) - + for i, server := range shuffle(service.Servers, m.rand) { target, err := url.Parse(server.URL) if err != nil { return nil, fmt.Errorf("error parsing server URL %s: %w", server.URL, err) } - logger.Debug().Str(logs.ServerName, proxyName).Stringer("target", target). + logger.Debug().Int(logs.ServerIndex, i).Str("URL", server.URL). Msg("Creating server") qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName) @@ -392,12 +385,12 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName proxy, _ = capture.Wrap(proxy) } - lb.Add(proxyName, proxy, server.Weight, server.Fenced) + lb.Add(server.URL, proxy, server.Weight, server.Fenced) // servers are considered UP by default. info.UpdateServerStatus(target.String(), runtime.StatusUp) - healthCheckTargets[proxyName] = target + healthCheckTargets[server.URL] = target } if service.HealthCheck != nil {