Migrate all requests to new builder pattern
This commit is contained in:
parent
21ef86b594
commit
53142455b6
@ -83,34 +83,33 @@ func Validate(o *options.Options) error {
|
||||
|
||||
logger.Printf("Performing OIDC Discovery...")
|
||||
|
||||
if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil {
|
||||
if body, err := requests.Request(req); err == nil {
|
||||
|
||||
// Prefer manually configured URLs. It's a bit unclear
|
||||
// why you'd be doing discovery and also providing the URLs
|
||||
// explicitly though...
|
||||
if o.LoginURL == "" {
|
||||
o.LoginURL = body.Get("authorization_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.RedeemURL == "" {
|
||||
o.RedeemURL = body.Get("token_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.OIDCJwksURL == "" {
|
||||
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
|
||||
}
|
||||
|
||||
if o.ProfileURL == "" {
|
||||
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
|
||||
}
|
||||
|
||||
o.SkipOIDCDiscovery = true
|
||||
} else {
|
||||
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
||||
}
|
||||
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
|
||||
body, err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
||||
} else {
|
||||
logger.Printf("error: failed parsing OIDC discovery URL: %v", err)
|
||||
// Prefer manually configured URLs. It's a bit unclear
|
||||
// why you'd be doing discovery and also providing the URLs
|
||||
// explicitly though...
|
||||
if o.LoginURL == "" {
|
||||
o.LoginURL = body.Get("authorization_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.RedeemURL == "" {
|
||||
o.RedeemURL = body.Get("token_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.OIDCJwksURL == "" {
|
||||
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
|
||||
}
|
||||
|
||||
if o.ProfileURL == "" {
|
||||
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
|
||||
}
|
||||
|
||||
o.SkipOIDCDiscovery = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -385,10 +384,12 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
|
||||
if err != nil {
|
||||
// Try as JWKS URI
|
||||
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
|
||||
_, err := http.NewRequest("GET", jwksURI, nil)
|
||||
resp, err := requests.New(jwksURI).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
|
||||
} else {
|
||||
verifier = provider.Verifier(config)
|
||||
|
@ -3,10 +3,8 @@ package providers
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@ -91,39 +89,21 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresOn int64 `json:"expires_on,string"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
@ -169,26 +149,21 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getAzureHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
|
||||
json, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getAzureHeader(s.AccessToken)).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
email, err = getEmailFromJSON(json)
|
||||
|
||||
if err == nil && email != "" {
|
||||
return email, err
|
||||
}
|
||||
|
||||
email, err = json.Get("userPrincipalName").String()
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
@ -85,15 +84,13 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
FullName string `json:"full_name"`
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||
|
||||
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
UnmarshalInto(&emails)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
err = requests.RequestJSON(req, &emails)
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
logger.Printf("failed making request: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
@ -101,15 +98,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
teamURL := &url.URL{}
|
||||
*teamURL = *p.ValidateURL
|
||||
teamURL.Path = "/2.0/teams"
|
||||
req, err = http.NewRequestWithContext(ctx, "GET",
|
||||
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
|
||||
|
||||
requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken
|
||||
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
UnmarshalInto(&teams)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
err = requests.RequestJSON(req, &teams)
|
||||
if err != nil {
|
||||
logger.Printf("failed requesting teams membership %s", err)
|
||||
logger.Printf("failed requesting teams membership: %v", err)
|
||||
return "", err
|
||||
}
|
||||
var found = false
|
||||
@ -129,20 +125,19 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
repositoriesURL := &url.URL{}
|
||||
*repositoriesURL = *p.ValidateURL
|
||||
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
|
||||
req, err = http.NewRequestWithContext(ctx, "GET",
|
||||
repositoriesURL.String()+"?role=contributor"+
|
||||
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+
|
||||
"&access_token="+s.AccessToken,
|
||||
nil)
|
||||
|
||||
requestURL := repositoriesURL.String() + "?role=contributor" +
|
||||
"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") +
|
||||
"&access_token=" + s.AccessToken
|
||||
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
UnmarshalInto(&repositories)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
err = requests.RequestJSON(req, &repositories)
|
||||
if err != nil {
|
||||
logger.Printf("failed checking repository access %s", err)
|
||||
logger.Printf("failed checking repository access: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
var found = false
|
||||
for _, repository := range repositories.Values {
|
||||
if p.Repository == repository.FullName {
|
||||
|
@ -60,13 +60,11 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getDigitalOceanHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
json, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -62,20 +62,21 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getFacebookHeader(s.AccessToken)
|
||||
|
||||
type result struct {
|
||||
Email string
|
||||
}
|
||||
var r result
|
||||
err = requests.RequestJSON(req, &r)
|
||||
|
||||
requestURL := p.ProfileURL.String() + "?fields=name,email"
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getFacebookHeader(s.AccessToken)).
|
||||
UnmarshalInto(&r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if r.Email == "" {
|
||||
return "", errors.New("no email")
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@ -15,6 +14,7 @@ import (
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
)
|
||||
|
||||
// GitHubProvider represents an GitHub based Identity Provider
|
||||
@ -111,27 +111,16 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
|
||||
RawQuery: params.Encode(),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf(
|
||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
var op orgsPage
|
||||
if err := json.Unmarshal(body, &op); err != nil {
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
UnmarshalInto(&op)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(op) == 0 {
|
||||
break
|
||||
}
|
||||
@ -187,9 +176,13 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
RawQuery: params.Encode(),
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
// bodyclose cannot detect that the body is being closed later in requests.Into,
|
||||
// so have to skip the linting for the next line.
|
||||
// nolint:bodyclose
|
||||
resp, err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return false, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf(
|
||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
var tp teamsPage
|
||||
if err := json.Unmarshal(body, &tp); err != nil {
|
||||
return false, fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
if err := requests.UnmarshalInto(resp, &tp); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(tp) == 0 {
|
||||
break
|
||||
@ -297,25 +278,12 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
|
||||
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf(
|
||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
var repo repository
|
||||
if err := json.Unmarshal(body, &repo); err != nil {
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
UnmarshalInto(&repo)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@ -337,26 +305,14 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
UnmarshalInto(&user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, stripToken(endpoint.String()), body)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if p.isVerifiedUser(user.Login) {
|
||||
return true, nil
|
||||
@ -372,12 +328,14 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
@ -440,28 +398,13 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(s.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||
UnmarshalInto(&emails)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &emails); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
|
||||
returnEmail := ""
|
||||
for _, email := range emails {
|
||||
@ -489,34 +432,14 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
|
||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not create new GET request: %v", err)
|
||||
}
|
||||
|
||||
req.Header = getGitHubHeader(s.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||
UnmarshalInto(&user)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
|
||||
// Now that we have the username we can check collaborator status
|
||||
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
|
||||
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {
|
||||
|
@ -2,15 +2,13 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -131,31 +129,13 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
|
||||
userInfoURL := *p.LoginURL
|
||||
userInfoURL.Path = "/oauth/userinfo"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user info request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform user info request: %v", err)
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read user info response: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var userInfo gitlabUserInfo
|
||||
err = json.Unmarshal(body, &userInfo)
|
||||
err := requests.New(userInfoURL.String()).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
UnmarshalInto(&userInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info: %v", err)
|
||||
return nil, fmt.Errorf("error getting user info: %v", err)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
|
@ -9,13 +9,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("code", code)
|
||||
params.Add("grant_type", "authorization_code")
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@ -145,10 +123,17 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
||||
if err != nil {
|
||||
return
|
||||
@ -283,38 +268,23 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("refresh_token", refreshToken)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &data)
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
UnmarshalInto(&data)
|
||||
if err != nil {
|
||||
return
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
token = data.AccessToken
|
||||
idToken = data.IDToken
|
||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||
|
@ -56,7 +56,11 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
|
||||
params := url.Values{"access_token": {accessToken}}
|
||||
endpoint = endpoint + "?" + params.Encode()
|
||||
}
|
||||
resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
|
||||
|
||||
resp, err := requests.New(endpoint).
|
||||
WithContext(ctx).
|
||||
WithHeaders(header).
|
||||
Do()
|
||||
if err != nil {
|
||||
logger.Printf("GET %s", stripToken(endpoint))
|
||||
logger.Printf("token validation request failed: %s", err)
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
@ -51,14 +50,10 @@ func (p *KeycloakProvider) SetGroup(group string) {
|
||||
}
|
||||
|
||||
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
json, err := requests.Request(req)
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
|
@ -58,13 +58,12 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getLinkedInHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
requestURL := p.ProfileURL.String() + "?format=json"
|
||||
json, err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getLinkedInHeader(s.AccessToken)).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
@ -128,51 +129,33 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) {
|
||||
// query the user info endpoint for user attributes
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
|
||||
return
|
||||
}
|
||||
|
||||
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
|
||||
// parse the user attributes from the data we got and make sure that
|
||||
// the email address has been validated.
|
||||
var emailData struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
err = json.Unmarshal(body, &emailData)
|
||||
|
||||
// query the user info endpoint for user attributes
|
||||
err := requests.New(userInfoEndpoint).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
UnmarshalInto(&emailData)
|
||||
if err != nil {
|
||||
return
|
||||
return "", err
|
||||
}
|
||||
if emailData.Email == "" {
|
||||
err = fmt.Errorf("missing email")
|
||||
return
|
||||
|
||||
email := emailData.Email
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("missing email")
|
||||
}
|
||||
email = emailData.Email
|
||||
|
||||
if !emailData.EmailVerified {
|
||||
err = fmt.Errorf("email %s not listed as verified", email)
|
||||
return
|
||||
return "", fmt.Errorf("email %s not listed as verified", email)
|
||||
}
|
||||
return
|
||||
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
@ -201,30 +184,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
params.Add("code", code)
|
||||
params.Add("grant_type", "authorization_code")
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the token from the body that we got from the token endpoint.
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@ -232,9 +191,14 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check nonce here
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
)
|
||||
|
||||
@ -31,18 +30,14 @@ func getNextcloudHeader(accessToken string) http.Header {
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
||||
p.ValidateURL.String(), nil)
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getNextcloudHeader(s.AccessToken)).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
req.Header = getNextcloudHeader(s.AccessToken)
|
||||
json, err := requests.Request(req)
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
return "", fmt.Errorf("error making request: %v", err)
|
||||
}
|
||||
|
||||
email, err := json.Get("ocs").Get("data").Get("email").String()
|
||||
return email, err
|
||||
}
|
||||
|
@ -256,13 +256,10 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
|
||||
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
||||
// contents at the profileURL contains the email.
|
||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header = getOIDCHeader(accessToken)
|
||||
|
||||
respJSON, err := requests.Request(req)
|
||||
respJSON, err := requests.New(profileURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getOIDCHeader(accessToken)).
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -7,13 +7,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
)
|
||||
|
||||
var _ Provider = (*ProviderData)(nil)
|
||||
@ -39,18 +39,16 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
resp, err := requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user