Add claim extractor provider util
This commit is contained in:
parent
44dc3cad77
commit
537e596904
1
go.mod
1
go.mod
@ -23,6 +23,7 @@ require (
|
||||
github.com/onsi/gomega v1.10.2
|
||||
github.com/pierrec/lz4 v2.5.2+incompatible
|
||||
github.com/prometheus/client_golang v1.9.0
|
||||
github.com/spf13/cast v1.3.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.6.3
|
||||
github.com/stretchr/testify v1.6.1
|
||||
|
210
pkg/providers/util/claim_extractor.go
Normal file
210
pkg/providers/util/claim_extractor.go
Normal file
@ -0,0 +1,210 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
|
||||
// present, from the profile URL.
|
||||
type ClaimExtractor interface {
|
||||
// GetClaim fetches a named claim and returns the value.
|
||||
GetClaim(claim string) (interface{}, bool, error)
|
||||
|
||||
// GetClaimInto fetches a named claim and puts the value into the destination.
|
||||
GetClaimInto(claim string, dst interface{}) (bool, error)
|
||||
}
|
||||
|
||||
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
|
||||
// If needed, it will use the profile URL to look up a claim if it isn't present
|
||||
// within the ID Token.
|
||||
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
|
||||
payload, err := parseJWT(idToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ID Token: %v", err)
|
||||
}
|
||||
|
||||
tokenClaims, err := simplejson.NewJson(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
|
||||
}
|
||||
|
||||
return &claimExtractor{
|
||||
ctx: ctx,
|
||||
profileURL: profileURL,
|
||||
requestHeaders: profileRequestHeaders,
|
||||
tokenClaims: tokenClaims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// claimExtractor implements the ClaimExtractor interface
|
||||
type claimExtractor struct {
|
||||
profileURL *url.URL
|
||||
ctx context.Context
|
||||
requestHeaders map[string][]string
|
||||
tokenClaims *simplejson.Json
|
||||
profileClaims *simplejson.Json
|
||||
}
|
||||
|
||||
// GetClaim will return the value claim if it exists.
|
||||
// It will only return an error if the profile URL needs to be fetched due to
|
||||
// the claim not being present in the ID Token.
|
||||
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
|
||||
if claim == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
|
||||
return value, true, nil
|
||||
}
|
||||
|
||||
if c.profileClaims == nil {
|
||||
profileClaims, err := c.loadProfileClaims()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
|
||||
}
|
||||
|
||||
c.profileClaims = profileClaims
|
||||
}
|
||||
|
||||
if value := getClaimFrom(claim, c.profileClaims); value != nil {
|
||||
return value, true, nil
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// loadProfileClaims will fetch the profileURL using the provided headers as
|
||||
// authentication.
|
||||
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
|
||||
if c.profileURL == nil || c.requestHeaders == nil {
|
||||
// When no profileURL is set, we return a non-empty map so that
|
||||
// we don't attempt to populate the profile claims again.
|
||||
// If there are no headers, the request would be unauthorized so we also skip
|
||||
// in this case too.
|
||||
return simplejson.New(), nil
|
||||
}
|
||||
|
||||
claims, err := requests.New(c.profileURL.String()).
|
||||
WithContext(c.ctx).
|
||||
WithHeaders(c.requestHeaders).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request to profile URL: %v", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// GetClaimInto loads a claim and places it into the destination interface.
|
||||
// This will attempt to coerce the claim into the specified type.
|
||||
// If it cannot be coerced, an error may be returned.
|
||||
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
|
||||
value, exists, err := c.GetClaim(claim)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
|
||||
}
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
if err := coerceClaim(value, dst); err != nil {
|
||||
return false, fmt.Errorf("could no coerce claim: %v", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
|
||||
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// getClaimFrom gets a claim from a Json object.
|
||||
// It can accept either a single claim name or a json path.
|
||||
// Paths with indexes are not supported.
|
||||
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
|
||||
claimParts := strings.Split(claim, ".")
|
||||
return src.GetPath(claimParts...).Interface()
|
||||
}
|
||||
|
||||
// coerceClaim tries to convert the value into the destination interface type.
|
||||
// If it can convert the value, it will then store the value in the destination
|
||||
// interface.
|
||||
func coerceClaim(value, dst interface{}) error {
|
||||
switch d := dst.(type) {
|
||||
case *string:
|
||||
str, err := toString(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string: %v", err)
|
||||
}
|
||||
*d = str
|
||||
case *[]string:
|
||||
strSlice, err := toStringSlice(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string slice: %v", err)
|
||||
}
|
||||
*d = strSlice
|
||||
case *bool:
|
||||
*d = cast.ToBool(value)
|
||||
default:
|
||||
return fmt.Errorf("unknown type for destination: %T", dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// toStringSlice converts an interface (either a slice or single value) into
|
||||
// a slice of strings.
|
||||
func toStringSlice(value interface{}) ([]string, error) {
|
||||
var sliceValues []interface{}
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
sliceValues = v
|
||||
case interface{}:
|
||||
sliceValues = []interface{}{v}
|
||||
default:
|
||||
sliceValues = cast.ToSlice(value)
|
||||
}
|
||||
|
||||
out := []string{}
|
||||
for _, v := range sliceValues {
|
||||
str, err := toString(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
|
||||
}
|
||||
out = append(out, str)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// toString coerces a value into a string.
|
||||
// If it is non-string, marshal it into JSON.
|
||||
func toString(value interface{}) (string, error) {
|
||||
if str, err := cast.ToStringE(value); err == nil {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
jsonStr, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonStr), nil
|
||||
}
|
530
pkg/providers/util/claim_extractor_test.go
Normal file
530
pkg/providers/util/claim_extractor_test.go
Normal file
@ -0,0 +1,530 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
emptyJSON = "{}"
|
||||
profilePath = "/userinfo"
|
||||
authorizedAccessToken = "valid_access_token"
|
||||
basicIDTokenPayload = `{
|
||||
"user": "idTokenUser",
|
||||
"email": "idTokenEmail",
|
||||
"groups": [
|
||||
"idTokenGroup1",
|
||||
"idTokenGroup2"
|
||||
]
|
||||
}`
|
||||
basicProfileURLPayload = `{
|
||||
"user": "profileUser",
|
||||
"email": "profileEmail",
|
||||
"groups": [
|
||||
"profileGroup1",
|
||||
"profileGroup2"
|
||||
]
|
||||
}`
|
||||
nestedClaimPayload = `{
|
||||
"auth": {
|
||||
"user": {
|
||||
"username": "nestedUser"
|
||||
}
|
||||
}
|
||||
}`
|
||||
complexGroupsPayload = `{
|
||||
"groups": [
|
||||
{
|
||||
"groupID": "group1",
|
||||
"roles": ["admin"]
|
||||
},
|
||||
{
|
||||
"groupID": "group2",
|
||||
"roles": ["user", "employee"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
)
|
||||
|
||||
var _ = Describe("Claim Extractor Suite", func() {
|
||||
Context("Claim Extractor", func() {
|
||||
type newClaimExtractorTableInput struct {
|
||||
idToken string
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("NewClaimExtractor",
|
||||
func(in newClaimExtractorTableInput) {
|
||||
_, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
},
|
||||
Entry("with a valid JWT", newClaimExtractorTableInput{
|
||||
idToken: createJWTFromPayload(basicIDTokenPayload),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{
|
||||
idToken: createJWTFromPayload("this is not JSON"),
|
||||
expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"),
|
||||
}),
|
||||
Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{
|
||||
idToken: "eyJeyJ",
|
||||
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"),
|
||||
}),
|
||||
Entry("with an non-base64 IDToken", newClaimExtractorTableInput{
|
||||
idToken: "{metadata}.{payload}.{signature}",
|
||||
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"),
|
||||
}),
|
||||
)
|
||||
|
||||
type getClaimTableInput struct {
|
||||
testClaimExtractorOpts
|
||||
claim string
|
||||
expectedValue interface{}
|
||||
expectExists bool
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("GetClaim",
|
||||
func(in getClaimTableInput) {
|
||||
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if serverClose != nil {
|
||||
defer serverClose()
|
||||
}
|
||||
|
||||
value, exists, err := claimExtractor.GetClaim(in.claim)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if in.expectedValue != nil {
|
||||
Expect(value).To(Equal(in.expectedValue))
|
||||
} else {
|
||||
Expect(value).To(BeNil())
|
||||
}
|
||||
|
||||
Expect(exists).To(Equal(in.expectExists))
|
||||
},
|
||||
Entry("retrieves a string claim from ID Token when present", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
expectExists: true,
|
||||
expectedValue: "idTokenUser",
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
expectExists: true,
|
||||
expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"},
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the empty string", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
},
|
||||
claim: "",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
},
|
||||
claim: "not_found",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "not_found",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: nil,
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "not_found",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("when the profile URL is unauthorized", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: emptyJSON,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: make(http.Header),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
expectExists: false,
|
||||
expectedValue: nil,
|
||||
expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
|
||||
}),
|
||||
Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: emptyJSON,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
expectExists: true,
|
||||
expectedValue: "profileUser",
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string claim from a nested path", getClaimTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: nestedClaimPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "auth.user.username",
|
||||
expectExists: true,
|
||||
expectedValue: "nestedUser",
|
||||
expectedError: nil,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
It("GetClaim should only call the profile URL once", func() {
|
||||
var counter int32
|
||||
countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
rw.Write([]byte(basicProfileURLPayload))
|
||||
}
|
||||
|
||||
claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
|
||||
idTokenPayload: "{}",
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: countRequestsHandler,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if serverClose != nil {
|
||||
defer serverClose()
|
||||
}
|
||||
|
||||
value, exists, err := claimExtractor.GetClaim("user")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(value).To(Equal("profileUser"))
|
||||
Expect(counter).To(BeEquivalentTo(1))
|
||||
|
||||
// Check a different claim, but expect the count not to increase
|
||||
value, exists, err = claimExtractor.GetClaim("email")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
Expect(value).To(Equal("profileEmail"))
|
||||
Expect(counter).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
type getClaimIntoTableInput struct {
|
||||
testClaimExtractorOpts
|
||||
into interface{}
|
||||
claim string
|
||||
expectedValue interface{}
|
||||
expectExists bool
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("GetClaimInto",
|
||||
func(in getClaimIntoTableInput) {
|
||||
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if serverClose != nil {
|
||||
defer serverClose()
|
||||
}
|
||||
|
||||
exists, err := claimExtractor.GetClaimInto(in.claim, in.into)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if in.expectedValue != nil {
|
||||
Expect(in.into).To(Equal(in.expectedValue))
|
||||
} else {
|
||||
Expect(in.into).To(BeEmpty())
|
||||
}
|
||||
|
||||
Expect(exists).To(Equal(in.expectExists))
|
||||
},
|
||||
Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: stringPointer(""),
|
||||
expectExists: true,
|
||||
expectedValue: stringPointer("idTokenUser"),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: stringSlicePointer([]string{}),
|
||||
expectExists: true,
|
||||
expectedValue: stringSlicePointer([]string{"idTokenUser"}),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
into: stringSlicePointer([]string{}),
|
||||
expectExists: true,
|
||||
expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
into: stringPointer(""),
|
||||
expectExists: true,
|
||||
expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: "",
|
||||
expectExists: false,
|
||||
expectedValue: "",
|
||||
expectedError: errors.New("could no coerce claim: unknown type for destination: string"),
|
||||
}),
|
||||
Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: complexGroupsPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: shouldNotBeRequestedProfileHandler,
|
||||
},
|
||||
claim: "groups",
|
||||
into: stringSlicePointer([]string{}),
|
||||
expectExists: true,
|
||||
expectedValue: stringSlicePointer([]string{
|
||||
"{\"groupID\":\"group1\",\"roles\":[\"admin\"]}",
|
||||
"{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}",
|
||||
}),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: basicIDTokenPayload,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: newAuthorizedHeader(),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "not_found",
|
||||
into: stringPointer(""),
|
||||
expectExists: false,
|
||||
expectedValue: stringPointer(""),
|
||||
expectedError: nil,
|
||||
}),
|
||||
Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{
|
||||
testClaimExtractorOpts: testClaimExtractorOpts{
|
||||
idTokenPayload: emptyJSON,
|
||||
setProfileURL: true,
|
||||
profileRequestHeaders: make(http.Header),
|
||||
profileRequestHandler: requiresAuthProfileHandler,
|
||||
},
|
||||
claim: "user",
|
||||
into: stringPointer(""),
|
||||
expectExists: false,
|
||||
expectedValue: stringPointer(""),
|
||||
expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
|
||||
}),
|
||||
)
|
||||
|
||||
type coerceClaimTableInput struct {
|
||||
value interface{}
|
||||
dst interface{}
|
||||
expectedDst interface{}
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("coerceClaim",
|
||||
func(in coerceClaimTableInput) {
|
||||
err := coerceClaim(in.value, in.dst)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(in.dst).To(Equal(in.expectedDst))
|
||||
},
|
||||
Entry("coerces a string to a string", coerceClaimTableInput{
|
||||
value: "some_string",
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("some_string"),
|
||||
}),
|
||||
Entry("coerces a slice to a string slice", coerceClaimTableInput{
|
||||
value: []interface{}{"a", "b"},
|
||||
dst: stringSlicePointer([]string{}),
|
||||
expectedDst: stringSlicePointer([]string{"a", "b"}),
|
||||
}),
|
||||
Entry("coerces a bool to a bool", coerceClaimTableInput{
|
||||
value: true,
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
}),
|
||||
Entry("coerces a string to a bool", coerceClaimTableInput{
|
||||
value: "true",
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
}),
|
||||
Entry("coerces a map to a string", coerceClaimTableInput{
|
||||
value: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "baz"},
|
||||
},
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
// ******************************************
|
||||
// Helpers for setting up the claim extractor
|
||||
// ******************************************
|
||||
|
||||
type testClaimExtractorOpts struct {
|
||||
idTokenPayload string
|
||||
setProfileURL bool
|
||||
profileRequestHeaders http.Header
|
||||
profileRequestHandler http.HandlerFunc
|
||||
}
|
||||
|
||||
func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) {
|
||||
var profileURL *url.URL
|
||||
var closeServer func()
|
||||
if in.setProfileURL {
|
||||
server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler))
|
||||
closeServer = server.Close
|
||||
|
||||
var err error
|
||||
profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
rawIDToken := createJWTFromPayload(in.idTokenPayload)
|
||||
|
||||
claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders)
|
||||
return claimExtractor, closeServer, err
|
||||
}
|
||||
|
||||
func createJWTFromPayload(payload string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON))
|
||||
payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header)
|
||||
}
|
||||
|
||||
func newAuthorizedHeader() http.Header {
|
||||
headers := make(http.Header)
|
||||
headers.Add("Authorization", "Bearer "+authorizedAccessToken)
|
||||
return headers
|
||||
}
|
||||
|
||||
func hasAuthorizedHeader(headers http.Header) bool {
|
||||
return headers.Get("Authorization") == "Bearer "+authorizedAccessToken
|
||||
}
|
||||
|
||||
// ***********************
|
||||
// Typed Pointer Functions
|
||||
// ***********************
|
||||
|
||||
func stringPointer(in string) *string {
|
||||
return &in
|
||||
}
|
||||
|
||||
func stringSlicePointer(in []string) *[]string {
|
||||
return &in
|
||||
}
|
||||
|
||||
func boolPointer(in bool) *bool {
|
||||
return &in
|
||||
}
|
||||
|
||||
// ******************************
|
||||
// Different profile URL handlers
|
||||
// ******************************
|
||||
|
||||
func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
Expect(true).To(BeFalse(), "Unexpected request to profile URL")
|
||||
}
|
||||
|
||||
func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
if !hasAuthorizedHeader(req.Header) {
|
||||
rw.WriteHeader(403)
|
||||
rw.Write([]byte("Unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
rw.Write([]byte(basicProfileURLPayload))
|
||||
}
|
17
pkg/providers/util/util_suite_test.go
Normal file
17
pkg/providers/util/util_suite_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestProviderUtilSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
logger.SetErrOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Provider Utils")
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user