From 967051314e5d84cb389aa33ca8e0ff624b43c7eb Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 26 Jun 2021 11:49:08 +0100 Subject: [PATCH] Integrate claim extractor into providers --- providers/adfs.go | 13 +- providers/adfs_test.go | 6 +- providers/auth_test.go | 13 +- providers/azure.go | 19 +- providers/azure_test.go | 39 ++-- providers/digitalocean.go | 2 + providers/facebook.go | 1 + providers/linkedin.go | 2 + providers/nextcloud.go | 29 +-- providers/nextcloud_test.go | 72 ------- providers/oidc.go | 65 +----- providers/oidc_test.go | 343 +------------------------------- providers/provider_data.go | 138 +++++-------- providers/provider_data_test.go | 187 +++++++---------- providers/util.go | 16 -- 15 files changed, 212 insertions(+), 733 deletions(-) diff --git a/providers/adfs.go b/providers/adfs.go index 797c856..f5cbbfc 100644 --- a/providers/adfs.go +++ b/providers/adfs.go @@ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt } func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { - idToken, err := p.Verifier.Verify(ctx, s.IDToken) + claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken) if err != nil { - return err + return fmt.Errorf("could not extract claims: %v", err) } - claims, err := p.getClaims(idToken) + + upn, found, err := claims.GetClaim(adfsUPNClaim) if err != nil { - return fmt.Errorf("couldn't extract claims from id_token (%v)", err) + return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err) } - upn := claims.raw[adfsUPNClaim] - if upn != nil { + + if found && fmt.Sprint(upn) != "" { s.Email = fmt.Sprint(upn) } return nil diff --git a/providers/adfs_test.go b/providers/adfs_test.go index 7eb1c48..93e61ea 100755 --- a/providers/adfs_test.go +++ b/providers/adfs_test.go @@ -79,7 +79,7 @@ func testADFSBackend() *httptest.Server { { "access_token": "my_access_token", "id_token": "my_id_token", - "refresh_token": "my_refresh_token" + "refresh_token": "my_refresh_token" } ` userInfo := ` @@ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() { Context("with valid token", func() { It("should not throw an error", func() { rawIDToken, _ := newSignedTestIDToken(defaultIDToken) - idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) - Expect(err).To(BeNil()) - session, err := p.buildSessionFromClaims(idToken) + session, err := p.buildSessionFromClaims(rawIDToken, "") Expect(err).To(BeNil()) session.IDToken = rawIDToken err = p.EnrichSession(context.Background(), session) diff --git a/providers/auth_test.go b/providers/auth_test.go index 2ece923..bda93b9 100644 --- a/providers/auth_test.go +++ b/providers/auth_test.go @@ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState { } func IsAuthorizedInHeader(reqHeader http.Header) bool { - return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken) + return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken) +} + +func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool { + return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token) } func IsAuthorizedInURL(reqURL *url.URL) bool { return reqURL.Query().Get("access_token") == authorizedAccessToken } + +func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool { + if token == "" { + return false + } + return reqURL.Query().Get("refresh_token") == token +} diff --git a/providers/azure.go b/providers/azure.go index 39beb83..10ea701 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { if p.ValidateURL == nil || p.ValidateURL.String() == "" { p.ValidateURL = p.ProfileURL } + p.getAuthorizationHeaderFunc = makeAzureHeader return &AzureProvider{ ProviderData: p, @@ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* session.CreatedAtNow() session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) - email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) + email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken) // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 @@ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* } if session.Email == "" { - email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken) + email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken) if err == nil && email != "" { session.Email = email } else { @@ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err // verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token // when oidc verifier is configured -func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) { +func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) { email := "" - if token != "" && p.Verifier != nil { - token, err := p.Verifier.Verify(ctx, token) + if rawIDToken != "" && p.Verifier != nil { + _, err := p.Verifier.Verify(ctx, rawIDToken) // due to issues mentioned above, id_token may not be signed by AAD if err == nil { - claims, err := p.getClaims(token) + s, err := p.buildSessionFromClaims(rawIDToken, accessToken) if err == nil { - email = claims.Email + email = s.Email } else { logger.Printf("unable to get claims from token: %v", err) } @@ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess s.CreatedAtNow() s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) - email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) + email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken) // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 @@ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess } if s.Email == "" { - email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken) + email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken) if err == nil && email != "" { s.Email = email } else { diff --git a/providers/azure_test.go b/providers/azure_test.go index 3bb16d8..25b8c20 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -13,9 +13,8 @@ import ( "testing" "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt" - - oidc "github.com/coreos/go-oidc/v3/oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" @@ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) { assert.Equal(t, "openid", p.Data().Scope) } -func testAzureBackend(payload string) *httptest.Server { - return testAzureBackendWithError(payload, false) +func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server { + return testAzureBackendWithError(payload, accessToken, refreshToken, false) } -func testAzureBackendWithError(payload string, injectError bool) *httptest.Server { +func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server { path := "/v1.0/me" return httptest.NewServer(http.HandlerFunc( @@ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve w.WriteHeader(200) } w.Write([]byte(payload)) - } else if !IsAuthorizedInHeader(r.Header) { + } else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) && + !isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) { w.WriteHeader(403) } else { w.WriteHeader(200) @@ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) { host string ) if testCase.PayloadFromAzureBackend != "" { - b = testAzureBackend(testCase.PayloadFromAzureBackend) + b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "") defer b.Close() bURL, _ := url.Parse(b.URL) @@ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) { payloadBytes, err := json.Marshal(payload) assert.NoError(t, err) - b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError) + b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError) defer b.Close() bURL, _ := url.Parse(b.URL) @@ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { func TestAzureProviderRefresh(t *testing.T) { email := "foo@example.com" + subject := "foo" idToken := idTokenClaims{ - StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"}, - Email: email} + Email: email, + StandardClaims: jwt.StandardClaims{ + Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532", + Subject: subject, + }, + } idTokenString, err := newSignedTestIDToken(idToken) assert.NoError(t, err) + timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") assert.NoError(t, err) + + newAccessToken := "new_some_access_token" payload := azureOAuthPayload{ IDToken: idTokenString, RefreshToken: "new_some_refresh_token", - AccessToken: "new_some_access_token", + AccessToken: newAccessToken, ExpiresOn: timestamp.Unix(), } - payloadBytes, err := json.Marshal(payload) assert.NoError(t, err) - b := testAzureBackend(string(payloadBytes)) + + refreshToken := "some_refresh_token" + b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) expires := time.Now().Add(time.Duration(-1) * time.Hour) - session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} + session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires} refreshed, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) assert.True(t, refreshed) assert.NotEqual(t, session, nil) - assert.Equal(t, "new_some_access_token", session.AccessToken) + assert.Equal(t, newAccessToken, session.AccessToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken) assert.Equal(t, idTokenString, session.IDToken) assert.Equal(t, email, session.Email) diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 4c1196d..acbd6f7 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { validateURL: digitalOceanDefaultProfileURL, scope: digitalOceanDefaultScope, }) + p.getAuthorizationHeaderFunc = makeOIDCHeader + return &DigitalOceanProvider{ProviderData: p} } diff --git a/providers/facebook.go b/providers/facebook.go index 6db9c38..cfa836d 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { validateURL: facebookDefaultProfileURL, scope: facebookDefaultScope, }) + p.getAuthorizationHeaderFunc = makeOIDCHeader return &FacebookProvider{ProviderData: p} } diff --git a/providers/linkedin.go b/providers/linkedin.go index cac8022..3904d84 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { validateURL: linkedinDefaultValidateURL, scope: linkedinDefaultScope, }) + p.getAuthorizationHeaderFunc = makeLinkedInHeader + return &LinkedInProvider{ProviderData: p} } diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 4a074d6..e915601 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -1,13 +1,5 @@ package providers -import ( - "context" - "fmt" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" -) - // NextcloudProvider represents an Nextcloud based Identity Provider type NextcloudProvider struct { *ProviderData @@ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud" // NewNextcloudProvider initiates a new NextcloudProvider func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { p.ProviderName = nextCloudProviderName + p.getAuthorizationHeaderFunc = makeOIDCHeader + if p.EmailClaim == OIDCEmailClaim { + // This implies the email claim has not been overridden, we should set a default + // for this provider + p.EmailClaim = "ocs.data.email" + } return &NextcloudProvider{ProviderData: p} } - -// GetEmailAddress returns the Account email address -func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - json, err := requests.New(p.ValidateURL.String()). - WithContext(ctx). - WithHeaders(makeOIDCHeader(s.AccessToken)). - Do(). - UnmarshalJSON() - if err != nil { - return "", fmt.Errorf("error making request: %v", err) - } - - email, err := json.Get("ocs").Get("data").Get("email").String() - return email, err -} diff --git a/providers/nextcloud_test.go b/providers/nextcloud_test.go index cd26885..92f5030 100644 --- a/providers/nextcloud_test.go +++ b/providers/nextcloud_test.go @@ -1,18 +1,13 @@ package providers import ( - "context" - "net/http" - "net/http/httptest" "net/url" "testing" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) const formatJSON = "format=json" -const userPath = "/ocs/v2.php/cloud/user" func testNextcloudProvider(hostname string) *NextcloudProvider { p := NewNextcloudProvider( @@ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider { return p } -func testNextcloudBackend(payload string) *httptest.Server { - path := userPath - query := formatJSON - - return httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != path || r.URL.RawQuery != query { - w.WriteHeader(404) - } else if !IsAuthorizedInHeader(r.Header) { - w.WriteHeader(403) - } else { - w.WriteHeader(200) - w.Write([]byte(payload)) - } - })) -} - func TestNextcloudProviderDefaults(t *testing.T) { p := testNextcloudProvider("") assert.NotEqual(t, nil, p) @@ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) { assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON, p.Data().ValidateURL.String()) } - -func TestNextcloudProviderGetEmailAddress(t *testing.T) { - b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testNextcloudProvider(bURL.Host) - p.ValidateURL.Path = userPath - p.ValidateURL.RawQuery = formatJSON - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) -} - -// Note that trying to trigger the "failed building request" case is not -// practical, since the only way it can fail is if the URL fails to parse. -func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) { - b := testNextcloudBackend("unused payload") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testNextcloudProvider(bURL.Host) - p.ValidateURL.Path = userPath - p.ValidateURL.RawQuery = formatJSON - - // We'll trigger a request failure by using an unexpected access - // token. Alternatively, we could allow the parsing of the payload as - // JSON to fail. - session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} - -func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { - b := testNextcloudBackend("{\"foo\": \"bar\"}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testNextcloudProvider(bURL.Host) - p.ValidateURL.Path = userPath - p.ValidateURL.RawQuery = formatJSON - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} diff --git a/providers/oidc.go b/providers/oidc.go index b1711d5..cccb8d7 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -5,12 +5,10 @@ import ( "errors" "fmt" "net/url" - "reflect" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "golang.org/x/oauth2" ) @@ -24,6 +22,8 @@ type OIDCProvider struct { // NewOIDCProvider initiates a new OIDCProvider func NewOIDCProvider(p *ProviderData) *OIDCProvider { p.ProviderName = "OpenID Connect" + p.getAuthorizationHeaderFunc = makeOIDCHeader + return &OIDCProvider{ ProviderData: p, SkipNonce: true, @@ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s // EnrichSession is called after Redeem to allow providers to enrich session fields // such as User, Email, Groups with provider specific API calls. func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { - if p.ProfileURL.String() == "" { - if s.Email == "" { - return errors.New("id_token did not contain an email and profileURL is not defined") - } - return nil - } - - // Try to get missing emails or groups from a profileURL - if s.Email == "" || s.Groups == nil { - err := p.enrichFromProfileURL(ctx, s) - if err != nil { - logger.Errorf("Warning: Profile URL request failed: %v", err) - } - } - // If a mandatory email wasn't set, error at this point. if s.Email == "" { return errors.New("neither the id_token nor the profileURL set an email") @@ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta return nil } -// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of -// an OIDC profile URL -func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error { - respJSON, err := requests.New(p.ProfileURL.String()). - WithContext(ctx). - WithHeaders(makeOIDCHeader(s.AccessToken)). - Do(). - UnmarshalJSON() - if err != nil { - return err - } - - email, err := respJSON.Get(p.EmailClaim).String() - if err == nil && s.Email == "" { - s.Email = email - } - - if len(s.Groups) > 0 { - return nil - } - for _, group := range coerceArray(respJSON, p.GroupsClaim) { - formatted, err := formatGroup(group) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(group), err) - continue - } - s.Groups = append(s.Groups, formatted) - } - - return nil -} - // ValidateSession checks that the session's IDToken is still valid func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { - idToken, err := p.Verifier.Verify(ctx, s.IDToken) + _, err := p.Verifier.Verify(ctx, s.IDToken) if err != nil { logger.Errorf("id_token verification failed: %v", err) return false @@ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS if p.SkipNonce { return true } - err = p.checkNonce(s, idToken) + err = p.checkNonce(s) if err != nil { logger.Errorf("nonce verification failed: %v", err) return false @@ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) return nil, err } - ss, err := p.buildSessionFromClaims(idToken) + ss, err := p.buildSessionFromClaims(token, "") if err != nil { return nil, err } @@ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) // createSession takes an oauth2.Token and creates a SessionState from it. // It alters behavior if called from Redeem vs Refresh func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { - idToken, err := p.verifyIDToken(ctx, token) + _, err := p.verifyIDToken(ctx, token) if err != nil { switch err { case ErrMissingIDToken: @@ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r } } - ss, err := p.buildSessionFromClaims(idToken) + rawIDToken := getIDToken(token) + ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken) if err != nil { return nil, err } ss.AccessToken = token.AccessToken ss.RefreshToken = token.RefreshToken - ss.IDToken = getIDToken(token) + ss.IDToken = rawIDToken ss.CreatedAtNow() ss.SetExpiresOn(token.Expiry) diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 3c87da0..1a98f46 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/http/httptest" @@ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { Scope: "openid profile offline_access", EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", Verifier: internaloidc.NewVerifier(oidc.NewVerifier( oidcIssuer, mockJWKS{}, @@ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { assert.Equal(t, defaultIDToken.Phone, session.Email) } -func TestOIDCProvider_EnrichSession(t *testing.T) { - testCases := map[string]struct { - ExistingSession *sessions.SessionState - EmailClaim string - GroupsClaim string - ProfileJSON map[string]interface{} - ExpectedError error - ExpectedSession *sessions.SessionState - }{ - "Already Populated": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Email": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "found@email.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Email: "found@email.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - - "Missing Email Only in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "found@email.com", - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Email: "found@email.com", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Email with Custom Claim": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "weird", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "weird": "weird@claim.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Email: "weird@claim.com", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Email not in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "groups": []string{"new", "thing"}, - }, - ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), - ExpectedSession: &sessions.SessionState{ - User: "missing.email", - Groups: []string{"already", "populated"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"new", "thing"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups with Complex Groups in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []map[string]interface{}{ - { - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - }, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups with Singleton Complex Group in Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": map[string]interface{}{ - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Empty Groups Claims": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": []string{"new", "thing"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups with Custom Claim": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "roles", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "roles": []string{"new", "thing", "roles"}, - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"new", "thing", "roles"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups String Profile URL Response": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: nil, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - "groups": "singleton", - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - Groups: []string{"singleton"}, - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - "Missing Groups in both Claims and Profile URL": { - ExistingSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - EmailClaim: "email", - GroupsClaim: "groups", - ProfileJSON: map[string]interface{}{ - "email": "new@thing.com", - }, - ExpectedError: nil, - ExpectedSession: &sessions.SessionState{ - User: "already", - Email: "already@populated.com", - IDToken: idToken, - AccessToken: accessToken, - RefreshToken: refreshToken, - }, - }, - } - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - jsonResp, err := json.Marshal(tc.ProfileJSON) - assert.NoError(t, err) - - server, provider := newTestOIDCSetup(jsonResp) - provider.ProfileURL, err = url.Parse(server.URL) - assert.NoError(t, err) - - provider.EmailClaim = tc.EmailClaim - provider.GroupsClaim = tc.GroupsClaim - defer server.Close() - - err = provider.EnrichSession(context.Background(), tc.ExistingSession) - assert.Equal(t, tc.ExpectedError, err) - assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) - }) - } -} - func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { idToken, _ := newSignedTestIDToken(defaultIDToken) @@ -565,11 +238,15 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { ExpectedGroups: []string{"test:c", "test:d"}, }, "Complex Groups Claim": { - IDToken: complexGroupsIDToken, - GroupsClaim: "groups", - ExpectedUser: "123456789", - ExpectedEmail: "complex@claims.com", - ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + IDToken: complexGroupsIDToken, + GroupsClaim: "groups", + ExpectedUser: "123456789", + ExpectedEmail: "complex@claims.com", + ExpectedGroups: []string{ + "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", + "12345", + "Just::A::String", + }, }, } for testName, tc := range testCases { diff --git a/providers/provider_data.go b/providers/provider_data.go index 38f1740..13241ee 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -5,20 +5,23 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "net/url" - "reflect" "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util" "golang.org/x/oauth2" ) const ( OIDCEmailClaim = "email" OIDCGroupsClaim = "groups" + // This is not exported as it's not currently user configurable + oidcUserClaim = "sub" ) var OIDCAudienceClaims = []string{"aud"} @@ -52,6 +55,8 @@ type ProviderData struct { // Universal Group authorization data structure // any provider can set to consume AllowedGroups map[string]struct{} + + getAuthorizationHeaderFunc func(string) http.Header } // Data returns the ProviderData @@ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) { if p.Scope == "" { p.Scope = defaults.scope } + + if p.UserClaim == "" { + p.UserClaim = oidcUserClaim + } } // defaultURL will set return a default value if the given value is not set. @@ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL { // OIDC compliant // **************************************************************************** -// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload -type OIDCClaims struct { - Subject string `json:"sub"` - Email string `json:"-"` - Groups []string `json:"-"` - Verified *bool `json:"email_verified"` - Nonce string `json:"nonce"` - - raw map[string]interface{} -} - func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { rawIDToken := getIDToken(token) if strings.TrimSpace(rawIDToken) == "" { @@ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) ( // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState // with non-Token related fields. -func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { +func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) { ss := &sessions.SessionState{} - if idToken == nil { + if rawIDToken == "" { return ss, nil } - claims, err := p.getClaims(idToken) + extractor, err := p.getClaimExtractor(rawIDToken, accessToken) if err != nil { - return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) + return nil, err } - ss.User = claims.Subject - ss.Email = claims.Email - ss.Groups = claims.Groups - - // Allow specialized providers that embed OIDCProvider to control the User - // claim. Not exposed as a configuration flag to generic OIDC provider - // users (yet). - if p.UserClaim != "" { - user, ok := claims.raw[p.UserClaim].(string) - if !ok { - return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim) + // Use a slice of a struct (vs map) here in case the same claim is used twice + for _, c := range []struct { + claim string + dst interface{} + }{ + {p.UserClaim, &ss.User}, + {p.EmailClaim, &ss.Email}, + {p.GroupsClaim, &ss.Groups}, + // TODO (@NickMeves) Deprecate for dynamic claim to session mapping + {"preferred_username", &ss.PreferredUsername}, + } { + if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil { + return nil, err } - ss.User = user - } - - // TODO (@NickMeves) Deprecate for dynamic claim to session mapping - if pref, ok := claims.raw["preferred_username"].(string); ok { - ss.PreferredUsername = pref } // `email_verified` must be present and explicitly set to `false` to be // considered unverified. verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail - if verifyEmail && claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + + var verified bool + exists, err := extractor.GetClaimInto("email_verified", &verified) + if err != nil { + return nil, err + } + + if verifyEmail && exists && !verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email) } return ss, nil } -// getClaims extracts IDToken claims into an OIDCClaims -func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { - claims := &OIDCClaims{} - - // Extract default claims. - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("failed to parse default id_token claims: %v", err) - } - // Extract custom claims. - if err := idToken.Claims(&claims.raw); err != nil { - return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) +func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) { + extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken)) + if err != nil { + return nil, fmt.Errorf("could not initialise claim extractor: %v", err) } - email := claims.raw[p.EmailClaim] - if email != nil { - claims.Email = fmt.Sprint(email) - } - claims.Groups = p.extractGroups(claims.raw) - - return claims, nil + return extractor, nil } // checkNonce compares the session's nonce with the IDToken's nonce claim -func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error { - claims, err := p.getClaims(idToken) +func (p *ProviderData) checkNonce(s *sessions.SessionState) error { + extractor, err := p.getClaimExtractor(s.IDToken, "") if err != nil { return fmt.Errorf("id_token claims extraction failed: %v", err) } - if !s.CheckNonce(claims.Nonce) { + var nonce string + if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil { + return fmt.Errorf("could not extract nonce from ID Token: %v", err) + } + + if !s.CheckNonce(nonce) { return errors.New("id_token nonce claim does not match the session nonce") } return nil } -// extractGroups extracts groups from a claim to a list in a type safe manner. -// If the claim isn't present, `nil` is returned. If the groups claim is -// present but empty, `[]string{}` is returned. -func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { - rawClaim, ok := claims[p.GroupsClaim] - if !ok { - return nil +func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header { + if p.getAuthorizationHeaderFunc != nil && accessToken != "" { + return p.getAuthorizationHeaderFunc(accessToken) } - - // Handle traditional list-based groups as well as non-standard singleton - // based groups. Both variants support complex objects if needed. - var claimGroups []interface{} - switch raw := rawClaim.(type) { - case []interface{}: - claimGroups = raw - case interface{}: - claimGroups = []interface{}{raw} - } - - groups := []string{} - for _, rawGroup := range claimGroups { - formattedGroup, err := formatGroup(rawGroup) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(rawGroup), err) - continue - } - groups = append(groups, formattedGroup) - } - return groups + return nil } diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 8e6d12c..64c8326 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -60,16 +60,30 @@ var ( StandardClaims: standardClaims, } + numericGroupsIDToken = idTokenClaims{ + Name: "Jane Dobbs", + Email: "janed@me.com", + Phone: "+4798765432", + Picture: "http://mugbook.com/janed/me.jpg", + Groups: []interface{}{1, 2, 3}, + Roles: []string{"test:c", "test:d"}, + Verified: &verified, + Nonce: encryption.HashNonce([]byte(oidcNonce)), + StandardClaims: standardClaims, + } + complexGroupsIDToken = idTokenClaims{ Name: "Complex Claim", Email: "complex@claims.com", Phone: "+5439871234", Picture: "http://mugbook.com/complex/claims.jpg", - Groups: []map[string]interface{}{ - { + Groups: []interface{}{ + map[string]interface{}{ "groupId": "Admin Group Id", "roles": []string{"Admin"}, }, + 12345, + "Just::A::String", }, Roles: []string{"test:simple", "test:roles"}, Verified: &verified, @@ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: false, EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", @@ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "unverified@email.com", @@ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "email", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ - User: "123456789", - Email: "complex@claims.com", - Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + User: "123456789", + Email: "complex@claims.com", + Groups: []string{ + "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", + "12345", + "Just::A::String", + }, PreferredUsername: "Complex Claim", }, }, @@ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { PreferredUsername: "Jane Dobbs", }, }, - "User Claim Invalid": { + "User Claim switched to non string": { IDToken: defaultIDToken, AllowUnverified: true, - UserClaim: "groups", + UserClaim: "roles", EmailClaim: "email", GroupsClaim: "groups", - ExpectedError: errors.New("unable to extract custom UserClaim (groups)"), + ExpectedSession: &sessions.SessionState{ + User: "[\"test:c\",\"test:d\"]", + Email: "janed@me.com", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Jane Dobbs", + }, }, "Email Claim Switched": { IDToken: unverifiedIDToken, AllowUnverified: true, EmailClaim: "phone_number", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "+4025205729", @@ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "roles", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", - Email: "[test:c test:d]", + Email: "[\"test:c\",\"test:d\"]", Groups: []string{"test:a", "test:b"}, PreferredUsername: "Mystery Man", }, @@ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: true, EmailClaim: "aksjdfhjksadh", GroupsClaim: "groups", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "", @@ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: false, EmailClaim: "email", GroupsClaim: "roles", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", @@ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { AllowUnverified: false, EmailClaim: "email", GroupsClaim: "alskdjfsalkdjf", + UserClaim: "sub", ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", @@ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { PreferredUsername: "Jane Dobbs", }, }, + "Groups Claim Numeric values": { + IDToken: numericGroupsIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "groups", + UserClaim: "sub", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{"1", "2", "3"}, + PreferredUsername: "Jane Dobbs", + }, + }, + "Groups Claim string values": { + IDToken: defaultIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "email", + UserClaim: "sub", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{"janed@me.com"}, + PreferredUsername: "Jane Dobbs", + }, + }, } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { @@ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { rawIDToken, err := newSignedTestIDToken(tc.IDToken) g.Expect(err).ToNot(HaveOccurred()) - idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) - g.Expect(err).ToNot(HaveOccurred()) - - ss, err := provider.buildSessionFromClaims(idToken) + ss, err := provider.buildSessionFromClaims(rawIDToken, "") if err != nil { g.Expect(err).To(Equal(tc.ExpectedError)) } @@ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) { t.Run(testName, func(t *testing.T) { g := NewWithT(t) + // Ensure that the ID token in the session is valid (signed and contains a nonce) + // as the nonce claim is extracted to compare with the session nonce + rawIDToken, err := newSignedTestIDToken(tc.IDToken) + g.Expect(err).ToNot(HaveOccurred()) + tc.Session.IDToken = rawIDToken + verificationOptions := &internaloidc.IDTokenVerificationOptions{ AudienceClaims: []string{"aud"}, ClientID: oidcClientID, @@ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) { ), verificationOptions), } - rawIDToken, err := newSignedTestIDToken(tc.IDToken) - g.Expect(err).ToNot(HaveOccurred()) - - idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) - g.Expect(err).ToNot(HaveOccurred()) - - err = provider.checkNonce(tc.Session, idToken) - if err != nil { + if err := provider.checkNonce(tc.Session); err != nil { g.Expect(err).To(Equal(tc.ExpectedError)) } else { g.Expect(err).ToNot(HaveOccurred()) @@ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) { }) } } - -func TestProviderData_extractGroups(t *testing.T) { - testCases := map[string]struct { - Claims map[string]interface{} - GroupsClaim string - ExpectedGroups []string - }{ - "Standard String Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": []interface{}{"three", "string", "groups"}, - }, - GroupsClaim: "groups", - ExpectedGroups: []string{"three", "string", "groups"}, - }, - "Different Claim Name": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "roles": []interface{}{"three", "string", "roles"}, - }, - GroupsClaim: "roles", - ExpectedGroups: []string{"three", "string", "roles"}, - }, - "Numeric Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": []interface{}{1, 2, 3}, - }, - GroupsClaim: "groups", - ExpectedGroups: []string{"1", "2", "3"}, - }, - "Complex Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": []interface{}{ - map[string]interface{}{ - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - 12345, - "Just::A::String", - }, - }, - GroupsClaim: "groups", - ExpectedGroups: []string{ - "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", - "12345", - "Just::A::String", - }, - }, - "Missing Groups Claim Returns Nil": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - }, - GroupsClaim: "groups", - ExpectedGroups: nil, - }, - "Non List Groups": { - Claims: map[string]interface{}{ - "email": "this@does.not.matter.com", - "groups": "singleton", - }, - GroupsClaim: "groups", - ExpectedGroups: []string{"singleton"}, - }, - } - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - g := NewWithT(t) - - verificationOptions := &internaloidc.IDTokenVerificationOptions{ - AudienceClaims: []string{"aud"}, - ClientID: oidcClientID, - } - provider := &ProviderData{ - Verifier: internaloidc.NewVerifier(oidc.NewVerifier( - oidcIssuer, - mockJWKS{}, - &oidc.Config{ClientID: oidcClientID}, - ), verificationOptions), - } - provider.GroupsClaim = tc.GroupsClaim - - groups := provider.extractGroups(tc.Claims) - if tc.ExpectedGroups != nil { - g.Expect(groups).To(Equal(tc.ExpectedGroups)) - } else { - g.Expect(groups).To(BeNil()) - } - }) - } -} diff --git a/providers/util.go b/providers/util.go index e6fdc34..0507dde 100644 --- a/providers/util.go +++ b/providers/util.go @@ -6,7 +6,6 @@ import ( "net/http" "net/url" - "github.com/bitly/go-simplejson" "golang.org/x/oauth2" ) @@ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) { } return string(jsonGroup), nil } - -// coerceArray extracts a field from simplejson.Json that might be a -// singleton or a list and coerces it into a list. -func coerceArray(sj *simplejson.Json, key string) []interface{} { - array, err := sj.Get(key).Array() - if err == nil { - return array - } - - single := sj.Get(key).Interface() - if single == nil { - return nil - } - return []interface{}{single} -}