Add claim extractor provider util

This commit is contained in:
Joel Speed 2021-06-26 11:48:49 +01:00 committed by Joel Speed
parent 44dc3cad77
commit 537e596904
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
4 changed files with 758 additions and 0 deletions

1
go.mod
View File

@ -23,6 +23,7 @@ require (
github.com/onsi/gomega v1.10.2
github.com/pierrec/lz4 v2.5.2+incompatible
github.com/prometheus/client_golang v1.9.0
github.com/spf13/cast v1.3.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.6.3
github.com/stretchr/testify v1.6.1

View File

@ -0,0 +1,210 @@
package util
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/bitly/go-simplejson"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"github.com/spf13/cast"
)
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
// present, from the profile URL.
type ClaimExtractor interface {
// GetClaim fetches a named claim and returns the value.
GetClaim(claim string) (interface{}, bool, error)
// GetClaimInto fetches a named claim and puts the value into the destination.
GetClaimInto(claim string, dst interface{}) (bool, error)
}
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
// If needed, it will use the profile URL to look up a claim if it isn't present
// within the ID Token.
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
payload, err := parseJWT(idToken)
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token: %v", err)
}
tokenClaims, err := simplejson.NewJson(payload)
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
}
return &claimExtractor{
ctx: ctx,
profileURL: profileURL,
requestHeaders: profileRequestHeaders,
tokenClaims: tokenClaims,
}, nil
}
// claimExtractor implements the ClaimExtractor interface
type claimExtractor struct {
profileURL *url.URL
ctx context.Context
requestHeaders map[string][]string
tokenClaims *simplejson.Json
profileClaims *simplejson.Json
}
// GetClaim will return the value claim if it exists.
// It will only return an error if the profile URL needs to be fetched due to
// the claim not being present in the ID Token.
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
if claim == "" {
return nil, false, nil
}
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
return value, true, nil
}
if c.profileClaims == nil {
profileClaims, err := c.loadProfileClaims()
if err != nil {
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
}
c.profileClaims = profileClaims
}
if value := getClaimFrom(claim, c.profileClaims); value != nil {
return value, true, nil
}
return nil, false, nil
}
// loadProfileClaims will fetch the profileURL using the provided headers as
// authentication.
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
if c.profileURL == nil || c.requestHeaders == nil {
// When no profileURL is set, we return a non-empty map so that
// we don't attempt to populate the profile claims again.
// If there are no headers, the request would be unauthorized so we also skip
// in this case too.
return simplejson.New(), nil
}
claims, err := requests.New(c.profileURL.String()).
WithContext(c.ctx).
WithHeaders(c.requestHeaders).
Do().
UnmarshalJSON()
if err != nil {
return nil, fmt.Errorf("error making request to profile URL: %v", err)
}
return claims, nil
}
// GetClaimInto loads a claim and places it into the destination interface.
// This will attempt to coerce the claim into the specified type.
// If it cannot be coerced, an error may be returned.
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
value, exists, err := c.GetClaim(claim)
if err != nil {
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
}
if !exists {
return false, nil
}
if err := coerceClaim(value, dst); err != nil {
return false, fmt.Errorf("could no coerce claim: %v", err)
}
return true, nil
}
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}
// getClaimFrom gets a claim from a Json object.
// It can accept either a single claim name or a json path.
// Paths with indexes are not supported.
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
claimParts := strings.Split(claim, ".")
return src.GetPath(claimParts...).Interface()
}
// coerceClaim tries to convert the value into the destination interface type.
// If it can convert the value, it will then store the value in the destination
// interface.
func coerceClaim(value, dst interface{}) error {
switch d := dst.(type) {
case *string:
str, err := toString(value)
if err != nil {
return fmt.Errorf("could not convert value to string: %v", err)
}
*d = str
case *[]string:
strSlice, err := toStringSlice(value)
if err != nil {
return fmt.Errorf("could not convert value to string slice: %v", err)
}
*d = strSlice
case *bool:
*d = cast.ToBool(value)
default:
return fmt.Errorf("unknown type for destination: %T", dst)
}
return nil
}
// toStringSlice converts an interface (either a slice or single value) into
// a slice of strings.
func toStringSlice(value interface{}) ([]string, error) {
var sliceValues []interface{}
switch v := value.(type) {
case []interface{}:
sliceValues = v
case interface{}:
sliceValues = []interface{}{v}
default:
sliceValues = cast.ToSlice(value)
}
out := []string{}
for _, v := range sliceValues {
str, err := toString(v)
if err != nil {
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
}
out = append(out, str)
}
return out, nil
}
// toString coerces a value into a string.
// If it is non-string, marshal it into JSON.
func toString(value interface{}) (string, error) {
if str, err := cast.ToStringE(value); err == nil {
return str, nil
}
jsonStr, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(jsonStr), nil
}

View File

@ -0,0 +1,530 @@
package util
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
const (
emptyJSON = "{}"
profilePath = "/userinfo"
authorizedAccessToken = "valid_access_token"
basicIDTokenPayload = `{
"user": "idTokenUser",
"email": "idTokenEmail",
"groups": [
"idTokenGroup1",
"idTokenGroup2"
]
}`
basicProfileURLPayload = `{
"user": "profileUser",
"email": "profileEmail",
"groups": [
"profileGroup1",
"profileGroup2"
]
}`
nestedClaimPayload = `{
"auth": {
"user": {
"username": "nestedUser"
}
}
}`
complexGroupsPayload = `{
"groups": [
{
"groupID": "group1",
"roles": ["admin"]
},
{
"groupID": "group2",
"roles": ["user", "employee"]
}
]
}`
)
var _ = Describe("Claim Extractor Suite", func() {
Context("Claim Extractor", func() {
type newClaimExtractorTableInput struct {
idToken string
expectedError error
}
DescribeTable("NewClaimExtractor",
func(in newClaimExtractorTableInput) {
_, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
} else {
Expect(err).ToNot(HaveOccurred())
}
},
Entry("with a valid JWT", newClaimExtractorTableInput{
idToken: createJWTFromPayload(basicIDTokenPayload),
expectedError: nil,
}),
Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{
idToken: createJWTFromPayload("this is not JSON"),
expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"),
}),
Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{
idToken: "eyJeyJ",
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"),
}),
Entry("with an non-base64 IDToken", newClaimExtractorTableInput{
idToken: "{metadata}.{payload}.{signature}",
expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"),
}),
)
type getClaimTableInput struct {
testClaimExtractorOpts
claim string
expectedValue interface{}
expectExists bool
expectedError error
}
DescribeTable("GetClaim",
func(in getClaimTableInput) {
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
value, exists, err := claimExtractor.GetClaim(in.claim)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
if in.expectedValue != nil {
Expect(value).To(Equal(in.expectedValue))
} else {
Expect(value).To(BeNil())
}
Expect(exists).To(Equal(in.expectExists))
},
Entry("retrieves a string claim from ID Token when present", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
expectExists: true,
expectedValue: "idTokenUser",
expectedError: nil,
}),
Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
expectExists: true,
expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"},
expectedError: nil,
}),
Entry("when the requested claim is the empty string", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
},
claim: "",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
profileRequestHeaders: newAuthorizedHeader(),
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: nil,
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "not_found",
expectExists: false,
expectedValue: nil,
expectedError: nil,
}),
Entry("when the profile URL is unauthorized", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: make(http.Header),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
expectExists: false,
expectedValue: nil,
expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
}),
Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
expectExists: true,
expectedValue: "profileUser",
expectedError: nil,
}),
Entry("retrieves a string claim from a nested path", getClaimTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: nestedClaimPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "auth.user.username",
expectExists: true,
expectedValue: "nestedUser",
expectedError: nil,
}),
)
})
It("GetClaim should only call the profile URL once", func() {
var counter int32
countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&counter, 1)
rw.Write([]byte(basicProfileURLPayload))
}
claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{
idTokenPayload: "{}",
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: countRequestsHandler,
})
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
value, exists, err := claimExtractor.GetClaim("user")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
Expect(value).To(Equal("profileUser"))
Expect(counter).To(BeEquivalentTo(1))
// Check a different claim, but expect the count not to increase
value, exists, err = claimExtractor.GetClaim("email")
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
Expect(value).To(Equal("profileEmail"))
Expect(counter).To(BeEquivalentTo(1))
})
type getClaimIntoTableInput struct {
testClaimExtractorOpts
into interface{}
claim string
expectedValue interface{}
expectExists bool
expectedError error
}
DescribeTable("GetClaimInto",
func(in getClaimIntoTableInput) {
claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts)
Expect(err).ToNot(HaveOccurred())
if serverClose != nil {
defer serverClose()
}
exists, err := claimExtractor.GetClaimInto(in.claim, in.into)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
if in.expectedValue != nil {
Expect(in.into).To(Equal(in.expectedValue))
} else {
Expect(in.into).To(BeEmpty())
}
Expect(exists).To(Equal(in.expectExists))
},
Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: stringPointer(""),
expectExists: true,
expectedValue: stringPointer("idTokenUser"),
expectedError: nil,
}),
Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{"idTokenUser"}),
expectedError: nil,
}),
Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}),
expectedError: nil,
}),
Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringPointer(""),
expectExists: true,
expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"),
expectedError: nil,
}),
Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "user",
into: "",
expectExists: false,
expectedValue: "",
expectedError: errors.New("could no coerce claim: unknown type for destination: string"),
}),
Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: complexGroupsPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: shouldNotBeRequestedProfileHandler,
},
claim: "groups",
into: stringSlicePointer([]string{}),
expectExists: true,
expectedValue: stringSlicePointer([]string{
"{\"groupID\":\"group1\",\"roles\":[\"admin\"]}",
"{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}",
}),
expectedError: nil,
}),
Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: basicIDTokenPayload,
setProfileURL: true,
profileRequestHeaders: newAuthorizedHeader(),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "not_found",
into: stringPointer(""),
expectExists: false,
expectedValue: stringPointer(""),
expectedError: nil,
}),
Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{
testClaimExtractorOpts: testClaimExtractorOpts{
idTokenPayload: emptyJSON,
setProfileURL: true,
profileRequestHeaders: make(http.Header),
profileRequestHandler: requiresAuthProfileHandler,
},
claim: "user",
into: stringPointer(""),
expectExists: false,
expectedValue: stringPointer(""),
expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"),
}),
)
type coerceClaimTableInput struct {
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
DescribeTable("coerceClaim",
func(in coerceClaimTableInput) {
err := coerceClaim(in.value, in.dst)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
Expect(in.dst).To(Equal(in.expectedDst))
},
Entry("coerces a string to a string", coerceClaimTableInput{
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
}),
Entry("coerces a slice to a string slice", coerceClaimTableInput{
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
}),
Entry("coerces a bool to a bool", coerceClaimTableInput{
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a string to a bool", coerceClaimTableInput{
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a map to a string", coerceClaimTableInput{
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
}),
)
})
// ******************************************
// Helpers for setting up the claim extractor
// ******************************************
type testClaimExtractorOpts struct {
idTokenPayload string
setProfileURL bool
profileRequestHeaders http.Header
profileRequestHandler http.HandlerFunc
}
func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) {
var profileURL *url.URL
var closeServer func()
if in.setProfileURL {
server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler))
closeServer = server.Close
var err error
profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath)
Expect(err).ToNot(HaveOccurred())
}
rawIDToken := createJWTFromPayload(in.idTokenPayload)
claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders)
return claimExtractor, closeServer, err
}
func createJWTFromPayload(payload string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON))
payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload))
return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header)
}
func newAuthorizedHeader() http.Header {
headers := make(http.Header)
headers.Add("Authorization", "Bearer "+authorizedAccessToken)
return headers
}
func hasAuthorizedHeader(headers http.Header) bool {
return headers.Get("Authorization") == "Bearer "+authorizedAccessToken
}
// ***********************
// Typed Pointer Functions
// ***********************
func stringPointer(in string) *string {
return &in
}
func stringSlicePointer(in []string) *[]string {
return &in
}
func boolPointer(in bool) *bool {
return &in
}
// ******************************
// Different profile URL handlers
// ******************************
func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) {
defer GinkgoRecover()
Expect(true).To(BeFalse(), "Unexpected request to profile URL")
}
func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) {
if !hasAuthorizedHeader(req.Header) {
rw.WriteHeader(403)
rw.Write([]byte("Unauthorized"))
return
}
rw.Write([]byte(basicProfileURLPayload))
}

View File

@ -0,0 +1,17 @@
package util
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func TestProviderUtilSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Provider Utils")
}