Add request scope middleware

This commit is contained in:
Joel Speed 2020-07-04 18:41:58 +01:00
parent 1aac37d2b1
commit 2768321929
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
3 changed files with 157 additions and 0 deletions

View File

@ -0,0 +1,24 @@
package middleware
import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
)
// RequestScope contains information regarding the request that is being made.
// The RequestScope is used to pass information between different middlewares
// within the chain.
type RequestScope struct {
// Session details the authenticated users information (if it exists).
Session *sessions.SessionState
// SaveSession indicates whether the session storage should attempt to save
// the session or not.
SaveSession bool
// ClearSession indicates whether the user should be logged out or not.
ClearSession bool
// SessionRevalidated indicates whether the session has been revalidated since
// it was loaded or not.
SessionRevalidated bool
}

39
pkg/middleware/scope.go Normal file
View File

@ -0,0 +1,39 @@
package middleware
import (
"context"
"net/http"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware"
)
type scopeKey string
// requestScopeKey uses a typed string to reduce likelihood of clasing
// with other context keys
const requestScopeKey scopeKey = "request-scope"
func NewScope() alice.Constructor {
return addScope
}
// addScope injects a new request scope into the request context.
func addScope(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := &middlewareapi.RequestScope{}
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
requestWithScope := req.WithContext(contextWithScope)
next.ServeHTTP(rw, requestWithScope)
})
}
// GetRequestScope returns the current request scope from the given request
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
scope := req.Context().Value(requestScopeKey)
if scope == nil {
return nil
}
return scope.(*middlewareapi.RequestScope)
}

View File

@ -0,0 +1,94 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Scope Suite", func() {
Context("NewScope", func() {
var request, nextRequest *http.Request
var rw http.ResponseWriter
BeforeEach(func() {
var err error
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
Expect(err).ToNot(HaveOccurred())
rw = httptest.NewRecorder()
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextRequest = r
w.WriteHeader(200)
}))
handler.ServeHTTP(rw, request)
})
It("does not add a scope to the original request", func() {
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
})
It("cannot load a scope from the original request using GetRequestScope", func() {
Expect(GetRequestScope(request)).To(BeNil())
})
It("adds a scope to the request for the next handler", func() {
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
})
It("can load a scope from the next handler's request using GetRequestScope", func() {
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
})
})
Context("GetRequestScope", func() {
var request *http.Request
BeforeEach(func() {
var err error
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
Expect(err).ToNot(HaveOccurred())
})
Context("with a scope", func() {
var scope *middlewareapi.RequestScope
BeforeEach(func() {
scope = &middlewareapi.RequestScope{}
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope)
request = request.WithContext(contextWithScope)
})
It("returns the scope", func() {
s := GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
})
Context("if the scope is then modified", func() {
BeforeEach(func() {
Expect(scope.SaveSession).To(BeFalse())
scope.SaveSession = true
})
It("returns the updated session", func() {
s := GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
Expect(s.SaveSession).To(BeTrue())
})
})
})
Context("without a scope", func() {
It("returns nil", func() {
Expect(GetRequestScope(request)).To(BeNil())
})
})
})
})