Ensure required PKCE information is exposed from provider

This commit is contained in:
Joel Speed 2022-02-18 15:04:38 +00:00
parent 474a3b049e
commit c3158ebc48
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
2 changed files with 69 additions and 13 deletions

View File

@ -16,6 +16,7 @@ type providerJSON struct {
TokenURL string `json:"token_endpoint"`
JWKsURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
CodeChallengeAlgs []string `json:"code_challenge_methods_supported"`
}
// Endpoints represents the endpoints discovered as part of the OIDC discovery process
@ -27,10 +28,17 @@ type Endpoints struct {
UserInfoURL string
}
// PKCE holds information relevant to the PKCE (code challenge) support of the
// provider.
type PKCE struct {
CodeChallengeAlgs []string
}
// DiscoveryProvider holds information about an identity provider having
// used OIDC discovery to retrieve the information.
type DiscoveryProvider interface {
Endpoints() Endpoints
PKCE() PKCE
}
// NewProvider allows a user to perform an OIDC discovery and returns the DiscoveryProvider.
@ -59,6 +67,7 @@ func NewProvider(ctx context.Context, issuerURL string, skipIssuerVerification b
tokenURL: p.TokenURL,
jwksURL: p.JWKsURL,
userInfoURL: p.UserInfoURL,
codeChallengeAlgs: p.CodeChallengeAlgs,
}, nil
}
@ -68,6 +77,7 @@ type discoveryProvider struct {
tokenURL string
jwksURL string
userInfoURL string
codeChallengeAlgs []string
}
// Endpoints returns the discovered endpoints needed for an authentication provider.
@ -79,3 +89,10 @@ func (p *discoveryProvider) Endpoints() Endpoints {
UserInfoURL: p.userInfoURL,
}
}
// PKCE returns information related to the PKCE (code challenge) support of the provider.
func (p *discoveryProvider) PKCE() PKCE {
return PKCE{
CodeChallengeAlgs: p.codeChallengeAlgs,
}
}

View File

@ -84,6 +84,25 @@ var _ = Describe("Provider", func() {
expectedError: "failed to discover OIDC configuration: unexpected status \"400\"",
}),
)
It("with code challenges supported on the provider, shold populate PKCE information", func() {
m, err := mockoidc.NewServer(nil)
Expect(err).ToNot(HaveOccurred())
m.AddMiddleware(newCodeChallengeIssuerMiddleware(m))
ln, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
Expect(m.Start(ln, nil)).To(Succeed())
defer func() {
Expect(m.Shutdown()).To(Succeed())
}()
provider, err := NewProvider(context.Background(), m.Issuer(), false)
Expect(err).ToNot(HaveOccurred())
Expect(provider.PKCE().CodeChallengeAlgs).To(ConsistOf("S256", "plain"))
})
})
func newInvalidIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Handler {
@ -105,6 +124,26 @@ func newInvalidIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Ha
}
}
func newCodeChallengeIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
p := providerJSON{
Issuer: m.Issuer(),
AuthURL: m.AuthorizationEndpoint(),
TokenURL: m.TokenEndpoint(),
JWKsURL: m.JWKSEndpoint(),
UserInfoURL: m.UserinfoEndpoint(),
CodeChallengeAlgs: []string{"S256", "plain"},
}
data, err := json.Marshal(p)
if err != nil {
rw.WriteHeader(500)
}
rw.Write(data)
})
}
}
func newBadRequestMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {