66bfd8ebd5
* add azure china support Signed-off-by: Markus Blaschke <mblaschke82@gmail.com> * update changelog Signed-off-by: Markus Blaschke <mblaschke82@gmail.com> * fix lint Signed-off-by: Markus Blaschke <mblaschke82@gmail.com> --------- Signed-off-by: Markus Blaschke <mblaschke82@gmail.com> Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
460 lines
14 KiB
Go
460 lines
14 KiB
Go
package providers
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
"github.com/bitly/go-simplejson"
|
|
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
|
)
|
|
|
|
// AzureProvider represents an Azure based Identity Provider
|
|
type AzureProvider struct {
|
|
*ProviderData
|
|
Tenant string
|
|
GraphGroupField string
|
|
isV2Endpoint bool
|
|
}
|
|
|
|
var _ Provider = (*AzureProvider)(nil)
|
|
|
|
const (
|
|
azureProviderName = "Azure"
|
|
azureDefaultScope = "openid"
|
|
azureDefaultGraphGroupField = "id"
|
|
)
|
|
|
|
var (
|
|
// Default Login URL for Azure. Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/authorize.
|
|
azureDefaultLoginURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "login.microsoftonline.com",
|
|
Path: "/common/oauth2/authorize",
|
|
}
|
|
|
|
// Default Redeem URL for Azure. Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/token.
|
|
azureDefaultRedeemURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "login.microsoftonline.com",
|
|
Path: "/common/oauth2/token",
|
|
}
|
|
|
|
// Default Profile URL for Azure. Pre-parsed URL of https://graph.microsoft.com/v1.0/me.
|
|
azureDefaultProfileURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "graph.microsoft.com",
|
|
Path: "/v1.0/me",
|
|
}
|
|
)
|
|
|
|
// NewAzureProvider initiates a new AzureProvider
|
|
func NewAzureProvider(p *ProviderData, opts options.AzureOptions) *AzureProvider {
|
|
p.setProviderDefaults(providerDefaults{
|
|
name: azureProviderName,
|
|
loginURL: azureDefaultLoginURL,
|
|
redeemURL: azureDefaultRedeemURL,
|
|
profileURL: azureDefaultProfileURL,
|
|
validateURL: nil,
|
|
scope: azureDefaultScope,
|
|
})
|
|
|
|
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
|
|
p.ValidateURL = p.ProfileURL
|
|
}
|
|
p.getAuthorizationHeaderFunc = makeAzureHeader
|
|
|
|
tenant := "common"
|
|
if opts.Tenant != "" {
|
|
tenant = opts.Tenant
|
|
overrideTenantURL(p.LoginURL, azureDefaultLoginURL, tenant, "authorize")
|
|
overrideTenantURL(p.RedeemURL, azureDefaultRedeemURL, tenant, "token")
|
|
}
|
|
|
|
graphGroupField := azureDefaultGraphGroupField
|
|
if opts.GraphGroupField != "" {
|
|
graphGroupField = opts.GraphGroupField
|
|
}
|
|
|
|
isV2Endpoint := false
|
|
if strings.Contains(p.LoginURL.String(), "v2.0") {
|
|
isV2Endpoint = true
|
|
azureV2GraphScope := fmt.Sprintf("https://%s/.default", p.ProfileURL.Host)
|
|
|
|
if strings.Contains(p.Scope, " groups") {
|
|
logger.Print("WARNING: `groups` scope is not an accepted scope when using Azure OAuth V2 endpoint. Removing it from the scope list")
|
|
p.Scope = strings.ReplaceAll(p.Scope, " groups", "")
|
|
}
|
|
|
|
if !strings.Contains(p.Scope, " "+azureV2GraphScope) {
|
|
// In order to be able to query MS Graph we must pass the ms graph default endpoint
|
|
p.Scope += " " + azureV2GraphScope
|
|
}
|
|
|
|
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
|
|
logger.Print("WARNING: `--resource` option has no effect when using the Azure OAuth V2 endpoint.")
|
|
}
|
|
}
|
|
|
|
return &AzureProvider{
|
|
ProviderData: p,
|
|
Tenant: tenant,
|
|
GraphGroupField: graphGroupField,
|
|
isV2Endpoint: isV2Endpoint,
|
|
}
|
|
}
|
|
|
|
func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
|
|
if current == nil || current.String() == "" || current.String() == defaultURL.String() {
|
|
*current = url.URL{
|
|
Scheme: "https",
|
|
Host: current.Host,
|
|
Path: "/" + tenant + "/oauth2/" + path}
|
|
}
|
|
}
|
|
|
|
func getMicrosoftGraphGroupsURL(profileURL *url.URL, graphGroupField string) *url.URL {
|
|
|
|
selectStatement := "$select=displayName,id"
|
|
if !slices.Contains([]string{"displayName", "id"}, graphGroupField) {
|
|
selectStatement += "," + graphGroupField
|
|
}
|
|
|
|
// Select only security groups. Due to the filter option, count param is mandatory even if unused otherwise
|
|
return &url.URL{
|
|
Scheme: "https",
|
|
Host: profileURL.Host,
|
|
Path: "/v1.0/me/transitiveMemberOf",
|
|
RawQuery: "$count=true&$filter=securityEnabled+eq+true&" + selectStatement,
|
|
}
|
|
}
|
|
|
|
func (p *AzureProvider) GetLoginURL(redirectURI, state, _ string, extraParams url.Values) string {
|
|
// In azure oauth v2 there is no resource param so add it only if V1 endpoint
|
|
// https://docs.microsoft.com/en-us/azure/active-directory/azuread-dev/azure-ad-endpoint-comparison#scopes-not-resources
|
|
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" && !p.isV2Endpoint {
|
|
extraParams.Add("resource", p.ProtectedResource.String())
|
|
}
|
|
a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
|
|
return a.String()
|
|
}
|
|
|
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) {
|
|
params, err := p.prepareRedeem(redirectURL, code, codeVerifier)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// blindly try json and x-www-form-urlencoded
|
|
var jsonResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresOn int64 `json:"expires_on,string"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
|
|
err = requests.New(p.RedeemURL.String()).
|
|
WithContext(ctx).
|
|
WithMethod("POST").
|
|
WithBody(bytes.NewBufferString(params.Encode())).
|
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
|
Do().
|
|
UnmarshalInto(&jsonResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
session := &sessions.SessionState{
|
|
AccessToken: jsonResponse.AccessToken,
|
|
IDToken: jsonResponse.IDToken,
|
|
RefreshToken: jsonResponse.RefreshToken,
|
|
}
|
|
session.CreatedAtNow()
|
|
session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
|
|
|
err = p.extractClaimsIntoSession(ctx, session)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to get email and/or groups claims from token: %v", err)
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// EnrichSession enriches the session state with userID, mail and groups
|
|
func (p *AzureProvider) EnrichSession(ctx context.Context, session *sessions.SessionState) error {
|
|
err := p.extractClaimsIntoSession(ctx, session)
|
|
|
|
if err != nil {
|
|
logger.Printf("unable to get email and/or groups claims from token: %v", err)
|
|
}
|
|
|
|
if session.Email == "" {
|
|
email, err := p.getEmailFromProfileAPI(ctx, session.AccessToken)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to get email address from profile URL: %v", err)
|
|
}
|
|
session.Email = email
|
|
}
|
|
|
|
// If using the v2.0 oidc endpoint we're also querying Microsoft Graph
|
|
if p.isV2Endpoint {
|
|
groups, err := p.getGroupsFromProfileAPI(ctx, session)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to get groups from Microsoft Graph: %v", err)
|
|
}
|
|
session.Groups = util.RemoveDuplicateStr(append(session.Groups, groups...))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *AzureProvider) prepareRedeem(redirectURL, code, codeVerifier string) (url.Values, error) {
|
|
params := url.Values{}
|
|
if code == "" {
|
|
return params, ErrMissingCode
|
|
}
|
|
clientSecret, err := p.GetClientSecret()
|
|
if err != nil {
|
|
return params, err
|
|
}
|
|
|
|
params.Add("redirect_uri", redirectURL)
|
|
params.Add("client_id", p.ClientID)
|
|
params.Add("client_secret", clientSecret)
|
|
params.Add("code", code)
|
|
params.Add("grant_type", "authorization_code")
|
|
if codeVerifier != "" {
|
|
params.Add("code_verifier", codeVerifier)
|
|
}
|
|
|
|
// In azure oauth v2 there is no resource param so add it only if V1 endpoint
|
|
// https://docs.microsoft.com/en-us/azure/active-directory/azuread-dev/azure-ad-endpoint-comparison#scopes-not-resources
|
|
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" && !p.isV2Endpoint {
|
|
params.Add("resource", p.ProtectedResource.String())
|
|
}
|
|
|
|
return params, nil
|
|
}
|
|
|
|
// extractClaimsIntoSession tries to extract email and groups claims from either id_token or access token
|
|
// when oidc verifier is configured
|
|
func (p *AzureProvider) extractClaimsIntoSession(ctx context.Context, session *sessions.SessionState) error {
|
|
|
|
var s *sessions.SessionState
|
|
|
|
// First let's verify session token
|
|
if err := p.verifySessionToken(ctx, session); err != nil {
|
|
return fmt.Errorf("unable to verify token: %v", err)
|
|
}
|
|
|
|
// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
|
|
// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
|
|
// due to above issues, id_token may not be signed by AAD
|
|
// in that case, we will fallback to access token
|
|
var err error
|
|
s, err = p.buildSessionFromClaims(session.IDToken, session.AccessToken)
|
|
if err != nil || s.Email == "" {
|
|
s, err = p.buildSessionFromClaims(session.AccessToken, session.AccessToken)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("unable to get claims from token: %v", err)
|
|
}
|
|
|
|
session.Email = s.Email
|
|
if s.Groups != nil {
|
|
session.Groups = s.Groups
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// verifySessionToken tries to validate id_token if present or access token when oidc verifier is configured
|
|
func (p *AzureProvider) verifySessionToken(ctx context.Context, session *sessions.SessionState) error {
|
|
// Without a verifier there's no way to verify
|
|
if p.Verifier == nil {
|
|
return nil
|
|
}
|
|
|
|
if session.IDToken != "" {
|
|
if _, err := p.Verifier.Verify(ctx, session.IDToken); err != nil {
|
|
logger.Printf("unable to verify ID token, fallback to access token: %v", err)
|
|
if _, err = p.Verifier.Verify(ctx, session.AccessToken); err != nil {
|
|
return fmt.Errorf("unable to verify access token: %v", err)
|
|
}
|
|
}
|
|
} else if _, err := p.Verifier.Verify(ctx, session.AccessToken); err != nil {
|
|
return fmt.Errorf("unable to verify access token: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
|
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
|
if s == nil || s.RefreshToken == "" {
|
|
return false, nil
|
|
}
|
|
|
|
err := p.redeemRefreshToken(ctx, s)
|
|
if err != nil {
|
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
|
clientSecret, err := p.GetClientSecret()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
params := url.Values{}
|
|
params.Add("client_id", p.ClientID)
|
|
params.Add("client_secret", clientSecret)
|
|
params.Add("refresh_token", s.RefreshToken)
|
|
params.Add("grant_type", "refresh_token")
|
|
|
|
var jsonResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresOn int64 `json:"expires_on,string"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
|
|
err = requests.New(p.RedeemURL.String()).
|
|
WithContext(ctx).
|
|
WithMethod("POST").
|
|
WithBody(bytes.NewBufferString(params.Encode())).
|
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
|
Do().
|
|
UnmarshalInto(&jsonResponse)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.AccessToken = jsonResponse.AccessToken
|
|
s.IDToken = jsonResponse.IDToken
|
|
s.RefreshToken = jsonResponse.RefreshToken
|
|
|
|
s.CreatedAtNow()
|
|
s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
|
|
|
|
err = p.extractClaimsIntoSession(ctx, s)
|
|
|
|
if err != nil {
|
|
logger.Printf("unable to get email and/or groups claims from token: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func makeAzureHeader(accessToken string) http.Header {
|
|
return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil)
|
|
}
|
|
|
|
func (p *AzureProvider) getGroupsFromProfileAPI(ctx context.Context, s *sessions.SessionState) ([]string, error) {
|
|
if s.AccessToken == "" {
|
|
return nil, fmt.Errorf("missing access token")
|
|
}
|
|
|
|
groupsURL := getMicrosoftGraphGroupsURL(p.ProfileURL, p.GraphGroupField).String()
|
|
|
|
// Need and extra header while talking with MS Graph. For more context see
|
|
// https://docs.microsoft.com/en-us/graph/api/group-list-transitivememberof?view=graph-rest-1.0&tabs=http#request-headers
|
|
extraHeader := makeAzureHeader(s.AccessToken)
|
|
extraHeader.Add("ConsistencyLevel", "eventual")
|
|
|
|
var groups []string
|
|
|
|
for groupsURL != "" {
|
|
jsonRequest, err := requests.New(groupsURL).
|
|
WithContext(ctx).
|
|
WithHeaders(extraHeader).
|
|
Do().
|
|
UnmarshalSimpleJSON()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to unmarshal Microsoft Graph response: %v", err)
|
|
|
|
}
|
|
groupsURL, err = jsonRequest.Get("@odata.nextLink").String()
|
|
if err != nil {
|
|
groupsURL = ""
|
|
}
|
|
groupsPage := getGroupsFromJSON(jsonRequest, p.GraphGroupField)
|
|
groups = append(groups, groupsPage...)
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func getGroupsFromJSON(json *simplejson.Json, graphGroupField string) []string {
|
|
groups := []string{}
|
|
|
|
for i := range json.Get("value").MustArray() {
|
|
value := json.Get("value").GetIndex(i).Get(graphGroupField).MustString()
|
|
groups = append(groups, value)
|
|
}
|
|
|
|
return groups
|
|
}
|
|
|
|
func (p *AzureProvider) getEmailFromProfileAPI(ctx context.Context, accessToken string) (string, error) {
|
|
if accessToken == "" {
|
|
return "", fmt.Errorf("missing access token")
|
|
}
|
|
|
|
json, err := requests.New(p.ProfileURL.String()).
|
|
WithContext(ctx).
|
|
WithHeaders(makeAzureHeader(accessToken)).
|
|
Do().
|
|
UnmarshalSimpleJSON()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
email, err := getEmailFromJSON(json)
|
|
if email == "" {
|
|
return "", fmt.Errorf("empty email address: %v", err)
|
|
}
|
|
return email, nil
|
|
}
|
|
|
|
func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
|
email, err := json.Get("mail").String()
|
|
|
|
if err != nil || email == "" {
|
|
otherMails, otherMailsErr := json.Get("otherMails").Array()
|
|
if len(otherMails) > 0 {
|
|
email = otherMails[0].(string)
|
|
}
|
|
err = otherMailsErr
|
|
}
|
|
|
|
if err != nil || email == "" {
|
|
email, err = json.Get("userPrincipalName").String()
|
|
if err != nil {
|
|
logger.Errorf("unable to find userPrincipalName: %s", err)
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
return email, nil
|
|
}
|
|
|
|
// ValidateSession validates the AccessToken
|
|
func (p *AzureProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
|
return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken))
|
|
}
|