Migrate all requests to new builder pattern

This commit is contained in:
Joel Speed 2020-07-03 19:27:25 +01:00
parent 21ef86b594
commit 53142455b6
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
15 changed files with 194 additions and 399 deletions

View File

@ -83,34 +83,33 @@ func Validate(o *options.Options) error {
logger.Printf("Performing OIDC Discovery...")
if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil {
if body, err := requests.Request(req); err == nil {
// Prefer manually configured URLs. It's a bit unclear
// why you'd be doing discovery and also providing the URLs
// explicitly though...
if o.LoginURL == "" {
o.LoginURL = body.Get("authorization_endpoint").MustString()
}
if o.RedeemURL == "" {
o.RedeemURL = body.Get("token_endpoint").MustString()
}
if o.OIDCJwksURL == "" {
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
}
if o.ProfileURL == "" {
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
}
o.SkipOIDCDiscovery = true
} else {
logger.Printf("error: failed to discover OIDC configuration: %v", err)
}
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
body, err := requests.New(requestURL).
WithContext(ctx).
UnmarshalJSON()
if err != nil {
logger.Printf("error: failed to discover OIDC configuration: %v", err)
} else {
logger.Printf("error: failed parsing OIDC discovery URL: %v", err)
// Prefer manually configured URLs. It's a bit unclear
// why you'd be doing discovery and also providing the URLs
// explicitly though...
if o.LoginURL == "" {
o.LoginURL = body.Get("authorization_endpoint").MustString()
}
if o.RedeemURL == "" {
o.RedeemURL = body.Get("token_endpoint").MustString()
}
if o.OIDCJwksURL == "" {
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
}
if o.ProfileURL == "" {
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
}
o.SkipOIDCDiscovery = true
}
}
@ -385,10 +384,12 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
if err != nil {
// Try as JWKS URI
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
_, err := http.NewRequest("GET", jwksURI, nil)
resp, err := requests.New(jwksURI).Do()
if err != nil {
return nil, err
}
resp.Body.Close()
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
} else {
verifier = provider.Verifier(config)

View File

@ -3,10 +3,8 @@ package providers
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
@ -91,39 +89,21 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String())
}
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresOn int64 `json:"expires_on,string"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&jsonResponse)
if err != nil {
return
return nil, err
}
created := time.Now()
@ -169,26 +149,21 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getAzureHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil {
return "", err
}
email, err = getEmailFromJSON(json)
if err == nil && email != "" {
return email, err
}
email, err = json.Get("userPrincipalName").String()
if err != nil {
logger.Printf("failed making request %s", err)
return "", err

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"net/http"
"net/url"
"strings"
@ -85,15 +84,13 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
FullName string `json:"full_name"`
}
}
req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
UnmarshalInto(&emails)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
err = requests.RequestJSON(req, &emails)
if err != nil {
logger.Printf("failed making request %s", err)
logger.Printf("failed making request: %v", err)
return "", err
}
@ -101,15 +98,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
teamURL := &url.URL{}
*teamURL = *p.ValidateURL
teamURL.Path = "/2.0/teams"
req, err = http.NewRequestWithContext(ctx, "GET",
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
UnmarshalInto(&teams)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
err = requests.RequestJSON(req, &teams)
if err != nil {
logger.Printf("failed requesting teams membership %s", err)
logger.Printf("failed requesting teams membership: %v", err)
return "", err
}
var found = false
@ -129,20 +125,19 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
repositoriesURL := &url.URL{}
*repositoriesURL = *p.ValidateURL
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
req, err = http.NewRequestWithContext(ctx, "GET",
repositoriesURL.String()+"?role=contributor"+
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+
"&access_token="+s.AccessToken,
nil)
requestURL := repositoriesURL.String() + "?role=contributor" +
"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") +
"&access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
UnmarshalInto(&repositories)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
err = requests.RequestJSON(req, &repositories)
if err != nil {
logger.Printf("failed checking repository access %s", err)
logger.Printf("failed checking repository access: %v", err)
return "", err
}
var found = false
for _, repository := range repositories.Values {
if p.Repository == repository.FullName {

View File

@ -60,13 +60,11 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
}
req.Header = getDigitalOceanHeader(s.AccessToken)
json, err := requests.Request(req)
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil {
return "", err
}

View File

@ -62,20 +62,21 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
if err != nil {
return "", err
}
req.Header = getFacebookHeader(s.AccessToken)
type result struct {
Email string
}
var r result
err = requests.RequestJSON(req, &r)
requestURL := p.ProfileURL.String() + "?fields=name,email"
err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getFacebookHeader(s.AccessToken)).
UnmarshalInto(&r)
if err != nil {
return "", err
}
if r.Email == "" {
return "", errors.New("no email")
}

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
@ -15,6 +14,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
)
// GitHubProvider represents an GitHub based Identity Provider
@ -111,27 +111,16 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
RawQuery: params.Encode(),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var op orgsPage
if err := json.Unmarshal(body, &op); err != nil {
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
UnmarshalInto(&op)
if err != nil {
return false, err
}
if len(op) == 0 {
break
}
@ -187,9 +176,13 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
RawQuery: params.Encode(),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
// bodyclose cannot detect that the body is being closed later in requests.Into,
// so have to skip the linting for the next line.
// nolint:bodyclose
resp, err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if err != nil {
return false, err
}
@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
}
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
return false, err
}
resp.Body.Close()
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var tp teamsPage
if err := json.Unmarshal(body, &tp); err != nil {
return false, fmt.Errorf("%s unmarshaling %s", err, body)
if err := requests.UnmarshalInto(resp, &tp); err != nil {
return false, err
}
if len(tp) == 0 {
break
@ -297,25 +278,12 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var repo repository
if err := json.Unmarshal(body, &repo); err != nil {
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
UnmarshalInto(&repo)
if err != nil {
return false, err
}
@ -337,26 +305,14 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
UnmarshalInto(&user)
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf("got %d from %q %s",
resp.StatusCode, stripToken(endpoint.String()), body)
}
if err := json.Unmarshal(body, &user); err != nil {
return false, err
}
if p.isVerifiedUser(user.Login) {
return true, nil
@ -372,12 +328,14 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
resp, err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
@ -440,28 +398,13 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req)
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)).
UnmarshalInto(&emails)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
returnEmail := ""
for _, email := range emails {
@ -489,34 +432,14 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}
req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req)
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)).
UnmarshalInto(&user)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
// Now that we have the username we can check collaborator status
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {

View File

@ -2,15 +2,13 @@ package providers
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
oidc "github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"golang.org/x/oauth2"
)
@ -131,31 +129,13 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
userInfoURL := *p.LoginURL
userInfoURL.Path = "/oauth/userinfo"
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create user info request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform user info request: %v", err)
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read user info response: %v", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body)
}
var userInfo gitlabUserInfo
err = json.Unmarshal(body, &userInfo)
err := requests.New(userInfoURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
UnmarshalInto(&userInfo)
if err != nil {
return nil, fmt.Errorf("failed to parse user info: %v", err)
return nil, fmt.Errorf("error getting user info: %v", err)
}
return &userInfo, nil

View File

@ -9,13 +9,13 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
params.Add("client_secret", clientSecret)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
@ -145,10 +123,17 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&jsonResponse)
if err != nil {
return
return nil, err
}
c, err := claimsFromIDToken(jsonResponse.IDToken)
if err != nil {
return
@ -283,38 +268,23 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
params.Add("client_secret", clientSecret)
params.Add("refresh_token", refreshToken)
params.Add("grant_type", "refresh_token")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var data struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &data)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&data)
if err != nil {
return
return "", "", 0, err
}
token = data.AccessToken
idToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second

View File

@ -56,7 +56,11 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode()
}
resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
resp, err := requests.New(endpoint).
WithContext(ctx).
WithHeaders(header).
Do()
if err != nil {
logger.Printf("GET %s", stripToken(endpoint))
logger.Printf("token validation request failed: %s", err)

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"net/http"
"net/url"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
@ -51,14 +50,10 @@ func (p *KeycloakProvider) SetGroup(group string) {
}
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil)
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
json, err := requests.Request(req)
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
UnmarshalJSON()
if err != nil {
logger.Printf("failed making request %s", err)
return "", err

View File

@ -58,13 +58,12 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
if err != nil {
return "", err
}
req.Header = getLinkedInHeader(s.AccessToken)
json, err := requests.Request(req)
requestURL := p.ProfileURL.String() + "?format=json"
json, err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getLinkedInHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil {
return "", err
}

View File

@ -15,6 +15,7 @@ import (
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"gopkg.in/square/go-jose.v2"
)
@ -128,51 +129,33 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
return
}
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) {
// query the user info endpoint for user attributes
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
if err != nil {
return
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
return
}
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
// parse the user attributes from the data we got and make sure that
// the email address has been validated.
var emailData struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
err = json.Unmarshal(body, &emailData)
// query the user info endpoint for user attributes
err := requests.New(userInfoEndpoint).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
UnmarshalInto(&emailData)
if err != nil {
return
return "", err
}
if emailData.Email == "" {
err = fmt.Errorf("missing email")
return
email := emailData.Email
if email == "" {
return "", fmt.Errorf("missing email")
}
email = emailData.Email
if !emailData.EmailVerified {
err = fmt.Errorf("email %s not listed as verified", email)
return
return "", fmt.Errorf("email %s not listed as verified", email)
}
return
return email, nil
}
// Redeem exchanges the OAuth2 authentication token for an ID token
@ -201,30 +184,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
// Get the token from the body that we got from the token endpoint.
var jsonResponse struct {
AccessToken string `json:"access_token"`
@ -232,9 +191,14 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&jsonResponse)
if err != nil {
return
return nil, err
}
// check nonce here

View File

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
)
@ -31,18 +30,14 @@ func getNextcloudHeader(accessToken string) http.Header {
// GetEmailAddress returns the Account email address
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String(), nil)
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
WithHeaders(getNextcloudHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
req.Header = getNextcloudHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
logger.Printf("failed making request %s", err)
return "", err
return "", fmt.Errorf("error making request: %v", err)
}
email, err := json.Get("ocs").Get("data").Get("email").String()
return email, err
}

View File

@ -256,13 +256,10 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
// contents at the profileURL contains the email.
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
if err != nil {
return nil, err
}
req.Header = getOIDCHeader(accessToken)
respJSON, err := requests.Request(req)
respJSON, err := requests.New(profileURL).
WithContext(ctx).
WithHeaders(getOIDCHeader(accessToken)).
UnmarshalJSON()
if err != nil {
return nil, err
}

View File

@ -7,13 +7,13 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
)
var _ Provider = (*ProviderData)(nil)
@ -39,18 +39,16 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String())
}
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
resp, err := requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do()
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()