From fde09bea4ec847d4611836d657215cfc60dbc2ef Mon Sep 17 00:00:00 2001 From: Alexander Block Date: Mon, 14 Sep 2020 13:48:17 +0200 Subject: [PATCH] Move azure specific resource parameter handling into azure provider --- providers/azure.go | 9 +++++++++ providers/azure_test.go | 7 +++++++ providers/provider_default.go | 3 --- providers/provider_default_test.go | 18 ------------------ 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/providers/azure.go b/providers/azure.go index 0ae0cba..c9940d6 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -210,3 +210,12 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session return email, err } + +func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { + a, params := DefaultGetLoginURL(p.ProviderData, redirectURI, state) + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + params.Add("resource", p.ProtectedResource.String()) + } + a.RawQuery = params.Encode() + return a.String() +} diff --git a/providers/azure_test.go b/providers/azure_test.go index fe9bbb4..6e2e4e9 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -213,3 +213,10 @@ func TestAzureProviderRedeemReturnsIdToken(t *testing.T) { assert.Equal(t, timestamp, s.ExpiresOn.UTC()) assert.Equal(t, "refresh1234", s.RefreshToken) } + +func TestAzureProviderProtectedResourceConfigured(t *testing.T) { + p := testAzureProvider("") + p.ProtectedResource, _ = url.Parse("http://my.resource.test") + result := p.GetLoginURL("https://my.test.app/oauth", "") + assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) +} diff --git a/providers/provider_default.go b/providers/provider_default.go index 6e898a8..65c7f72 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -89,9 +89,6 @@ func DefaultGetLoginURL(p *ProviderData, redirectURI, state string) (url.URL, ur params.Set("client_id", p.ClientID) params.Set("response_type", "code") params.Add("state", state) - if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - params.Add("resource", p.ProtectedResource.String()) - } return a, params } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index e699a68..74d7096 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -47,21 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) { result := p.GetLoginURL("https://my.test.app/oauth", "") assert.Contains(t, result, "acr_values=testValue") } - -func TestProtectedResourceConfigured(t *testing.T) { - p := &ProviderData{ - LoginURL: &url.URL{ - Scheme: "http", - Host: "my.test.idp", - Path: "/oauth/authorize", - }, - AcrValues: "testValue", - ProtectedResource: &url.URL{ - Scheme: "http", - Host: "my.resource.test", - }, - } - - result := p.GetLoginURL("https://my.test.app/oauth", "") - assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) -}