Merge branch 'master' into fix-set-basic-default

This commit is contained in:
Dan Bond 2020-04-14 09:37:11 +01:00 committed by GitHub
commit 5fc6bd0e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 88 additions and 79 deletions

View File

@ -9,5 +9,27 @@ linters:
- deadcode - deadcode
- gofmt - gofmt
- goimports - goimports
enable-all: false - gosimple
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
- bodyclose
- dogsled
- goprintffuncname
- misspell
- prealloc
- scopelint
- stylecheck
- unconvert
- gocritic
disable-all: true disable-all: true
issues:
exclude-rules:
- path: _test\.go
linters:
- scopelint
- bodyclose
- unconvert
- gocritic

View File

@ -19,6 +19,7 @@
## Changes since v5.1.0 ## Changes since v5.1.0
- [#486](https://github.com/oauth2-proxy/oauth2-proxy/pull/486) Add new linters (@johejo)
- [#440](https://github.com/oauth2-proxy/oauth2-proxy/pull/440) Switch Azure AD Graph API to Microsoft Graph API (@johejo) - [#440](https://github.com/oauth2-proxy/oauth2-proxy/pull/440) Switch Azure AD Graph API to Microsoft Graph API (@johejo)
- [#453](https://github.com/oauth2-proxy/oauth2-proxy/pull/453) Prevent browser caching during auth flow (@johejo) - [#453](https://github.com/oauth2-proxy/oauth2-proxy/pull/453) Prevent browser caching during auth flow (@johejo)
- [#481](https://github.com/oauth2-proxy/oauth2-proxy/pull/481) Update Okta docs (@trevorbox) - [#481](https://github.com/oauth2-proxy/oauth2-proxy/pull/481) Update Okta docs (@trevorbox)

View File

@ -324,8 +324,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
if p.redirectURL.Host != "" { if p.redirectURL.Host != "" {
return p.redirectURL.String() return p.redirectURL.String()
} }
var u url.URL u := *p.redirectURL
u = *p.redirectURL
if u.Scheme == "" { if u.Scheme == "" {
if p.CookieSecure { if p.CookieSecure {
u.Scheme = httpsScheme u.Scheme = httpsScheme
@ -695,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
if ok { if ok {
session := &sessionsapi.SessionState{User: user} session := &sessionsapi.SessionState{User: user}
p.SaveSession(rw, req, session) p.SaveSession(rw, req, session)
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, http.StatusFound)
} else { } else {
if p.SkipProviderButton { if p.SkipProviderButton {
p.OAuthStart(rw, req) p.OAuthStart(rw, req)
@ -734,7 +733,7 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
return return
} }
p.ClearSessionCookie(rw, req) p.ClearSessionCookie(rw, req)
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, http.StatusFound)
} }
// OAuthStart starts the OAuth2 authentication flow // OAuthStart starts the OAuth2 authentication flow
@ -754,7 +753,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
return return
} }
redirectURI := p.GetRedirectURI(req.Host) redirectURI := p.GetRedirectURI(req.Host)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
} }
// OAuthCallback is the OAuth2 authentication flow callback that finishes the // OAuthCallback is the OAuth2 authentication flow callback that finishes the
@ -817,7 +816,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, 500, "Internal Error", "Internal Error") p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return return
} }
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, http.StatusFound)
} else { } else {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized") logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized")
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
@ -1096,10 +1095,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionStat
// isAjax checks if a request is an ajax request // isAjax checks if a request is an ajax request
func isAjax(req *http.Request) bool { func isAjax(req *http.Request) bool {
acceptValues, ok := req.Header["accept"] acceptValues := req.Header.Values("Accept")
if !ok {
acceptValues = req.Header["Accept"]
}
const ajaxReq = applicationJSON const ajaxReq = applicationJSON
for _, v := range acceptValues { for _, v := range acceptValues {
if v == ajaxReq { if v == ajaxReq {

View File

@ -64,7 +64,6 @@ func TestWebSocketProxy(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err %s", err) t.Fatalf("err %s", err)
} }
return
}), }),
} }
backend := httptest.NewServer(&handler) backend := httptest.NewServer(&handler)
@ -426,7 +425,7 @@ func TestBasicAuthPassword(t *testing.T) {
if rw.Code >= 400 { if rw.Code >= 400 {
t.Fatalf("expected 3xx got %d", rw.Code) t.Fatalf("expected 3xx got %d", rw.Code)
} }
cookie := rw.HeaderMap["Set-Cookie"][1] cookie := rw.Header().Values("Set-Cookie")[1]
cookieName := proxy.CookieName cookieName := proxy.CookieName
var value string var value string
@ -614,7 +613,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
} }
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
patTest.proxy.ServeHTTP(rw, req) patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][1] return rw.Code, rw.Header().Values("Set-Cookie")[1]
} }
// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath // getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
@ -691,7 +690,7 @@ func TestStaticProxyUpstream(t *testing.T) {
} }
assert.NotEqual(t, nil, cookie) assert.NotEqual(t, nil, cookie)
// Now we make a regular request againts the upstream proxy; And validate // Now we make a regular request against the upstream proxy; And validate
// the returned status code through the static proxy. // the returned status code through the static proxy.
code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy") code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy")
if code != 200 { if code != 200 {
@ -824,8 +823,6 @@ type ProcessCookieTest struct {
proxy *OAuthProxy proxy *OAuthProxy
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
req *http.Request req *http.Request
provider TestProvider
responseCode int
validateUser bool validateUser bool
} }
@ -910,7 +907,7 @@ func TestProcessCookieNoCookieError(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
session, err := pcTest.LoadCookiedSession() session, err := pcTest.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) assert.Equal(t, "cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil { if session != nil {
t.Errorf("expected nil session. got %#v", session) t.Errorf("expected nil session. got %#v", session)
} }
@ -1072,8 +1069,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0]) assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User"))
assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email"))
} }
func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
@ -1103,10 +1100,10 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0]) assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:"+pcTest.opts.BasicAuthPassword)) expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:"+pcTest.opts.BasicAuthPassword))
assert.Equal(t, expectedHeader, pcTest.rw.HeaderMap["Authorization"][0]) assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0])
} }
func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
@ -1136,9 +1133,9 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0]) assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
assert.Equal(t, 0, len(pcTest.rw.HeaderMap["Authorization"]), "should not have Authorization header entries") assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries")
} }
func TestAuthSkippedForPreflightRequests(t *testing.T) { func TestAuthSkippedForPreflightRequests(t *testing.T) {
@ -1462,8 +1459,7 @@ type NoOpKeySet struct {
func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
splitStrings := strings.Split(jwt, ".") splitStrings := strings.Split(jwt, ".")
payloadString := splitStrings[1] payloadString := splitStrings[1]
jsonString, err := base64.RawURLEncoding.DecodeString(payloadString) return base64.RawURLEncoding.DecodeString(payloadString)
return []byte(jsonString), err
} }
func TestGetJwtSession(t *testing.T) { func TestGetJwtSession(t *testing.T) {

View File

@ -290,7 +290,7 @@ func (o *Options) Validate() error {
} }
} }
if o.PreferEmailToUser == true && o.PassBasicAuth == false && o.PassUserHeaders == false { if o.PreferEmailToUser && !o.PassBasicAuth && !o.PassUserHeaders {
msgs = append(msgs, "PreferEmailToUser should only be used with PassBasicAuth or PassUserHeaders") msgs = append(msgs, "PreferEmailToUser should only be used with PassBasicAuth or PassUserHeaders")
} }
@ -349,7 +349,7 @@ func (o *Options) Validate() error {
if string(secretBytes(o.CookieSecret)) != o.CookieSecret { if string(secretBytes(o.CookieSecret)) != o.CookieSecret {
decoded = true decoded = true
} }
if validCookieSecretSize == false { if !validCookieSecretSize {
var suffix string var suffix string
if decoded { if decoded {
suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret)
@ -414,7 +414,7 @@ func (o *Options) Validate() error {
msgs = setupLogger(o, msgs) msgs = setupLogger(o, msgs)
if len(msgs) != 0 { if len(msgs) != 0 {
return fmt.Errorf("Invalid configuration:\n %s", return fmt.Errorf("invalid configuration:\n %s",
strings.Join(msgs, "\n ")) strings.Join(msgs, "\n "))
} }
return nil return nil
@ -545,7 +545,7 @@ func parseSignatureKey(o *Options, msgs []string) []string {
// parseJwtIssuers takes in an array of strings in the form of issuer=audience // parseJwtIssuers takes in an array of strings in the form of issuer=audience
// and parses to an array of jwtIssuer structs. // and parses to an array of jwtIssuer structs.
func parseJwtIssuers(issuers []string, msgs []string) ([]jwtIssuer, []string) { func parseJwtIssuers(issuers []string, msgs []string) ([]jwtIssuer, []string) {
var parsedIssuers []jwtIssuer parsedIssuers := make([]jwtIssuer, 0, len(issuers))
for _, jwtVerifier := range issuers { for _, jwtVerifier := range issuers {
components := strings.Split(jwtVerifier, "=") components := strings.Split(jwtVerifier, "=")
if len(components) < 2 { if len(components) < 2 {

View File

@ -31,7 +31,7 @@ func testOptions() *Options {
func errorMsg(msgs []string) string { func errorMsg(msgs []string) string {
result := make([]string, 0) result := make([]string, 0)
result = append(result, "Invalid configuration:") result = append(result, "invalid configuration:")
result = append(result, msgs...) result = append(result, msgs...)
return strings.Join(result, "\n ") return strings.Join(result, "\n ")
} }
@ -278,7 +278,7 @@ func TestValidateSignatureKeyInvalidSpec(t *testing.T) {
o := testOptions() o := testOptions()
o.SignatureKey = "invalid spec" o.SignatureKey = "invalid spec"
err := o.Validate() err := o.Validate()
assert.Equal(t, err.Error(), "Invalid configuration:\n"+ assert.Equal(t, err.Error(), "invalid configuration:\n"+
" invalid signature hash:key spec: "+o.SignatureKey) " invalid signature hash:key spec: "+o.SignatureKey)
} }
@ -286,7 +286,7 @@ func TestValidateSignatureKeyUnsupportedAlgorithm(t *testing.T) {
o := testOptions() o := testOptions()
o.SignatureKey = "unsupported:default secret" o.SignatureKey = "unsupported:default secret"
err := o.Validate() err := o.Validate()
assert.Equal(t, err.Error(), "Invalid configuration:\n"+ assert.Equal(t, err.Error(), "invalid configuration:\n"+
" unsupported signature hash algorithm: "+o.SignatureKey) " unsupported signature hash algorithm: "+o.SignatureKey)
} }
@ -300,7 +300,7 @@ func TestValidateCookieBadName(t *testing.T) {
o := testOptions() o := testOptions()
o.CookieName = "_bad_cookie_name{}" o.CookieName = "_bad_cookie_name{}"
err := o.Validate() err := o.Validate()
assert.Equal(t, err.Error(), "Invalid configuration:\n"+ assert.Equal(t, err.Error(), "invalid configuration:\n"+
fmt.Sprintf(" invalid cookie name: %q", o.CookieName)) fmt.Sprintf(" invalid cookie name: %q", o.CookieName))
} }
@ -311,8 +311,8 @@ func TestSkipOIDCDiscovery(t *testing.T) {
o.SkipOIDCDiscovery = true o.SkipOIDCDiscovery = true
err := o.Validate() err := o.Validate()
assert.Equal(t, "Invalid configuration:\n"+ assert.Equal(t, "invalid configuration:\n"+
fmt.Sprintf(" missing setting: login-url\n missing setting: redeem-url\n missing setting: oidc-jwks-url"), err.Error()) " missing setting: login-url\n missing setting: redeem-url\n missing setting: oidc-jwks-url", err.Error())
o.LoginURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/oauth2/v2.0/authorize?p=b2c_1_sign_in" o.LoginURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/oauth2/v2.0/authorize?p=b2c_1_sign_in"
o.RedeemURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/oauth2/v2.0/token?p=b2c_1_sign_in" o.RedeemURL = "https://login.microsoftonline.com/fabrikamb2c.onmicrosoft.com/oauth2/v2.0/token?p=b2c_1_sign_in"

View File

@ -139,10 +139,10 @@ func (l *Logger) Output(calldepth int, message string) {
l.writer.Write([]byte("\n")) l.writer.Write([]byte("\n"))
} }
// PrintAuth writes auth info to the logger. Requires an http.Request to // PrintAuthf writes auth info to the logger. Requires an http.Request to
// log request details. Remaining arguments are handled in the manner of // log request details. Remaining arguments are handled in the manner of
// fmt.Sprintf. Writes a final newline to the end of every message. // fmt.Sprintf. Writes a final newline to the end of every message.
func (l *Logger) PrintAuth(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) { func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) {
if !l.authEnabled { if !l.authEnabled {
return return
} }
@ -166,7 +166,7 @@ func (l *Logger) PrintAuth(username string, req *http.Request, status AuthStatus
Timestamp: FormatTimestamp(now), Timestamp: FormatTimestamp(now),
UserAgent: fmt.Sprintf("%q", req.UserAgent()), UserAgent: fmt.Sprintf("%q", req.UserAgent()),
Username: username, Username: username,
Status: fmt.Sprintf("%s", status), Status: string(status),
Message: fmt.Sprintf(format, a...), Message: fmt.Sprintf(format, a...),
}) })
@ -185,7 +185,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
return return
} }
duration := float64(time.Now().Sub(ts)) / float64(time.Second) duration := float64(time.Since(ts)) / float64(time.Second)
if username == "" { if username == "" {
username = "-" username = "-"
@ -481,7 +481,7 @@ func Panicln(v ...interface{}) {
// PrintAuthf writes authentication details to the standard logger. // PrintAuthf writes authentication details to the standard logger.
// Arguments are handled in the manner of fmt.Printf. // Arguments are handled in the manner of fmt.Printf.
func PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) { func PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) {
std.PrintAuth(username, req, status, format, a...) std.PrintAuthf(username, req, status, format, a...)
} }
// PrintReq writes request details to the standard logger. // PrintReq writes request details to the standard logger.

View File

@ -88,9 +88,9 @@ func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
response, err := RequestUnparsedResponse( response, err := RequestUnparsedResponse(
backend.URL+"?access_token=my_token", nil) backend.URL+"?access_token=my_token", nil)
assert.Equal(t, nil, err)
defer response.Body.Close() defer response.Body.Close()
assert.Equal(t, nil, err)
assert.Equal(t, 200, response.StatusCode) assert.Equal(t, 200, response.StatusCode)
body, err := ioutil.ReadAll(response.Body) body, err := ioutil.ReadAll(response.Body)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
@ -124,9 +124,9 @@ func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Auth", "my_token") headers.Set("Auth", "my_token")
response, err := RequestUnparsedResponse(backend.URL, headers) response, err := RequestUnparsedResponse(backend.URL, headers)
assert.Equal(t, nil, err)
defer response.Body.Close() defer response.Body.Close()
assert.Equal(t, nil, err)
assert.Equal(t, 200, response.StatusCode) assert.Equal(t, 200, response.StatusCode)
body, err := ioutil.ReadAll(response.Body) body, err := ioutil.ReadAll(response.Body)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)

View File

@ -52,11 +52,11 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
c, err := loadCookie(req, s.CookieOptions.CookieName) c, err := loadCookie(req, s.CookieOptions.CookieName)
if err != nil { if err != nil {
// always http.ErrNoCookie // always http.ErrNoCookie
return nil, fmt.Errorf("Cookie %q not present", s.CookieOptions.CookieName) return nil, fmt.Errorf("cookie %q not present", s.CookieOptions.CookieName)
} }
val, _, ok := encryption.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire) val, _, ok := encryption.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire)
if !ok { if !ok {
return nil, errors.New("Cookie Signature not valid") return nil, errors.New("cookie signature not valid")
} }
session, err := utils.SessionFromCookie(val, s.CookieCipher) session, err := utils.SessionFromCookie(val, s.CookieCipher)
@ -69,8 +69,6 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
// Clear clears any saved session information by writing a cookie to // Clear clears any saved session information by writing a cookie to
// clear the session // clear the session
func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
var cookies []*http.Cookie
// matches CookieName, CookieName_<number> // matches CookieName, CookieName_<number>
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName)) var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName))
@ -79,7 +77,6 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clearCookie) http.SetCookie(rw, clearCookie)
cookies = append(cookies, clearCookie)
} }
} }
@ -174,7 +171,7 @@ func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
} }
} }
if len(cookies) == 0 { if len(cookies) == 0 {
return nil, fmt.Errorf("Could not find cookie %s", cookieName) return nil, fmt.Errorf("could not find cookie %s", cookieName)
} }
return joinCookies(cookies) return joinCookies(cookies)
} }

View File

@ -79,7 +79,7 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) {
return nil, fmt.Errorf("unable to parse redis url: %s", err) return nil, fmt.Errorf("unable to parse redis url: %s", err)
} }
if opts.RedisInsecureTLS != false { if opts.RedisInsecureTLS {
opt.TLSConfig.InsecureSkipVerify = true opt.TLSConfig.InsecureSkipVerify = true
} }
@ -149,7 +149,7 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro
val, _, ok := encryption.Validate(requestCookie, store.CookieOptions.CookieSecret, store.CookieOptions.CookieExpire) val, _, ok := encryption.Validate(requestCookie, store.CookieOptions.CookieSecret, store.CookieOptions.CookieExpire)
if !ok { if !ok {
return nil, fmt.Errorf("Cookie Signature not valid") return nil, fmt.Errorf("cookie signature not valid")
} }
ctx := req.Context() ctx := req.Context()
session, err := store.loadSessionFromString(ctx, val) session, err := store.loadSessionFromString(ctx, val)
@ -209,7 +209,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
val, _, ok := encryption.Validate(requestCookie, store.CookieOptions.CookieSecret, store.CookieOptions.CookieExpire) val, _, ok := encryption.Validate(requestCookie, store.CookieOptions.CookieSecret, store.CookieOptions.CookieExpire)
if !ok { if !ok {
return fmt.Errorf("Cookie Signature not valid") return fmt.Errorf("cookie signature not valid")
} }
// We only return an error if we had an issue with redis // We only return an error if we had an issue with redis

View File

@ -116,7 +116,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
break break
} }
} }
if found != true { if !found {
logger.Print("team membership test failed, access denied") logger.Print("team membership test failed, access denied")
return "", nil return "", nil
} }
@ -147,7 +147,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
break break
} }
} }
if found != true { if !found {
logger.Print("repository access test failed, access denied") logger.Print("repository access test failed, access denied")
return "", nil return "", nil
} }

View File

@ -122,7 +122,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
pn++ pn++
} }
var presentOrgs []string presentOrgs := make([]string, 0, len(orgs))
for _, org := range orgs { for _, org := range orgs {
if p.Org == org.Login { if p.Org == org.Login {
logger.Printf("Found Github Organization: %q", org.Login) logger.Printf("Found Github Organization: %q", org.Login)

View File

@ -222,11 +222,7 @@ func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.T
func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background() ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil { return err == nil
return false
}
return true
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address

View File

@ -2,6 +2,7 @@ package providers
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -15,10 +16,10 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
"google.golang.org/api/option"
) )
// GoogleProvider represents an Google based Identity Provider // GoogleProvider represents an Google based Identity Provider
@ -184,8 +185,9 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv
} }
conf.Subject = adminEmail conf.Subject = adminEmail
client := conf.Client(oauth2.NoContext) ctx := context.Background()
adminService, err := admin.New(client) client := conf.Client(ctx)
adminService, err := admin.NewService(ctx, option.WithHTTPClient(client))
if err != nil { if err != nil {
logger.Fatal(err) logger.Fatal(err)
} }

View File

@ -76,7 +76,7 @@ func (p *KeycloakProvider) GetEmailAddress(s *sessions.SessionState) (string, er
} }
} }
if found != true { if !found {
logger.Printf("group not found, access denied") logger.Printf("group not found, access denied")
return "", nil return "", nil
} }

View File

@ -183,7 +183,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
Issuer: p.ClientID, Issuer: p.ClientID,
Subject: p.ClientID, Subject: p.ClientID,
Audience: p.RedeemURL.String(), Audience: p.RedeemURL.String(),
ExpiresAt: int64(time.Now().Add(time.Duration(5 * time.Minute)).Unix()), ExpiresAt: time.Now().Add(5 * time.Minute).Unix(),
Id: randSeq(32), Id: randSeq(32),
} }
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
@ -260,8 +260,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
// GetLoginURL overrides GetLoginURL to add login.gov parameters // GetLoginURL overrides GetLoginURL to add login.gov parameters
func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string { func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string {
var a url.URL a := *p.LoginURL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery) params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI) params.Set("redirect_uri", redirectURI)
params.Set("approval_prompt", p.ApprovalPrompt) params.Set("approval_prompt", p.ApprovalPrompt)

View File

@ -180,11 +180,7 @@ func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDT
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background() ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil { return err == nil
return false
}
return true
} }
func getOIDCHeader(accessToken string) http.Header { func getOIDCHeader(accessToken string) http.Header {

View File

@ -62,6 +62,9 @@ type fakeKeySetStub struct{}
func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1]) decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
if err != nil {
return nil, err
}
tokenClaims := &idTokenClaims{} tokenClaims := &idTokenClaims{}
err = json.Unmarshal(decodeString, tokenClaims) err = json.Unmarshal(decodeString, tokenClaims)
@ -242,7 +245,9 @@ func TestOIDCProvider_findVerifiedIdToken(t *testing.T) {
verifiedIDToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIDToken) verifiedIDToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
assert.Equal(t, true, err == nil) assert.Equal(t, true, err == nil)
assert.Equal(t, true, verifiedIDToken != nil) if verifiedIDToken == nil {
t.Fatal("verifiedIDToken is nil")
}
assert.Equal(t, defaultIDToken.Issuer, verifiedIDToken.Issuer) assert.Equal(t, defaultIDToken.Issuer, verifiedIDToken.Issuer)
assert.Equal(t, defaultIDToken.Subject, verifiedIDToken.Subject) assert.Equal(t, defaultIDToken.Subject, verifiedIDToken.Subject)

View File

@ -31,7 +31,7 @@ type ProviderData struct {
// Data returns the ProviderData // Data returns the ProviderData
func (p *ProviderData) Data() *ProviderData { return p } func (p *ProviderData) Data() *ProviderData { return p }
func (p *ProviderData) GetClientSecret() (ClientSecret string, err error) { func (p *ProviderData) GetClientSecret() (clientSecret string, err error) {
if p.ClientSecret != "" || p.ClientSecretFile == "" { if p.ClientSecret != "" || p.ClientSecretFile == "" {
return p.ClientSecret, nil return p.ClientSecret, nil
} }

View File

@ -86,8 +86,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
// GetLoginURL with typical oauth parameters // GetLoginURL with typical oauth parameters
func (p *ProviderData) GetLoginURL(redirectURI, state string) string { func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
var a url.URL a := *p.LoginURL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery) params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI) params.Set("redirect_uri", redirectURI)
params.Add("acr_values", p.AcrValues) params.Add("acr_values", p.AcrValues)

View File

@ -44,7 +44,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) {
defer watcher.Close() defer watcher.Close()
for { for {
select { select {
case _ = <-done: case <-done:
logger.Printf("Shutting down watcher for: %s", filename) logger.Printf("Shutting down watcher for: %s", filename)
return return
case event := <-watcher.Events: case event := <-watcher.Events: