From ea5b8cc21fb6c8feac6226b1131c45e9a173df2f Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sun, 29 Nov 2020 14:58:01 -0800 Subject: [PATCH] Support non-list and complex groups --- CHANGELOG.md | 1 + providers/oidc.go | 18 +++--- providers/oidc_test.go | 107 ++++++++++++++++++++++++++++---- providers/provider_data.go | 36 +++++++---- providers/provider_data_test.go | 20 ++++-- providers/util.go | 20 ++++++ 6 files changed, 166 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74c272e..84f54de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) - [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed) - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) +- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe) - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves) - [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed) - [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed) diff --git a/providers/oidc.go b/providers/oidc.go index d7d3470..98cefb4 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -59,7 +60,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta } // Try to get missing emails or groups from a profileURL - if s.Email == "" || len(s.Groups) == 0 { + if s.Email == "" || s.Groups == nil { err := p.callProfileURL(ctx, s) if err != nil { logger.Errorf("Warning: Profile URL request failed: %v", err) @@ -90,16 +91,15 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt s.Email = email } - // Handle array & singleton groups cases if len(s.Groups) == 0 { - groups, err := respJSON.Get(p.GroupsClaim).StringArray() - if err == nil { - s.Groups = groups - } else { - group, err := respJSON.Get(p.GroupsClaim).String() - if err == nil { - s.Groups = []string{group} + for _, group := range coerceArray(respJSON, p.GroupsClaim) { + formatted, err := formatGroup(group) + if err != nil { + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(group), err) + continue } + s.Groups = append(s.Groups, formatted) } } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 2651b4e..7ac9863 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -68,7 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) { return u, s } -func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) { +func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) { redeemURL, server := newOIDCServer(body) provider := newOIDCProvider(redeemURL) return server, provider @@ -85,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) { IDToken: idToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) defer server.Close() session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") @@ -108,7 +108,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { IDToken: idToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) provider.EmailClaim = "phone_number" defer server.Close() @@ -247,7 +247,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { ExistingSession: &sessions.SessionState{ User: "already", Email: "already@populated.com", - Groups: []string{}, + Groups: nil, IDToken: idToken, AccessToken: accessToken, RefreshToken: refreshToken, @@ -268,6 +268,89 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { RefreshToken: refreshToken, }, }, + "Missing Groups with Complex Groups in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: nil, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []map[string]interface{}{ + { + "groupId": "Admin Group Id", + "roles": []string{"Admin"}, + }, + }, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups with Singleton Complex Group in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: nil, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": map[string]interface{}{ + "groupId": "Admin Group Id", + "roles": []string{"Admin"}, + }, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Empty Groups Claims": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, "Missing Groups with Custom Claim": { ExistingSession: &sessions.SessionState{ User: "already", @@ -297,7 +380,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { ExistingSession: &sessions.SessionState{ User: "already", Email: "already@populated.com", - Groups: []string{}, + Groups: nil, IDToken: idToken, AccessToken: accessToken, RefreshToken: refreshToken, @@ -346,7 +429,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { jsonResp, err := json.Marshal(tc.ProfileJSON) assert.NoError(t, err) - server, provider := newTestSetup(jsonResp) + server, provider := newTestOIDCSetup(jsonResp) provider.ProfileURL, err = url.Parse(server.URL) assert.NoError(t, err) @@ -371,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { RefreshToken: refreshToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) defer server.Close() existingSession := &sessions.SessionState{ @@ -405,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { IDToken: idToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) defer server.Close() existingSession := &sessions.SessionState{ @@ -433,7 +516,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { GroupsClaim string ExpectedUser string ExpectedEmail string - ExpectedGroups interface{} + ExpectedGroups []string }{ "Default IDToken": { IDToken: defaultIDToken, @@ -447,7 +530,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { GroupsClaim: "groups", ExpectedUser: "123456789", ExpectedEmail: "123456789", - ExpectedGroups: []string{}, + ExpectedGroups: nil, }, "Custom Groups Claim": { IDToken: defaultIDToken, @@ -466,7 +549,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { - server, provider := newTestSetup([]byte(`{}`)) + server, provider := newTestOIDCSetup([]byte(`{}`)) provider.GroupsClaim = tc.GroupsClaim defer server.Close() @@ -478,9 +561,9 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { assert.Equal(t, tc.ExpectedUser, ss.User) assert.Equal(t, tc.ExpectedEmail, ss.Email) + assert.Equal(t, tc.ExpectedGroups, ss.Groups) assert.Equal(t, rawIDToken, ss.IDToken) assert.Equal(t, rawIDToken, ss.AccessToken) - assert.Equal(t, tc.ExpectedGroups, ss.Groups) assert.Equal(t, "", ss.RefreshToken) }) } diff --git a/providers/provider_data.go b/providers/provider_data.go index ae43451..098e619 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -189,20 +189,34 @@ func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { return claims, nil } -// extractGroups extracts groups from a claim to a list in a type safe manner +// extractGroups extracts groups from a claim to a list in a type safe manner. +// If the claim isn't present, `nil` is returned. If the groups claim is +// present but empty, `[]string{}` is returned. func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { + rawClaim, ok := claims[p.GroupsClaim] + if !ok { + return nil + } + + // Handle traditional list-based groups as well as non-standard singleton + // based groups. Both variants support complex objects if needed. + var claimGroups []interface{} + switch raw := rawClaim.(type) { + case []interface{}: + claimGroups = raw + case interface{}: + claimGroups = []interface{}{raw} + } + groups := []string{} - rawGroups, ok := claims[p.GroupsClaim].([]interface{}) - if rawGroups != nil && ok { - for _, rawGroup := range rawGroups { - formattedGroup, err := formatGroup(rawGroup) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(rawGroup), err) - continue - } - groups = append(groups, formattedGroup) + for _, rawGroup := range claimGroups { + formattedGroup, err := formatGroup(rawGroup) + if err != nil { + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(rawGroup), err) + continue } + groups = append(groups, formattedGroup) } return groups } diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 4aed73e..f94c0db 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -300,7 +300,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", - Groups: []string{}, + Groups: nil, PreferredUsername: "Jane Dobbs", }, }, @@ -386,12 +386,20 @@ func TestProviderData_extractGroups(t *testing.T) { "Just::A::String", }, }, - "Missing Groups": { + "Missing Groups Claim Returns Nil": { Claims: map[string]interface{}{ "email": "this@does.not.matter.com", }, GroupsClaim: "groups", - ExpectedGroups: []string{}, + ExpectedGroups: nil, + }, + "Non List Groups": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + "groups": "singleton", + }, + GroupsClaim: "groups", + ExpectedGroups: []string{"singleton"}, }, } for testName, tc := range testCases { @@ -408,7 +416,11 @@ func TestProviderData_extractGroups(t *testing.T) { provider.GroupsClaim = tc.GroupsClaim groups := provider.extractGroups(tc.Claims) - g.Expect(groups).To(Equal(tc.ExpectedGroups)) + if tc.ExpectedGroups != nil { + g.Expect(groups).To(Equal(tc.ExpectedGroups)) + } else { + g.Expect(groups).To(BeNil()) + } }) } } diff --git a/providers/util.go b/providers/util.go index acf2090..055d29d 100644 --- a/providers/util.go +++ b/providers/util.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" + "github.com/bitly/go-simplejson" "golang.org/x/oauth2" ) @@ -59,6 +60,8 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va return a } +// getIDToken extracts an IDToken stored in the `Extra` fields of an +// oauth2.Token func getIDToken(token *oauth2.Token) string { idToken, ok := token.Extra("id_token").(string) if !ok { @@ -67,6 +70,8 @@ func getIDToken(token *oauth2.Token) string { return idToken } +// formatGroup coerces an OIDC groups claim into a string +// If it is non-string, marshal it into JSON. func formatGroup(rawGroup interface{}) (string, error) { group, ok := rawGroup.(string) if !ok { @@ -78,3 +83,18 @@ func formatGroup(rawGroup interface{}) (string, error) { } return group, nil } + +// coerceArray extracts a field from simplejson.Json that might be a +// singleton or a list and coerces it into a list. +func coerceArray(sj *simplejson.Json, key string) []interface{} { + array, err := sj.Get(key).Array() + if err == nil { + return array + } + + single := sj.Get(key).Interface() + if single == nil { + return nil + } + return []interface{}{single} +}