Support non-list and complex groups
This commit is contained in:
parent
eb56f24d6d
commit
ea5b8cc21f
@ -48,6 +48,7 @@
|
||||
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
|
||||
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
|
||||
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
|
||||
- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe)
|
||||
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves)
|
||||
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
|
||||
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
@ -59,7 +60,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
|
||||
}
|
||||
|
||||
// Try to get missing emails or groups from a profileURL
|
||||
if s.Email == "" || len(s.Groups) == 0 {
|
||||
if s.Email == "" || s.Groups == nil {
|
||||
err := p.callProfileURL(ctx, s)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: Profile URL request failed: %v", err)
|
||||
@ -90,16 +91,15 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt
|
||||
s.Email = email
|
||||
}
|
||||
|
||||
// Handle array & singleton groups cases
|
||||
if len(s.Groups) == 0 {
|
||||
groups, err := respJSON.Get(p.GroupsClaim).StringArray()
|
||||
if err == nil {
|
||||
s.Groups = groups
|
||||
} else {
|
||||
group, err := respJSON.Get(p.GroupsClaim).String()
|
||||
if err == nil {
|
||||
s.Groups = []string{group}
|
||||
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
|
||||
formatted, err := formatGroup(group)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(group), err)
|
||||
continue
|
||||
}
|
||||
s.Groups = append(s.Groups, formatted)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) {
|
||||
return u, s
|
||||
}
|
||||
|
||||
func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) {
|
||||
func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
|
||||
redeemURL, server := newOIDCServer(body)
|
||||
provider := newOIDCProvider(redeemURL)
|
||||
return server, provider
|
||||
@ -85,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
|
||||
IDToken: idToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
defer server.Close()
|
||||
|
||||
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
||||
@ -108,7 +108,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
||||
IDToken: idToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
provider.EmailClaim = "phone_number"
|
||||
defer server.Close()
|
||||
|
||||
@ -247,7 +247,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
@ -268,6 +268,89 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Complex Groups in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []map[string]interface{}{
|
||||
{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Singleton Complex Group in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Empty Groups Claims": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
@ -297,7 +380,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
@ -346,7 +429,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
jsonResp, err := json.Marshal(tc.ProfileJSON)
|
||||
assert.NoError(t, err)
|
||||
|
||||
server, provider := newTestSetup(jsonResp)
|
||||
server, provider := newTestOIDCSetup(jsonResp)
|
||||
provider.ProfileURL, err = url.Parse(server.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -371,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
RefreshToken: refreshToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
defer server.Close()
|
||||
|
||||
existingSession := &sessions.SessionState{
|
||||
@ -405,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
||||
IDToken: idToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
defer server.Close()
|
||||
|
||||
existingSession := &sessions.SessionState{
|
||||
@ -433,7 +516,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
GroupsClaim string
|
||||
ExpectedUser string
|
||||
ExpectedEmail string
|
||||
ExpectedGroups interface{}
|
||||
ExpectedGroups []string
|
||||
}{
|
||||
"Default IDToken": {
|
||||
IDToken: defaultIDToken,
|
||||
@ -447,7 +530,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "123456789",
|
||||
ExpectedGroups: []string{},
|
||||
ExpectedGroups: nil,
|
||||
},
|
||||
"Custom Groups Claim": {
|
||||
IDToken: defaultIDToken,
|
||||
@ -466,7 +549,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
server, provider := newTestSetup([]byte(`{}`))
|
||||
server, provider := newTestOIDCSetup([]byte(`{}`))
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
defer server.Close()
|
||||
|
||||
@ -478,9 +561,9 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
|
||||
assert.Equal(t, tc.ExpectedUser, ss.User)
|
||||
assert.Equal(t, tc.ExpectedEmail, ss.Email)
|
||||
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
|
||||
assert.Equal(t, rawIDToken, ss.IDToken)
|
||||
assert.Equal(t, rawIDToken, ss.AccessToken)
|
||||
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
})
|
||||
}
|
||||
|
@ -189,12 +189,27 @@ func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// extractGroups extracts groups from a claim to a list in a type safe manner
|
||||
// extractGroups extracts groups from a claim to a list in a type safe manner.
|
||||
// If the claim isn't present, `nil` is returned. If the groups claim is
|
||||
// present but empty, `[]string{}` is returned.
|
||||
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
|
||||
rawClaim, ok := claims[p.GroupsClaim]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle traditional list-based groups as well as non-standard singleton
|
||||
// based groups. Both variants support complex objects if needed.
|
||||
var claimGroups []interface{}
|
||||
switch raw := rawClaim.(type) {
|
||||
case []interface{}:
|
||||
claimGroups = raw
|
||||
case interface{}:
|
||||
claimGroups = []interface{}{raw}
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
rawGroups, ok := claims[p.GroupsClaim].([]interface{})
|
||||
if rawGroups != nil && ok {
|
||||
for _, rawGroup := range rawGroups {
|
||||
for _, rawGroup := range claimGroups {
|
||||
formattedGroup, err := formatGroup(rawGroup)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
@ -203,6 +218,5 @@ func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
|
||||
}
|
||||
groups = append(groups, formattedGroup)
|
||||
}
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
@ -300,7 +300,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{},
|
||||
Groups: nil,
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
@ -386,12 +386,20 @@ func TestProviderData_extractGroups(t *testing.T) {
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
"Missing Groups": {
|
||||
"Missing Groups Claim Returns Nil": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{},
|
||||
ExpectedGroups: nil,
|
||||
},
|
||||
"Non List Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"singleton"},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
@ -408,7 +416,11 @@ func TestProviderData_extractGroups(t *testing.T) {
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
|
||||
groups := provider.extractGroups(tc.Claims)
|
||||
if tc.ExpectedGroups != nil {
|
||||
g.Expect(groups).To(Equal(tc.ExpectedGroups))
|
||||
} else {
|
||||
g.Expect(groups).To(BeNil())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -59,6 +60,8 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
|
||||
return a
|
||||
}
|
||||
|
||||
// getIDToken extracts an IDToken stored in the `Extra` fields of an
|
||||
// oauth2.Token
|
||||
func getIDToken(token *oauth2.Token) string {
|
||||
idToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
@ -67,6 +70,8 @@ func getIDToken(token *oauth2.Token) string {
|
||||
return idToken
|
||||
}
|
||||
|
||||
// formatGroup coerces an OIDC groups claim into a string
|
||||
// If it is non-string, marshal it into JSON.
|
||||
func formatGroup(rawGroup interface{}) (string, error) {
|
||||
group, ok := rawGroup.(string)
|
||||
if !ok {
|
||||
@ -78,3 +83,18 @@ func formatGroup(rawGroup interface{}) (string, error) {
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// coerceArray extracts a field from simplejson.Json that might be a
|
||||
// singleton or a list and coerces it into a list.
|
||||
func coerceArray(sj *simplejson.Json, key string) []interface{} {
|
||||
array, err := sj.Get(key).Array()
|
||||
if err == nil {
|
||||
return array
|
||||
}
|
||||
|
||||
single := sj.Get(key).Interface()
|
||||
if single == nil {
|
||||
return nil
|
||||
}
|
||||
return []interface{}{single}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user