Integrate claim extractor into providers
This commit is contained in:
parent
537e596904
commit
967051314e
@ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt
|
||||
}
|
||||
|
||||
func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error {
|
||||
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("could not extract claims: %v", err)
|
||||
}
|
||||
claims, err := p.getClaims(idToken)
|
||||
|
||||
upn, found, err := claims.GetClaim(adfsUPNClaim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||
return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err)
|
||||
}
|
||||
upn := claims.raw[adfsUPNClaim]
|
||||
if upn != nil {
|
||||
|
||||
if found && fmt.Sprint(upn) != "" {
|
||||
s.Email = fmt.Sprint(upn)
|
||||
}
|
||||
return nil
|
||||
|
@ -79,7 +79,7 @@ func testADFSBackend() *httptest.Server {
|
||||
{
|
||||
"access_token": "my_access_token",
|
||||
"id_token": "my_id_token",
|
||||
"refresh_token": "my_refresh_token"
|
||||
"refresh_token": "my_refresh_token"
|
||||
}
|
||||
`
|
||||
userInfo := `
|
||||
@ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() {
|
||||
Context("with valid token", func() {
|
||||
It("should not throw an error", func() {
|
||||
rawIDToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||
idToken, err := p.Verifier.Verify(context.Background(), rawIDToken)
|
||||
Expect(err).To(BeNil())
|
||||
session, err := p.buildSessionFromClaims(idToken)
|
||||
session, err := p.buildSessionFromClaims(rawIDToken, "")
|
||||
Expect(err).To(BeNil())
|
||||
session.IDToken = rawIDToken
|
||||
err = p.EnrichSession(context.Background(), session)
|
||||
|
@ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState {
|
||||
}
|
||||
|
||||
func IsAuthorizedInHeader(reqHeader http.Header) bool {
|
||||
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken)
|
||||
return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken)
|
||||
}
|
||||
|
||||
func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool {
|
||||
return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token)
|
||||
}
|
||||
|
||||
func IsAuthorizedInURL(reqURL *url.URL) bool {
|
||||
return reqURL.Query().Get("access_token") == authorizedAccessToken
|
||||
}
|
||||
|
||||
func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
return reqURL.Query().Get("refresh_token") == token
|
||||
}
|
||||
|
@ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
|
||||
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
|
||||
p.ValidateURL = p.ProfileURL
|
||||
}
|
||||
p.getAuthorizationHeaderFunc = makeAzureHeader
|
||||
|
||||
return &AzureProvider{
|
||||
ProviderData: p,
|
||||
@ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
||||
session.CreatedAtNow()
|
||||
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken)
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken)
|
||||
|
||||
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
|
||||
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
|
||||
@ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*
|
||||
}
|
||||
|
||||
if session.Email == "" {
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken)
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken)
|
||||
if err == nil && email != "" {
|
||||
session.Email = email
|
||||
} else {
|
||||
@ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err
|
||||
|
||||
// verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token
|
||||
// when oidc verifier is configured
|
||||
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) {
|
||||
func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) {
|
||||
email := ""
|
||||
|
||||
if token != "" && p.Verifier != nil {
|
||||
token, err := p.Verifier.Verify(ctx, token)
|
||||
if rawIDToken != "" && p.Verifier != nil {
|
||||
_, err := p.Verifier.Verify(ctx, rawIDToken)
|
||||
// due to issues mentioned above, id_token may not be signed by AAD
|
||||
if err == nil {
|
||||
claims, err := p.getClaims(token)
|
||||
s, err := p.buildSessionFromClaims(rawIDToken, accessToken)
|
||||
if err == nil {
|
||||
email = claims.Email
|
||||
email = s.Email
|
||||
} else {
|
||||
logger.Printf("unable to get claims from token: %v", err)
|
||||
}
|
||||
@ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
s.CreatedAtNow()
|
||||
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
||||
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken)
|
||||
email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken)
|
||||
|
||||
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
|
||||
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
|
||||
@ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess
|
||||
}
|
||||
|
||||
if s.Email == "" {
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken)
|
||||
email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken)
|
||||
if err == nil && email != "" {
|
||||
s.Email = email
|
||||
} else {
|
||||
|
@ -13,9 +13,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
|
||||
|
||||
@ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) {
|
||||
assert.Equal(t, "openid", p.Data().Scope)
|
||||
}
|
||||
|
||||
func testAzureBackend(payload string) *httptest.Server {
|
||||
return testAzureBackendWithError(payload, false)
|
||||
func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server {
|
||||
return testAzureBackendWithError(payload, accessToken, refreshToken, false)
|
||||
}
|
||||
|
||||
func testAzureBackendWithError(payload string, injectError bool) *httptest.Server {
|
||||
func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server {
|
||||
path := "/v1.0/me"
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
@ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
w.Write([]byte(payload))
|
||||
} else if !IsAuthorizedInHeader(r.Header) {
|
||||
} else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) &&
|
||||
!isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) {
|
||||
w.WriteHeader(403)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
@ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) {
|
||||
host string
|
||||
)
|
||||
if testCase.PayloadFromAzureBackend != "" {
|
||||
b = testAzureBackend(testCase.PayloadFromAzureBackend)
|
||||
b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
@ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) {
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
assert.NoError(t, err)
|
||||
|
||||
b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError)
|
||||
b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError)
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
@ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
|
||||
|
||||
func TestAzureProviderRefresh(t *testing.T) {
|
||||
email := "foo@example.com"
|
||||
subject := "foo"
|
||||
idToken := idTokenClaims{
|
||||
StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"},
|
||||
Email: email}
|
||||
Email: email,
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532",
|
||||
Subject: subject,
|
||||
},
|
||||
}
|
||||
idTokenString, err := newSignedTestIDToken(idToken)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z")
|
||||
assert.NoError(t, err)
|
||||
|
||||
newAccessToken := "new_some_access_token"
|
||||
payload := azureOAuthPayload{
|
||||
IDToken: idTokenString,
|
||||
RefreshToken: "new_some_refresh_token",
|
||||
AccessToken: "new_some_access_token",
|
||||
AccessToken: newAccessToken,
|
||||
ExpiresOn: timestamp.Unix(),
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
assert.NoError(t, err)
|
||||
b := testAzureBackend(string(payloadBytes))
|
||||
|
||||
refreshToken := "some_refresh_token"
|
||||
b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken)
|
||||
defer b.Close()
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
expires := time.Now().Add(time.Duration(-1) * time.Hour)
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires}
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.True(t, refreshed)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "new_some_access_token", session.AccessToken)
|
||||
assert.Equal(t, newAccessToken, session.AccessToken)
|
||||
assert.Equal(t, "new_some_refresh_token", session.RefreshToken)
|
||||
assert.Equal(t, idTokenString, session.IDToken)
|
||||
assert.Equal(t, email, session.Email)
|
||||
|
@ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider {
|
||||
validateURL: digitalOceanDefaultProfileURL,
|
||||
scope: digitalOceanDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
|
||||
return &DigitalOceanProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
|
@ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
|
||||
validateURL: facebookDefaultProfileURL,
|
||||
scope: facebookDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
return &FacebookProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
|
@ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
|
||||
validateURL: linkedinDefaultValidateURL,
|
||||
scope: linkedinDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeLinkedInHeader
|
||||
|
||||
return &LinkedInProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,5 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
)
|
||||
|
||||
// NextcloudProvider represents an Nextcloud based Identity Provider
|
||||
type NextcloudProvider struct {
|
||||
*ProviderData
|
||||
@ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud"
|
||||
// NewNextcloudProvider initiates a new NextcloudProvider
|
||||
func NewNextcloudProvider(p *ProviderData) *NextcloudProvider {
|
||||
p.ProviderName = nextCloudProviderName
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
if p.EmailClaim == OIDCEmailClaim {
|
||||
// This implies the email claim has not been overridden, we should set a default
|
||||
// for this provider
|
||||
p.EmailClaim = "ocs.data.email"
|
||||
}
|
||||
return &NextcloudProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error making request: %v", err)
|
||||
}
|
||||
|
||||
email, err := json.Get("ocs").Get("data").Get("email").String()
|
||||
return email, err
|
||||
}
|
||||
|
@ -1,18 +1,13 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const formatJSON = "format=json"
|
||||
const userPath = "/ocs/v2.php/cloud/user"
|
||||
|
||||
func testNextcloudProvider(hostname string) *NextcloudProvider {
|
||||
p := NewNextcloudProvider(
|
||||
@ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider {
|
||||
return p
|
||||
}
|
||||
|
||||
func testNextcloudBackend(payload string) *httptest.Server {
|
||||
path := userPath
|
||||
query := formatJSON
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != path || r.URL.RawQuery != query {
|
||||
w.WriteHeader(404)
|
||||
} else if !IsAuthorizedInHeader(r.Header) {
|
||||
w.WriteHeader(403)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(payload))
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestNextcloudProviderDefaults(t *testing.T) {
|
||||
p := testNextcloudProvider("")
|
||||
assert.NotEqual(t, nil, p)
|
||||
@ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) {
|
||||
assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON,
|
||||
p.Data().ValidateURL.String())
|
||||
}
|
||||
|
||||
func TestNextcloudProviderGetEmailAddress(t *testing.T) {
|
||||
b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testNextcloudProvider(bURL.Host)
|
||||
p.ValidateURL.Path = userPath
|
||||
p.ValidateURL.RawQuery = formatJSON
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
}
|
||||
|
||||
// Note that trying to trigger the "failed building request" case is not
|
||||
// practical, since the only way it can fail is if the URL fails to parse.
|
||||
func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
b := testNextcloudBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testNextcloudProvider(bURL.Host)
|
||||
p.ValidateURL.Path = userPath
|
||||
p.ValidateURL.RawQuery = formatJSON
|
||||
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
}
|
||||
|
||||
func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testNextcloudBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testNextcloudProvider(bURL.Host)
|
||||
p.ValidateURL.Path = userPath
|
||||
p.ValidateURL.RawQuery = formatJSON
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
}
|
||||
|
@ -5,12 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -24,6 +22,8 @@ type OIDCProvider struct {
|
||||
// NewOIDCProvider initiates a new OIDCProvider
|
||||
func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
||||
p.ProviderName = "OpenID Connect"
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
|
||||
return &OIDCProvider{
|
||||
ProviderData: p,
|
||||
SkipNonce: true,
|
||||
@ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s
|
||||
// EnrichSession is called after Redeem to allow providers to enrich session fields
|
||||
// such as User, Email, Groups with provider specific API calls.
|
||||
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||
if p.ProfileURL.String() == "" {
|
||||
if s.Email == "" {
|
||||
return errors.New("id_token did not contain an email and profileURL is not defined")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to get missing emails or groups from a profileURL
|
||||
if s.Email == "" || s.Groups == nil {
|
||||
err := p.enrichFromProfileURL(ctx, s)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: Profile URL request failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If a mandatory email wasn't set, error at this point.
|
||||
if s.Email == "" {
|
||||
return errors.New("neither the id_token nor the profileURL set an email")
|
||||
@ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta
|
||||
return nil
|
||||
}
|
||||
|
||||
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
|
||||
// an OIDC profile URL
|
||||
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
|
||||
respJSON, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
email, err := respJSON.Get(p.EmailClaim).String()
|
||||
if err == nil && s.Email == "" {
|
||||
s.Email = email
|
||||
}
|
||||
|
||||
if len(s.Groups) > 0 {
|
||||
return nil
|
||||
}
|
||||
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
|
||||
formatted, err := formatGroup(group)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(group), err)
|
||||
continue
|
||||
}
|
||||
s.Groups = append(s.Groups, formatted)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSession checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
if err != nil {
|
||||
logger.Errorf("id_token verification failed: %v", err)
|
||||
return false
|
||||
@ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS
|
||||
if p.SkipNonce {
|
||||
return true
|
||||
}
|
||||
err = p.checkNonce(s, idToken)
|
||||
err = p.checkNonce(s)
|
||||
if err != nil {
|
||||
logger.Errorf("nonce verification failed: %v", err)
|
||||
return false
|
||||
@ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss, err := p.buildSessionFromClaims(idToken)
|
||||
ss, err := p.buildSessionFromClaims(token, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string)
|
||||
// createSession takes an oauth2.Token and creates a SessionState from it.
|
||||
// It alters behavior if called from Redeem vs Refresh
|
||||
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
|
||||
idToken, err := p.verifyIDToken(ctx, token)
|
||||
_, err := p.verifyIDToken(ctx, token)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrMissingIDToken:
|
||||
@ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r
|
||||
}
|
||||
}
|
||||
|
||||
ss, err := p.buildSessionFromClaims(idToken)
|
||||
rawIDToken := getIDToken(token)
|
||||
ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss.AccessToken = token.AccessToken
|
||||
ss.RefreshToken = token.RefreshToken
|
||||
ss.IDToken = getIDToken(token)
|
||||
ss.IDToken = rawIDToken
|
||||
|
||||
ss.CreatedAtNow()
|
||||
ss.SetExpiresOn(token.Expiry)
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
||||
Scope: "openid profile offline_access",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
@ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
||||
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
||||
}
|
||||
|
||||
func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ExistingSession *sessions.SessionState
|
||||
EmailClaim string
|
||||
GroupsClaim string
|
||||
ProfileJSON map[string]interface{}
|
||||
ExpectedError error
|
||||
ExpectedSession *sessions.SessionState
|
||||
}{
|
||||
"Already Populated": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "found@email.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "found@email.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
|
||||
"Missing Email Only in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "found@email.com",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "found@email.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "weird",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"weird": "weird@claim.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "weird@claim.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email not in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"new", "thing"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Complex Groups in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []map[string]interface{}{
|
||||
{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Singleton Complex Group in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Empty Groups Claims": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "roles",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"roles": []string{"new", "thing", "roles"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"new", "thing", "roles"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups String Profile URL Response": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"singleton"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups in both Claims and Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
jsonResp, err := json.Marshal(tc.ProfileJSON)
|
||||
assert.NoError(t, err)
|
||||
|
||||
server, provider := newTestOIDCSetup(jsonResp)
|
||||
provider.ProfileURL, err = url.Parse(server.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider.EmailClaim = tc.EmailClaim
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
defer server.Close()
|
||||
|
||||
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
|
||||
assert.Equal(t, tc.ExpectedError, err)
|
||||
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
|
||||
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||
@ -565,11 +238,15 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
ExpectedGroups: []string{"test:c", "test:d"},
|
||||
},
|
||||
"Complex Groups Claim": {
|
||||
IDToken: complexGroupsIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "complex@claims.com",
|
||||
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: complexGroupsIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "complex@claims.com",
|
||||
ExpectedGroups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
|
@ -5,20 +5,23 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
OIDCEmailClaim = "email"
|
||||
OIDCGroupsClaim = "groups"
|
||||
// This is not exported as it's not currently user configurable
|
||||
oidcUserClaim = "sub"
|
||||
)
|
||||
|
||||
var OIDCAudienceClaims = []string{"aud"}
|
||||
@ -52,6 +55,8 @@ type ProviderData struct {
|
||||
// Universal Group authorization data structure
|
||||
// any provider can set to consume
|
||||
AllowedGroups map[string]struct{}
|
||||
|
||||
getAuthorizationHeaderFunc func(string) http.Header
|
||||
}
|
||||
|
||||
// Data returns the ProviderData
|
||||
@ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) {
|
||||
if p.Scope == "" {
|
||||
p.Scope = defaults.scope
|
||||
}
|
||||
|
||||
if p.UserClaim == "" {
|
||||
p.UserClaim = oidcUserClaim
|
||||
}
|
||||
}
|
||||
|
||||
// defaultURL will set return a default value if the given value is not set.
|
||||
@ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
|
||||
// OIDC compliant
|
||||
// ****************************************************************************
|
||||
|
||||
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
|
||||
type OIDCClaims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"-"`
|
||||
Groups []string `json:"-"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
Nonce string `json:"nonce"`
|
||||
|
||||
raw map[string]interface{}
|
||||
}
|
||||
|
||||
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||
rawIDToken := getIDToken(token)
|
||||
if strings.TrimSpace(rawIDToken) == "" {
|
||||
@ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (
|
||||
|
||||
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
|
||||
// with non-Token related fields.
|
||||
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) {
|
||||
ss := &sessions.SessionState{}
|
||||
|
||||
if idToken == nil {
|
||||
if rawIDToken == "" {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
claims, err := p.getClaims(idToken)
|
||||
extractor, err := p.getClaimExtractor(rawIDToken, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss.User = claims.Subject
|
||||
ss.Email = claims.Email
|
||||
ss.Groups = claims.Groups
|
||||
|
||||
// Allow specialized providers that embed OIDCProvider to control the User
|
||||
// claim. Not exposed as a configuration flag to generic OIDC provider
|
||||
// users (yet).
|
||||
if p.UserClaim != "" {
|
||||
user, ok := claims.raw[p.UserClaim].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim)
|
||||
// Use a slice of a struct (vs map) here in case the same claim is used twice
|
||||
for _, c := range []struct {
|
||||
claim string
|
||||
dst interface{}
|
||||
}{
|
||||
{p.UserClaim, &ss.User},
|
||||
{p.EmailClaim, &ss.Email},
|
||||
{p.GroupsClaim, &ss.Groups},
|
||||
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||
{"preferred_username", &ss.PreferredUsername},
|
||||
} {
|
||||
if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ss.User = user
|
||||
}
|
||||
|
||||
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||
if pref, ok := claims.raw["preferred_username"].(string); ok {
|
||||
ss.PreferredUsername = pref
|
||||
}
|
||||
|
||||
// `email_verified` must be present and explicitly set to `false` to be
|
||||
// considered unverified.
|
||||
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
|
||||
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
|
||||
var verified bool
|
||||
exists, err := extractor.GetClaimInto("email_verified", &verified)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if verifyEmail && exists && !verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email)
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// getClaims extracts IDToken claims into an OIDCClaims
|
||||
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
|
||||
claims := &OIDCClaims{}
|
||||
|
||||
// Extract default claims.
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
|
||||
}
|
||||
// Extract custom claims.
|
||||
if err := idToken.Claims(&claims.raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
||||
func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) {
|
||||
extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not initialise claim extractor: %v", err)
|
||||
}
|
||||
|
||||
email := claims.raw[p.EmailClaim]
|
||||
if email != nil {
|
||||
claims.Email = fmt.Sprint(email)
|
||||
}
|
||||
claims.Groups = p.extractGroups(claims.raw)
|
||||
|
||||
return claims, nil
|
||||
return extractor, nil
|
||||
}
|
||||
|
||||
// checkNonce compares the session's nonce with the IDToken's nonce claim
|
||||
func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error {
|
||||
claims, err := p.getClaims(idToken)
|
||||
func (p *ProviderData) checkNonce(s *sessions.SessionState) error {
|
||||
extractor, err := p.getClaimExtractor(s.IDToken, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("id_token claims extraction failed: %v", err)
|
||||
}
|
||||
if !s.CheckNonce(claims.Nonce) {
|
||||
var nonce string
|
||||
if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil {
|
||||
return fmt.Errorf("could not extract nonce from ID Token: %v", err)
|
||||
}
|
||||
|
||||
if !s.CheckNonce(nonce) {
|
||||
return errors.New("id_token nonce claim does not match the session nonce")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractGroups extracts groups from a claim to a list in a type safe manner.
|
||||
// If the claim isn't present, `nil` is returned. If the groups claim is
|
||||
// present but empty, `[]string{}` is returned.
|
||||
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
|
||||
rawClaim, ok := claims[p.GroupsClaim]
|
||||
if !ok {
|
||||
return nil
|
||||
func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header {
|
||||
if p.getAuthorizationHeaderFunc != nil && accessToken != "" {
|
||||
return p.getAuthorizationHeaderFunc(accessToken)
|
||||
}
|
||||
|
||||
// Handle traditional list-based groups as well as non-standard singleton
|
||||
// based groups. Both variants support complex objects if needed.
|
||||
var claimGroups []interface{}
|
||||
switch raw := rawClaim.(type) {
|
||||
case []interface{}:
|
||||
claimGroups = raw
|
||||
case interface{}:
|
||||
claimGroups = []interface{}{raw}
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
for _, rawGroup := range claimGroups {
|
||||
formattedGroup, err := formatGroup(rawGroup)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(rawGroup), err)
|
||||
continue
|
||||
}
|
||||
groups = append(groups, formattedGroup)
|
||||
}
|
||||
return groups
|
||||
return nil
|
||||
}
|
||||
|
@ -60,16 +60,30 @@ var (
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
numericGroupsIDToken = idTokenClaims{
|
||||
Name: "Jane Dobbs",
|
||||
Email: "janed@me.com",
|
||||
Phone: "+4798765432",
|
||||
Picture: "http://mugbook.com/janed/me.jpg",
|
||||
Groups: []interface{}{1, 2, 3},
|
||||
Roles: []string{"test:c", "test:d"},
|
||||
Verified: &verified,
|
||||
Nonce: encryption.HashNonce([]byte(oidcNonce)),
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
complexGroupsIDToken = idTokenClaims{
|
||||
Name: "Complex Claim",
|
||||
Email: "complex@claims.com",
|
||||
Phone: "+5439871234",
|
||||
Picture: "http://mugbook.com/complex/claims.jpg",
|
||||
Groups: []map[string]interface{}{
|
||||
{
|
||||
Groups: []interface{}{
|
||||
map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
12345,
|
||||
"Just::A::String",
|
||||
},
|
||||
Roles: []string{"test:simple", "test:roles"},
|
||||
Verified: &verified,
|
||||
@ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
@ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "unverified@email.com",
|
||||
@ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "complex@claims.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
User: "123456789",
|
||||
Email: "complex@claims.com",
|
||||
Groups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
PreferredUsername: "Complex Claim",
|
||||
},
|
||||
},
|
||||
@ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"User Claim Invalid": {
|
||||
"User Claim switched to non string": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: true,
|
||||
UserClaim: "groups",
|
||||
UserClaim: "roles",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedError: errors.New("unable to extract custom UserClaim (groups)"),
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "[\"test:c\",\"test:d\"]",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Email Claim Switched": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "phone_number",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "+4025205729",
|
||||
@ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "roles",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "[test:c test:d]",
|
||||
Email: "[\"test:c\",\"test:d\"]",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Mystery Man",
|
||||
},
|
||||
@ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "aksjdfhjksadh",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "",
|
||||
@ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "roles",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
@ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "alskdjfsalkdjf",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
@ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Groups Claim Numeric values": {
|
||||
IDToken: numericGroupsIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"1", "2", "3"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Groups Claim string values": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "email",
|
||||
UserClaim: "sub",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"janed@me.com"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
@ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ss, err := provider.buildSessionFromClaims(idToken)
|
||||
ss, err := provider.buildSessionFromClaims(rawIDToken, "")
|
||||
if err != nil {
|
||||
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||
}
|
||||
@ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
// Ensure that the ID token in the session is valid (signed and contains a nonce)
|
||||
// as the nonce claim is extracted to compare with the session nonce
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
tc.Session.IDToken = rawIDToken
|
||||
|
||||
verificationOptions := &internaloidc.IDTokenVerificationOptions{
|
||||
AudienceClaims: []string{"aud"},
|
||||
ClientID: oidcClientID,
|
||||
@ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) {
|
||||
), verificationOptions),
|
||||
}
|
||||
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = provider.checkNonce(tc.Session, idToken)
|
||||
if err != nil {
|
||||
if err := provider.checkNonce(tc.Session); err != nil {
|
||||
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||
} else {
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
@ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderData_extractGroups(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
Claims map[string]interface{}
|
||||
GroupsClaim string
|
||||
ExpectedGroups []string
|
||||
}{
|
||||
"Standard String Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{"three", "string", "groups"},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"three", "string", "groups"},
|
||||
},
|
||||
"Different Claim Name": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"roles": []interface{}{"three", "string", "roles"},
|
||||
},
|
||||
GroupsClaim: "roles",
|
||||
ExpectedGroups: []string{"three", "string", "roles"},
|
||||
},
|
||||
"Numeric Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{1, 2, 3},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"1", "2", "3"},
|
||||
},
|
||||
"Complex Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{
|
||||
map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
12345,
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
"Missing Groups Claim Returns Nil": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: nil,
|
||||
},
|
||||
"Non List Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"singleton"},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
verificationOptions := &internaloidc.IDTokenVerificationOptions{
|
||||
AudienceClaims: []string{"aud"},
|
||||
ClientID: oidcClientID,
|
||||
}
|
||||
provider := &ProviderData{
|
||||
Verifier: internaloidc.NewVerifier(oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
&oidc.Config{ClientID: oidcClientID},
|
||||
), verificationOptions),
|
||||
}
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
|
||||
groups := provider.extractGroups(tc.Claims)
|
||||
if tc.ExpectedGroups != nil {
|
||||
g.Expect(groups).To(Equal(tc.ExpectedGroups))
|
||||
} else {
|
||||
g.Expect(groups).To(BeNil())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) {
|
||||
}
|
||||
return string(jsonGroup), nil
|
||||
}
|
||||
|
||||
// coerceArray extracts a field from simplejson.Json that might be a
|
||||
// singleton or a list and coerces it into a list.
|
||||
func coerceArray(sj *simplejson.Json, key string) []interface{} {
|
||||
array, err := sj.Get(key).Array()
|
||||
if err == nil {
|
||||
return array
|
||||
}
|
||||
|
||||
single := sj.Get(key).Interface()
|
||||
if single == nil {
|
||||
return nil
|
||||
}
|
||||
return []interface{}{single}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user