Migrate all requests to new builder pattern

This commit is contained in:
Joel Speed 2020-07-03 19:27:25 +01:00
parent 21ef86b594
commit 53142455b6
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
15 changed files with 194 additions and 399 deletions

View File

@ -83,34 +83,33 @@ func Validate(o *options.Options) error {
logger.Printf("Performing OIDC Discovery...") logger.Printf("Performing OIDC Discovery...")
if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil { requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
if body, err := requests.Request(req); err == nil { body, err := requests.New(requestURL).
WithContext(ctx).
// Prefer manually configured URLs. It's a bit unclear UnmarshalJSON()
// why you'd be doing discovery and also providing the URLs if err != nil {
// explicitly though... logger.Printf("error: failed to discover OIDC configuration: %v", err)
if o.LoginURL == "" {
o.LoginURL = body.Get("authorization_endpoint").MustString()
}
if o.RedeemURL == "" {
o.RedeemURL = body.Get("token_endpoint").MustString()
}
if o.OIDCJwksURL == "" {
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
}
if o.ProfileURL == "" {
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
}
o.SkipOIDCDiscovery = true
} else {
logger.Printf("error: failed to discover OIDC configuration: %v", err)
}
} else { } else {
logger.Printf("error: failed parsing OIDC discovery URL: %v", err) // Prefer manually configured URLs. It's a bit unclear
// why you'd be doing discovery and also providing the URLs
// explicitly though...
if o.LoginURL == "" {
o.LoginURL = body.Get("authorization_endpoint").MustString()
}
if o.RedeemURL == "" {
o.RedeemURL = body.Get("token_endpoint").MustString()
}
if o.OIDCJwksURL == "" {
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
}
if o.ProfileURL == "" {
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
}
o.SkipOIDCDiscovery = true
} }
} }
@ -385,10 +384,12 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
if err != nil { if err != nil {
// Try as JWKS URI // Try as JWKS URI
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
_, err := http.NewRequest("GET", jwksURI, nil) resp, err := requests.New(jwksURI).Do()
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp.Body.Close()
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
} else { } else {
verifier = provider.Verifier(config) verifier = provider.Verifier(config)

View File

@ -3,10 +3,8 @@ package providers
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
@ -91,39 +89,21 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String()) params.Add("resource", p.ProtectedResource.String())
} }
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct { var jsonResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresOn int64 `json:"expires_on,string"` ExpiresOn int64 `json:"expires_on,string"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return return nil, err
} }
created := time.Now() created := time.Now()
@ -169,26 +149,21 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getAzureHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil { if err != nil {
return "", err return "", err
} }
email, err = getEmailFromJSON(json) email, err = getEmailFromJSON(json)
if err == nil && email != "" { if err == nil && email != "" {
return email, err return email, err
} }
email, err = json.Get("userPrincipalName").String() email, err = json.Get("userPrincipalName").String()
if err != nil { if err != nil {
logger.Printf("failed making request %s", err) logger.Printf("failed making request %s", err)
return "", err return "", err

View File

@ -2,7 +2,6 @@ package providers
import ( import (
"context" "context"
"net/http"
"net/url" "net/url"
"strings" "strings"
@ -85,15 +84,13 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
FullName string `json:"full_name"` FullName string `json:"full_name"`
} }
} }
req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
UnmarshalInto(&emails)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed making request: %v", err)
return "", err
}
err = requests.RequestJSON(req, &emails)
if err != nil {
logger.Printf("failed making request %s", err)
return "", err return "", err
} }
@ -101,15 +98,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
teamURL := &url.URL{} teamURL := &url.URL{}
*teamURL = *p.ValidateURL *teamURL = *p.ValidateURL
teamURL.Path = "/2.0/teams" teamURL.Path = "/2.0/teams"
req, err = http.NewRequestWithContext(ctx, "GET",
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
UnmarshalInto(&teams)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed requesting teams membership: %v", err)
return "", err
}
err = requests.RequestJSON(req, &teams)
if err != nil {
logger.Printf("failed requesting teams membership %s", err)
return "", err return "", err
} }
var found = false var found = false
@ -129,20 +125,19 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
repositoriesURL := &url.URL{} repositoriesURL := &url.URL{}
*repositoriesURL = *p.ValidateURL *repositoriesURL = *p.ValidateURL
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
req, err = http.NewRequestWithContext(ctx, "GET",
repositoriesURL.String()+"?role=contributor"+ requestURL := repositoriesURL.String() + "?role=contributor" +
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ "&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") +
"&access_token="+s.AccessToken, "&access_token=" + s.AccessToken
nil)
err := requests.New(requestURL).
WithContext(ctx).
UnmarshalInto(&repositories)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed checking repository access: %v", err)
return "", err
}
err = requests.RequestJSON(req, &repositories)
if err != nil {
logger.Printf("failed checking repository access %s", err)
return "", err return "", err
} }
var found = false var found = false
for _, repository := range repositories.Values { for _, repository := range repositories.Values {
if p.Repository == repository.FullName { if p.Repository == repository.FullName {

View File

@ -60,13 +60,11 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
}
req.Header = getDigitalOceanHeader(s.AccessToken)
json, err := requests.Request(req) json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -62,20 +62,21 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
if err != nil {
return "", err
}
req.Header = getFacebookHeader(s.AccessToken)
type result struct { type result struct {
Email string Email string
} }
var r result var r result
err = requests.RequestJSON(req, &r)
requestURL := p.ProfileURL.String() + "?fields=name,email"
err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getFacebookHeader(s.AccessToken)).
UnmarshalInto(&r)
if err != nil { if err != nil {
return "", err return "", err
} }
if r.Email == "" { if r.Email == "" {
return "", errors.New("no email") return "", errors.New("no email")
} }

View File

@ -2,7 +2,6 @@ package providers
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -15,6 +14,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
) )
// GitHubProvider represents an GitHub based Identity Provider // GitHubProvider represents an GitHub based Identity Provider
@ -111,27 +111,16 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
Path: path.Join(p.ValidateURL.Path, "/user/orgs"), Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
RawQuery: params.Encode(), RawQuery: params.Encode(),
} }
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var op orgsPage var op orgsPage
if err := json.Unmarshal(body, &op); err != nil { err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
UnmarshalInto(&op)
if err != nil {
return false, err return false, err
} }
if len(op) == 0 { if len(op) == 0 {
break break
} }
@ -187,9 +176,13 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
RawQuery: params.Encode(), RawQuery: params.Encode(),
} }
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) // bodyclose cannot detect that the body is being closed later in requests.Into,
req.Header = getGitHubHeader(accessToken) // so have to skip the linting for the next line.
resp, err := http.DefaultClient.Do(req) // nolint:bodyclose
resp, err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if err != nil { if err != nil {
return false, err return false, err
} }
@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
} }
} }
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
return false, err
}
resp.Body.Close()
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var tp teamsPage var tp teamsPage
if err := json.Unmarshal(body, &tp); err != nil { if err := requests.UnmarshalInto(resp, &tp); err != nil {
return false, fmt.Errorf("%s unmarshaling %s", err, body) return false, err
} }
if len(tp) == 0 { if len(tp) == 0 {
break break
@ -297,25 +278,12 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo), Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
} }
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var repo repository var repo repository
if err := json.Unmarshal(body, &repo); err != nil { err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
UnmarshalInto(&repo)
if err != nil {
return false, err return false, err
} }
@ -337,26 +305,14 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
Host: p.ValidateURL.Host, Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"), Path: path.Join(p.ValidateURL.Path, "/user"),
} }
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
UnmarshalInto(&user)
if err != nil { if err != nil {
return false, err return false, err
} }
if resp.StatusCode != 200 {
return false, fmt.Errorf("got %d from %q %s",
resp.StatusCode, stripToken(endpoint.String()), body)
}
if err := json.Unmarshal(body, &user); err != nil {
return false, err
}
if p.isVerifiedUser(user.Login) { if p.isVerifiedUser(user.Login) {
return true, nil return true, nil
@ -372,12 +328,14 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
Host: p.ValidateURL.Host, Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
} }
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) resp, err := requests.New(endpoint.String()).
req.Header = getGitHubHeader(accessToken) WithContext(ctx).
resp, err := http.DefaultClient.Do(req) WithHeaders(getGitHubHeader(accessToken)).
Do()
if err != nil { if err != nil {
return false, err return false, err
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
@ -440,28 +398,13 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
Host: p.ValidateURL.Host, Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user/emails"), Path: path.Join(p.ValidateURL.Path, "/user/emails"),
} }
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) err := requests.New(endpoint.String()).
req.Header = getGitHubHeader(s.AccessToken) WithContext(ctx).
resp, err := http.DefaultClient.Do(req) WithHeaders(getGitHubHeader(s.AccessToken)).
UnmarshalInto(&emails)
if err != nil { if err != nil {
return "", err return "", err
} }
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
returnEmail := "" returnEmail := ""
for _, email := range emails { for _, email := range emails {
@ -489,34 +432,14 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
Path: path.Join(p.ValidateURL.Path, "/user"), Path: path.Join(p.ValidateURL.Path, "/user"),
} }
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) err := requests.New(endpoint.String()).
if err != nil { WithContext(ctx).
return "", fmt.Errorf("could not create new GET request: %v", err) WithHeaders(getGitHubHeader(s.AccessToken)).
} UnmarshalInto(&user)
req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", err return "", err
} }
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
// Now that we have the username we can check collaborator status // Now that we have the username we can check collaborator status
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {

View File

@ -2,15 +2,13 @@ package providers
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http"
"strings" "strings"
"time" "time"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -131,31 +129,13 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
userInfoURL := *p.LoginURL userInfoURL := *p.LoginURL
userInfoURL.Path = "/oauth/userinfo" userInfoURL.Path = "/oauth/userinfo"
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create user info request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform user info request: %v", err)
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read user info response: %v", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body)
}
var userInfo gitlabUserInfo var userInfo gitlabUserInfo
err = json.Unmarshal(body, &userInfo) err := requests.New(userInfoURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
UnmarshalInto(&userInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse user info: %v", err) return nil, fmt.Errorf("error getting user info: %v", err)
} }
return &userInfo, nil return &userInfo, nil

View File

@ -9,13 +9,13 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
params.Add("client_secret", clientSecret) params.Add("client_secret", clientSecret)
params.Add("code", code) params.Add("code", code)
params.Add("grant_type", "authorization_code") params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct { var jsonResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
@ -145,10 +123,17 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return return nil, err
} }
c, err := claimsFromIDToken(jsonResponse.IDToken) c, err := claimsFromIDToken(jsonResponse.IDToken)
if err != nil { if err != nil {
return return
@ -283,38 +268,23 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
params.Add("client_secret", clientSecret) params.Add("client_secret", clientSecret)
params.Add("refresh_token", refreshToken) params.Add("refresh_token", refreshToken)
params.Add("grant_type", "refresh_token") params.Add("grant_type", "refresh_token")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var data struct { var data struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
err = json.Unmarshal(body, &data)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&data)
if err != nil { if err != nil {
return return "", "", 0, err
} }
token = data.AccessToken token = data.AccessToken
idToken = data.IDToken idToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second expires = time.Duration(data.ExpiresIn) * time.Second

View File

@ -56,7 +56,11 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
params := url.Values{"access_token": {accessToken}} params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode() endpoint = endpoint + "?" + params.Encode()
} }
resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
resp, err := requests.New(endpoint).
WithContext(ctx).
WithHeaders(header).
Do()
if err != nil { if err != nil {
logger.Printf("GET %s", stripToken(endpoint)) logger.Printf("GET %s", stripToken(endpoint))
logger.Printf("token validation request failed: %s", err) logger.Printf("token validation request failed: %s", err)

View File

@ -2,7 +2,6 @@ package providers
import ( import (
"context" "context"
"net/http"
"net/url" "net/url"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
@ -51,14 +50,10 @@ func (p *KeycloakProvider) SetGroup(group string) {
} }
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
json, err := requests.New(p.ValidateURL.String()).
req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) WithContext(ctx).
req.Header.Set("Authorization", "Bearer "+s.AccessToken) SetHeader("Authorization", "Bearer "+s.AccessToken).
if err != nil { UnmarshalJSON()
logger.Printf("failed building request %s", err)
return "", err
}
json, err := requests.Request(req)
if err != nil { if err != nil {
logger.Printf("failed making request %s", err) logger.Printf("failed making request %s", err)
return "", err return "", err

View File

@ -58,13 +58,12 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
if err != nil {
return "", err
}
req.Header = getLinkedInHeader(s.AccessToken)
json, err := requests.Request(req) requestURL := p.ProfileURL.String() + "?format=json"
json, err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getLinkedInHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
@ -128,51 +129,33 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
return return
} }
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
// query the user info endpoint for user attributes
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
if err != nil {
return
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
return
}
// parse the user attributes from the data we got and make sure that // parse the user attributes from the data we got and make sure that
// the email address has been validated. // the email address has been validated.
var emailData struct { var emailData struct {
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
} }
err = json.Unmarshal(body, &emailData)
// query the user info endpoint for user attributes
err := requests.New(userInfoEndpoint).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
UnmarshalInto(&emailData)
if err != nil { if err != nil {
return return "", err
} }
if emailData.Email == "" {
err = fmt.Errorf("missing email") email := emailData.Email
return if email == "" {
return "", fmt.Errorf("missing email")
} }
email = emailData.Email
if !emailData.EmailVerified { if !emailData.EmailVerified {
err = fmt.Errorf("email %s not listed as verified", email) return "", fmt.Errorf("email %s not listed as verified", email)
return
} }
return
return email, nil
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
@ -201,30 +184,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
params.Add("code", code) params.Add("code", code)
params.Add("grant_type", "authorization_code") params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
// Get the token from the body that we got from the token endpoint. // Get the token from the body that we got from the token endpoint.
var jsonResponse struct { var jsonResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
@ -232,9 +191,14 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
} }
err = json.Unmarshal(body, &jsonResponse) err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
UnmarshalInto(&jsonResponse)
if err != nil { if err != nil {
return return nil, err
} }
// check nonce here // check nonce here

View File

@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
) )
@ -31,18 +30,14 @@ func getNextcloudHeader(accessToken string) http.Header {
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", json, err := requests.New(p.ValidateURL.String()).
p.ValidateURL.String(), nil) WithContext(ctx).
WithHeaders(getNextcloudHeader(s.AccessToken)).
UnmarshalJSON()
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) return "", fmt.Errorf("error making request: %v", err)
return "", err
}
req.Header = getNextcloudHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
logger.Printf("failed making request %s", err)
return "", err
} }
email, err := json.Get("ocs").Get("data").Get("email").String() email, err := json.Get("ocs").Get("data").Get("email").String()
return email, err return email, err
} }

View File

@ -256,13 +256,10 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
// contents at the profileURL contains the email. // contents at the profileURL contains the email.
// Make a query to the userinfo endpoint, and attempt to locate the email from there. // Make a query to the userinfo endpoint, and attempt to locate the email from there.
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) respJSON, err := requests.New(profileURL).
if err != nil { WithContext(ctx).
return nil, err WithHeaders(getOIDCHeader(accessToken)).
} UnmarshalJSON()
req.Header = getOIDCHeader(accessToken)
respJSON, err := requests.Request(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,13 +7,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"time" "time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
) )
var _ Provider = (*ProviderData)(nil) var _ Provider = (*ProviderData)(nil)
@ -39,18 +39,16 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String()) params.Add("resource", p.ProtectedResource.String())
} }
var req *http.Request resp, err := requests.New(p.RedeemURL.String()).
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) WithContext(ctx).
if err != nil { WithMethod("POST").
return WithBody(bytes.NewBufferString(params.Encode())).
} SetHeader("Content-Type", "application/x-www-form-urlencoded").
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") Do()
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var body []byte var body []byte
body, err = ioutil.ReadAll(resp.Body) body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()