Support non-list and complex groups

This commit is contained in:
Nick Meves 2020-11-29 14:58:01 -08:00
parent eb56f24d6d
commit ea5b8cc21f
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
6 changed files with 166 additions and 36 deletions

View File

@ -48,6 +48,7 @@
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe)
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves)
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"reflect"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@ -59,7 +60,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
}
// Try to get missing emails or groups from a profileURL
if s.Email == "" || len(s.Groups) == 0 {
if s.Email == "" || s.Groups == nil {
err := p.callProfileURL(ctx, s)
if err != nil {
logger.Errorf("Warning: Profile URL request failed: %v", err)
@ -90,16 +91,15 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt
s.Email = email
}
// Handle array & singleton groups cases
if len(s.Groups) == 0 {
groups, err := respJSON.Get(p.GroupsClaim).StringArray()
if err == nil {
s.Groups = groups
} else {
group, err := respJSON.Get(p.GroupsClaim).String()
if err == nil {
s.Groups = []string{group}
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
formatted, err := formatGroup(group)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(group), err)
continue
}
s.Groups = append(s.Groups, formatted)
}
}

View File

@ -68,7 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) {
return u, s
}
func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) {
func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
redeemURL, server := newOIDCServer(body)
provider := newOIDCProvider(redeemURL)
return server, provider
@ -85,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
IDToken: idToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
defer server.Close()
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
@ -108,7 +108,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
IDToken: idToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
provider.EmailClaim = "phone_number"
defer server.Close()
@ -247,7 +247,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
@ -268,6 +268,89 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
RefreshToken: refreshToken,
},
},
"Missing Groups with Complex Groups in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Singleton Complex Group in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Empty Groups Claims": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "already",
@ -297,7 +380,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
@ -346,7 +429,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
jsonResp, err := json.Marshal(tc.ProfileJSON)
assert.NoError(t, err)
server, provider := newTestSetup(jsonResp)
server, provider := newTestOIDCSetup(jsonResp)
provider.ProfileURL, err = url.Parse(server.URL)
assert.NoError(t, err)
@ -371,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
RefreshToken: refreshToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
defer server.Close()
existingSession := &sessions.SessionState{
@ -405,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
IDToken: idToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
defer server.Close()
existingSession := &sessions.SessionState{
@ -433,7 +516,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
GroupsClaim string
ExpectedUser string
ExpectedEmail string
ExpectedGroups interface{}
ExpectedGroups []string
}{
"Default IDToken": {
IDToken: defaultIDToken,
@ -447,7 +530,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
GroupsClaim: "groups",
ExpectedUser: "123456789",
ExpectedEmail: "123456789",
ExpectedGroups: []string{},
ExpectedGroups: nil,
},
"Custom Groups Claim": {
IDToken: defaultIDToken,
@ -466,7 +549,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
server, provider := newTestSetup([]byte(`{}`))
server, provider := newTestOIDCSetup([]byte(`{}`))
provider.GroupsClaim = tc.GroupsClaim
defer server.Close()
@ -478,9 +561,9 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
assert.Equal(t, tc.ExpectedUser, ss.User)
assert.Equal(t, tc.ExpectedEmail, ss.Email)
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
assert.Equal(t, rawIDToken, ss.IDToken)
assert.Equal(t, rawIDToken, ss.AccessToken)
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
assert.Equal(t, "", ss.RefreshToken)
})
}

View File

@ -189,12 +189,27 @@ func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
return claims, nil
}
// extractGroups extracts groups from a claim to a list in a type safe manner
// extractGroups extracts groups from a claim to a list in a type safe manner.
// If the claim isn't present, `nil` is returned. If the groups claim is
// present but empty, `[]string{}` is returned.
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
rawClaim, ok := claims[p.GroupsClaim]
if !ok {
return nil
}
// Handle traditional list-based groups as well as non-standard singleton
// based groups. Both variants support complex objects if needed.
var claimGroups []interface{}
switch raw := rawClaim.(type) {
case []interface{}:
claimGroups = raw
case interface{}:
claimGroups = []interface{}{raw}
}
groups := []string{}
rawGroups, ok := claims[p.GroupsClaim].([]interface{})
if rawGroups != nil && ok {
for _, rawGroup := range rawGroups {
for _, rawGroup := range claimGroups {
formattedGroup, err := formatGroup(rawGroup)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
@ -203,6 +218,5 @@ func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
}
groups = append(groups, formattedGroup)
}
}
return groups
}

View File

@ -300,7 +300,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{},
Groups: nil,
PreferredUsername: "Jane Dobbs",
},
},
@ -386,12 +386,20 @@ func TestProviderData_extractGroups(t *testing.T) {
"Just::A::String",
},
},
"Missing Groups": {
"Missing Groups Claim Returns Nil": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
},
GroupsClaim: "groups",
ExpectedGroups: []string{},
ExpectedGroups: nil,
},
"Non List Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": "singleton",
},
GroupsClaim: "groups",
ExpectedGroups: []string{"singleton"},
},
}
for testName, tc := range testCases {
@ -408,7 +416,11 @@ func TestProviderData_extractGroups(t *testing.T) {
provider.GroupsClaim = tc.GroupsClaim
groups := provider.extractGroups(tc.Claims)
if tc.ExpectedGroups != nil {
g.Expect(groups).To(Equal(tc.ExpectedGroups))
} else {
g.Expect(groups).To(BeNil())
}
})
}
}

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"github.com/bitly/go-simplejson"
"golang.org/x/oauth2"
)
@ -59,6 +60,8 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
return a
}
// getIDToken extracts an IDToken stored in the `Extra` fields of an
// oauth2.Token
func getIDToken(token *oauth2.Token) string {
idToken, ok := token.Extra("id_token").(string)
if !ok {
@ -67,6 +70,8 @@ func getIDToken(token *oauth2.Token) string {
return idToken
}
// formatGroup coerces an OIDC groups claim into a string
// If it is non-string, marshal it into JSON.
func formatGroup(rawGroup interface{}) (string, error) {
group, ok := rawGroup.(string)
if !ok {
@ -78,3 +83,18 @@ func formatGroup(rawGroup interface{}) (string, error) {
}
return group, nil
}
// coerceArray extracts a field from simplejson.Json that might be a
// singleton or a list and coerces it into a list.
func coerceArray(sj *simplejson.Json, key string) []interface{} {
array, err := sj.Get(key).Array()
if err == nil {
return array
}
single := sj.Get(key).Interface()
if single == nil {
return nil
}
return []interface{}{single}
}