diff --git a/oauthproxy.go b/oauthproxy.go index b0c94eb..c3a5693 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e if err != nil { return nil, err } + + // Force setting these in case the Provider didn't + if s.CreatedAt == nil { + s.CreatedAtNow() + } + if s.ExpiresOn == nil { + s.ExpiresIn(p.CookieOptions.Expire) + } + return s, nil } @@ -861,9 +870,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expire": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ } // prepareNoCache prepares headers for preventing browser caching. diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index e1ee4a6..9e77609 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -11,6 +11,7 @@ import ( "time" "unicode/utf8" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4" "github.com/vmihailenco/msgpack/v4" @@ -32,7 +33,8 @@ type SessionState struct { Groups []string `msgpack:"g,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"` - Lock Lock `msgpack:"-"` + Clock clock.Clock `msgpack:"-"` + Lock Lock `msgpack:"-"` } func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { @@ -63,9 +65,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { return s.Lock.Peek(ctx) } +// CreatedAtNow sets a SessionState's CreatedAt to now +func (s *SessionState) CreatedAtNow() { + now := s.Clock.Now() + s.CreatedAt = &now +} + +// SetExpiresOn sets an expiration +func (s *SessionState) SetExpiresOn(exp time.Time) { + s.ExpiresOn = &exp +} + +// ExpiresIn sets an expiration a certain duration from CreatedAt. +// CreatedAt will be set to time.Now if it is unset. +func (s *SessionState) ExpiresIn(d time.Duration) { + if s.CreatedAt == nil { + s.CreatedAtNow() + } + exp := s.CreatedAt.Add(d) + s.ExpiresOn = &exp +} + // IsExpired checks whether the session has expired func (s *SessionState) IsExpired() bool { - if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { + if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { return true } return false @@ -74,7 +97,7 @@ func (s *SessionState) IsExpired() bool { // Age returns the age of a session func (s *SessionState) Age() time.Duration { if s.CreatedAt != nil && !s.CreatedAt.IsZero() { - return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) + return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) } return 0 } diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index b373758..9f69ba6 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -142,8 +142,7 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R } // If we refreshed, update the `CreatedAt` time to reset the refresh timer - // TODO: Implement - // session.CreatedAtNow() + session.CreatedAtNow() // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index ce51ed0..1b3c12d 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -36,8 +36,7 @@ type SessionStore struct { // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { - now := time.Now() - ss.CreatedAt = &now + ss.CreatedAtNow() } value, err := s.cookieForSession(ss) if err != nil { diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go index 4922517..3215b25 100644 --- a/pkg/sessions/persistence/manager.go +++ b/pkg/sessions/persistence/manager.go @@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager { // from the persistent data store. func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { if s.CreatedAt == nil || s.CreatedAt.IsZero() { - now := time.Now() - s.CreatedAt = &now + s.CreatedAtNow() } tckt, err := decodeTicketFromRequest(req, m.Options) diff --git a/providers/azure.go b/providers/azure.go index f66d376..46d7e30 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* return nil, err } - created := time.Now() - expires := time.Unix(jsonResponse.ExpiresOn, 0) - session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, } + session.CreatedAtNow() + session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) @@ -239,10 +236,9 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st return email, nil } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } @@ -257,7 +253,7 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions. return true, nil } -func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { params := url.Values{} params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) @@ -271,25 +267,23 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess IDToken string `json:"id_token"` } - err = requests.New(p.RedeemURL.String()). + err := requests.New(p.RedeemURL.String()). WithContext(ctx). WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). Do(). UnmarshalInto(&jsonResponse) - if err != nil { - return + return err } - now := time.Now() - expires := time.Unix(jsonResponse.ExpiresOn, 0) s.AccessToken = jsonResponse.AccessToken s.IDToken = jsonResponse.IDToken s.RefreshToken = jsonResponse.RefreshToken - s.CreatedAt = &now - s.ExpiresOn = &expires + + s.CreatedAtNow() + s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) @@ -312,7 +306,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess } } - return + return nil } func makeAzureHeader(accessToken string) http.Header { diff --git a/providers/gitlab.go b/providers/gitlab.go index ca9a8bf..a2b11df 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -259,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) } } - created := time.Now() - return &sessions.SessionState{ + ss := &sessions.SessionState{ AccessToken: token.AccessToken, IDToken: getIDToken(token), RefreshToken: token.RefreshToken, - CreatedAt: &created, - ExpiresOn: &idToken.Expiry, - }, nil + } + + ss.CreatedAtNow() + ss.SetExpiresOn(idToken.Expiry) + + return ss, nil } // ValidateSession checks that the session's IDToken is still valid diff --git a/providers/google.go b/providers/google.go index 49eae1c..0cfd3e1 100644 --- a/providers/google.go +++ b/providers/google.go @@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( return nil, err } - created := time.Now() - expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - - return &sessions.SessionState{ + ss := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, Email: c.Email, User: c.Subject, - }, nil + } + ss.CreatedAtNow() + ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) + + return ss, nil } // EnrichSession checks the listed Google Groups configured and adds any // that the user is a member of to session.Groups. -func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { +func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error { // TODO (@NickMeves) - Move to pure EnrichSession logic and stop // reusing legacy `groupValidator`. // @@ -272,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, nil } - newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) + newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken) if err != nil { return false, err } @@ -285,12 +284,12 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } - origExpiration := s.ExpiresOn - expires := time.Now().Add(duration).Truncate(time.Second) s.AccessToken = newToken s.IDToken = newIDToken - s.ExpiresOn = &expires - logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) + + s.CreatedAtNow() + s.ExpiresIn(ttl) + return true, nil } diff --git a/providers/logingov.go b/providers/logingov.go index 0f62520..43f361f 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) { if code == "" { return nil, ErrMissingCode } @@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) return nil, err } - created := time.Now() - expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - - // Store the data that we found in the session state - return &sessions.SessionState{ + session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, Email: email, - }, nil + } + + session.CreatedAtNow() + session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) + + return session, nil } // GetLoginURL overrides GetLoginURL to add login.gov parameters diff --git a/providers/oidc.go b/providers/oidc.go index 3e1e79a..2cbbd00 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -226,7 +226,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) ss.AccessToken = token ss.IDToken = token ss.RefreshToken = "" - ss.ExpiresOn = &idToken.Expiry + + ss.CreatedAtNow() + ss.SetExpiresOn(idToken.Expiry) return ss, nil } @@ -256,9 +258,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r ss.RefreshToken = token.RefreshToken ss.IDToken = getIDToken(token) - created := time.Now() - ss.CreatedAt = &created - ss.ExpiresOn = &token.Expiry + ss.CreatedAtNow() + ss.SetExpiresOn(token.Expiry) return ss, nil } diff --git a/providers/provider_default.go b/providers/provider_default.go index be57f0e..0a62c24 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/url" - "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s if err != nil { return nil, err } + // TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration if token := values.Get("access_token"); token != "" { - created := time.Now() - return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil + ss := &sessions.SessionState{ + AccessToken: token, + } + ss.CreatedAtNow() + return ss, nil } return nil, fmt.Errorf("no access token found %s", result.Body())