2495 lines
63 KiB
Go
2495 lines
63 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/mbland/hmacauth"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
const (
|
|
// The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64
|
|
// encoded version of this.
|
|
rawCookieSecret = "secretthirtytwobytes+abcdefghijk"
|
|
base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams"
|
|
clientID = "3984n253984d7348dm8234yf982t"
|
|
clientSecret = "gv3498mfc9t23y23974dm2394dm9"
|
|
)
|
|
|
|
func init() {
|
|
logger.SetFlags(logger.Lshortfile)
|
|
}
|
|
|
|
func TestRobotsTxt(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/robots.txt", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "User-agent: *\nDisallow: /\n", rw.Body.String())
|
|
}
|
|
|
|
type TestProvider struct {
|
|
*providers.ProviderData
|
|
EmailAddress string
|
|
ValidToken bool
|
|
GroupValidator func(string) bool
|
|
}
|
|
|
|
var _ providers.Provider = (*TestProvider)(nil)
|
|
|
|
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
|
return &TestProvider{
|
|
ProviderData: &providers.ProviderData{
|
|
ProviderName: "Test Provider",
|
|
LoginURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/oauth/authorize",
|
|
},
|
|
RedeemURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/oauth/token",
|
|
},
|
|
ProfileURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: providerURL.Host,
|
|
Path: "/api/v1/profile",
|
|
},
|
|
Scope: "profile.email",
|
|
},
|
|
EmailAddress: emailAddress,
|
|
GroupValidator: func(s string) bool {
|
|
return true
|
|
},
|
|
}
|
|
}
|
|
|
|
func (tp *TestProvider) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
|
|
return tp.EmailAddress, nil
|
|
}
|
|
|
|
func (tp *TestProvider) ValidateSession(_ context.Context, _ *sessions.SessionState) bool {
|
|
return tp.ValidToken
|
|
}
|
|
|
|
func Test_redeemCode(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
_, err = proxy.redeemCode(req)
|
|
assert.Equal(t, providers.ErrMissingCode, err)
|
|
}
|
|
|
|
func Test_enrichSession(t *testing.T) {
|
|
const (
|
|
sessionUser = "Mr Session"
|
|
sessionEmail = "session@example.com"
|
|
providerEmail = "provider@example.com"
|
|
)
|
|
|
|
testCases := map[string]struct {
|
|
session *sessions.SessionState
|
|
expectedUser string
|
|
expectedEmail string
|
|
}{
|
|
"Session already has enrichable fields": {
|
|
session: &sessions.SessionState{
|
|
User: sessionUser,
|
|
Email: sessionEmail,
|
|
},
|
|
expectedUser: sessionUser,
|
|
expectedEmail: sessionEmail,
|
|
},
|
|
"Session is missing Email and GetEmailAddress is implemented": {
|
|
session: &sessions.SessionState{
|
|
User: sessionUser,
|
|
},
|
|
expectedUser: sessionUser,
|
|
expectedEmail: providerEmail,
|
|
},
|
|
"Session is missing User and GetUserName is not implemented": {
|
|
session: &sessions.SessionState{
|
|
Email: sessionEmail,
|
|
},
|
|
expectedUser: "",
|
|
expectedEmail: sessionEmail,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
// intentionally set after validation.Validate(opts) since it will clobber
|
|
// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
|
|
opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail))
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = proxy.enrichSessionState(context.Background(), tc.session)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.expectedUser, tc.session.User)
|
|
assert.Equal(t, tc.expectedEmail, tc.session.Email)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBasicAuthPassword(t *testing.T) {
|
|
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
logger.Printf("%#v", r)
|
|
var payload string
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("Authorization")
|
|
if payload == "" {
|
|
payload = "No Authorization header found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte(payload))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
|
|
basicAuthPassword := "This is a secure password"
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: providerServer.URL,
|
|
Path: "/",
|
|
URI: providerServer.URL,
|
|
},
|
|
}
|
|
|
|
opts.Cookie.Secure = false
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
BasicAuthPassword: &options.SecretSource{
|
|
Value: []byte(basicAuthPassword),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
providerURL, _ := url.Parse(providerServer.URL)
|
|
const emailAddress = "john.doe@example.com"
|
|
|
|
opts.SetProvider(NewTestProvider(providerURL, emailAddress))
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Save the required session
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
err = proxy.sessionStore.Save(rw, req, &sessions.SessionState{
|
|
Email: emailAddress,
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
// Extract the cookie value to inject into the test request
|
|
cookie := rw.Header().Values("Set-Cookie")[0]
|
|
|
|
req, _ = http.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Cookie", cookie)
|
|
rw = httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
// The username in the basic auth credentials is expected to be equal to the email address from the
|
|
// auth response, so we use the same variable here.
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+basicAuthPassword))
|
|
assert.Equal(t, expectedHeader, rw.Body.String())
|
|
providerServer.Close()
|
|
}
|
|
|
|
func TestPassGroupsHeadersWithGroups(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "X-Forwarded-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
const emailAddress = "john.doe@example.com"
|
|
const userName = "9fcab5c9b889a557"
|
|
|
|
groups := []string{"a", "b"}
|
|
created := time.Now()
|
|
session := &sessions.SessionState{
|
|
User: userName,
|
|
Groups: groups,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
// Save the required session
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
err = proxy.sessionStore.Save(rw, req, session)
|
|
assert.NoError(t, err)
|
|
|
|
// Extract the cookie value to inject into the test request
|
|
cookie := rw.Header().Values("Set-Cookie")[0]
|
|
|
|
req, _ = http.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Cookie", cookie)
|
|
rw = httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
assert.Equal(t, []string{"a,b"}, req.Header["X-Forwarded-Groups"])
|
|
}
|
|
|
|
type PassAccessTokenTest struct {
|
|
providerServer *httptest.Server
|
|
proxy *OAuthProxy
|
|
opts *options.Options
|
|
}
|
|
|
|
type PassAccessTokenTestOptions struct {
|
|
PassAccessToken bool
|
|
ProxyUpstream options.Upstream
|
|
}
|
|
|
|
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
|
|
patt := &PassAccessTokenTest{}
|
|
|
|
patt.providerServer = httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload string
|
|
switch r.URL.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("X-Forwarded-Access-Token")
|
|
if payload == "" {
|
|
payload = "No access token found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte(payload))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}))
|
|
|
|
patt.opts = baseTestOptions()
|
|
patt.opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: patt.providerServer.URL,
|
|
Path: "/",
|
|
URI: patt.providerServer.URL,
|
|
},
|
|
}
|
|
if opts.ProxyUpstream.ID != "" {
|
|
patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream)
|
|
}
|
|
|
|
patt.opts.Cookie.Secure = false
|
|
if opts.PassAccessToken {
|
|
patt.opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "X-Forwarded-Access-Token",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "access_token",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
err := validation.Validate(patt.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerURL, _ := url.Parse(patt.providerServer.URL)
|
|
const emailAddress = "michael.bland@gsa.gov"
|
|
|
|
patt.opts.SetProvider(NewTestProvider(providerURL, emailAddress))
|
|
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
|
|
return email == emailAddress
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return patt, nil
|
|
}
|
|
|
|
func (patTest *PassAccessTokenTest) Close() {
|
|
patTest.providerServer.Close()
|
|
}
|
|
|
|
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) {
|
|
rw := httptest.NewRecorder()
|
|
|
|
csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
req, err := http.NewRequest(
|
|
http.MethodGet,
|
|
fmt.Sprintf(
|
|
"/oauth2/callback?code=callback_code&state=%s",
|
|
encodeState(csrf.HashOAuthState(), "%2F"),
|
|
),
|
|
strings.NewReader(""),
|
|
)
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
|
|
// rw is a dummy here, we just want the csrfCookie to add to our req
|
|
csrfCookie, err := csrf.SetCookie(httptest.NewRecorder(), req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
req.AddCookie(csrfCookie)
|
|
|
|
patTest.proxy.ServeHTTP(rw, req)
|
|
|
|
return rw.Code, rw.Header().Values("Set-Cookie")[1]
|
|
}
|
|
|
|
// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
|
|
// and cookie and returns body and status code.
|
|
func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
|
|
cookieName := patTest.proxy.CookieOptions.Name
|
|
var value string
|
|
keyPrefix := cookieName + "="
|
|
|
|
for _, field := range strings.Split(cookie, "; ") {
|
|
value = strings.TrimPrefix(field, keyPrefix)
|
|
if value != field {
|
|
break
|
|
} else {
|
|
value = ""
|
|
}
|
|
}
|
|
if value == "" {
|
|
return 0, ""
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
req.AddCookie(&http.Cookie{
|
|
Name: cookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(time.Duration(24)),
|
|
HttpOnly: true,
|
|
})
|
|
|
|
rw := httptest.NewRecorder()
|
|
patTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestForwardAccessTokenUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotNil(t, cookie)
|
|
|
|
// Now we make a regular request; the access_token from the cookie is
|
|
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
|
// read by the test provider server and written in the response body.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "my_auth_token", payload)
|
|
}
|
|
|
|
func TestStaticProxyUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
ProxyUpstream: options.Upstream{
|
|
ID: "static-proxy",
|
|
Path: "/static-proxy",
|
|
Static: true,
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request against the upstream proxy; And validate
|
|
// the returned status code through the static proxy.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "Authenticated", payload)
|
|
}
|
|
|
|
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
|
patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(patTest.Close)
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := patTest.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request, but the access token header should
|
|
// not be present.
|
|
code, payload := patTest.getEndpointWithCookie(cookie, "/")
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "No access token found.", payload)
|
|
}
|
|
|
|
type SignInPageTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
signInRegexp *regexp.Regexp
|
|
signInProviderRegexp *regexp.Regexp
|
|
}
|
|
|
|
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
|
const signInSkipProvider = `>Found<`
|
|
|
|
func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
|
|
var sipTest SignInPageTest
|
|
|
|
sipTest.opts = baseTestOptions()
|
|
sipTest.opts.SkipProviderButton = skipProvider
|
|
err := validation.Validate(sipTest.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
|
|
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
|
|
|
|
return &sipTest, nil
|
|
}
|
|
|
|
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
sipTest.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
const endpoint = "/some/random/endpoint"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 403, code)
|
|
|
|
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != endpoint {
|
|
t.Fatal(`expected redirect to "` + endpoint +
|
|
`", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
|
|
assert.Equal(t, 200, code)
|
|
|
|
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != "/" {
|
|
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProvider(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
endpoint := "/some/random/endpoint"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProviderDirect(t *testing.T) {
|
|
sipTest, err := NewSignInPageTest(true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
endpoint := "/sign_in"
|
|
|
|
code, body := sipTest.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
type ProcessCookieTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
rw *httptest.ResponseRecorder
|
|
req *http.Request
|
|
validateUser bool
|
|
}
|
|
|
|
type ProcessCookieTestOpts struct {
|
|
providerValidateCookieResponse bool
|
|
}
|
|
|
|
type OptionsModifier func(*options.Options)
|
|
|
|
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
for _, modifier := range modifiers {
|
|
modifier(pcTest.opts)
|
|
}
|
|
// First, set the CookieRefresh option so proxy.AesCipher is created,
|
|
// needed to encrypt the access_token.
|
|
pcTest.opts.Cookie.Refresh = time.Hour
|
|
err := validation.Validate(pcTest.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: opts.providerValidateCookieResponse,
|
|
}
|
|
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups)
|
|
|
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
|
// access_token validation.
|
|
pcTest.proxy.CookieOptions.Refresh = time.Duration(0)
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
|
pcTest.validateUser = true
|
|
return &pcTest, nil
|
|
}
|
|
|
|
func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
providerValidateCookieResponse: true,
|
|
})
|
|
}
|
|
|
|
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
providerValidateCookieResponse: true,
|
|
}, modifiers...)
|
|
}
|
|
|
|
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
|
|
err := p.proxy.SaveSession(p.rw, p.req, s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, cookie := range p.rw.Result().Cookies() {
|
|
p.req.AddCookie(cookie)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
|
|
return p.proxy.LoadCookiedSession(p.req)
|
|
}
|
|
|
|
func TestLoadCookiedSession(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
assert.Equal(t, "", session.User)
|
|
assert.Equal(t, startSession.AccessToken, session.AccessToken)
|
|
}
|
|
|
|
func TestProcessCookieNoCookieError(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.Error(t, err, "cookie \"_oauth2_proxy\" not present")
|
|
if session != nil {
|
|
t.Errorf("expected nil session. got %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(23) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
|
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.Equal(t, nil, err)
|
|
if session.Age() < time.Duration(-2)*time.Hour {
|
|
t.Errorf("cookie too young %v", session.Age())
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
}
|
|
|
|
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.CookieOptions.Refresh = time.Hour
|
|
session, err := pcTest.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func NewUserInfoEndpointTest() (*ProcessCookieTest, error) {
|
|
pcTest, err := NewProcessCookieTestWithDefaults()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/userinfo", nil)
|
|
return pcTest, nil
|
|
}
|
|
|
|
func TestUserInfoEndpointAccepted(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
session *sessions.SessionState
|
|
expectedResponse string
|
|
}{
|
|
{
|
|
name: "Full session",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
Email: "john.doe@example.com",
|
|
Groups: []string{"example", "groups"},
|
|
AccessToken: "my_access_token",
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"]}\n",
|
|
},
|
|
{
|
|
name: "Minimal session",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
Email: "john.doe@example.com",
|
|
Groups: []string{"example", "groups"},
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"]}\n",
|
|
},
|
|
{
|
|
name: "No groups",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
Email: "john.doe@example.com",
|
|
AccessToken: "my_access_token",
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\"}\n",
|
|
},
|
|
{
|
|
name: "With Preferred Username",
|
|
session: &sessions.SessionState{
|
|
User: "john.doe",
|
|
PreferredUsername: "john",
|
|
Email: "john.doe@example.com",
|
|
Groups: []string{"example", "groups"},
|
|
AccessToken: "my_access_token",
|
|
},
|
|
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\"}\n",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
test, err := NewUserInfoEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = test.SaveSession(tc.session)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusOK, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, tc.expectedResponse, string(bodyBytes))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test, err := NewUserInfoEndpointTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
}
|
|
|
|
func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pcTest.req, _ = http.NewRequest(
|
|
"GET",
|
|
fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring),
|
|
nil)
|
|
return pcTest, nil
|
|
}
|
|
|
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
|
test, err := NewAuthOnlyEndpointTest("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
|
|
err = test.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
test.validateUser = false
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Preferred-Username",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "preferred_username",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
pcTest.opts.Providers[0].AllowedGroups = []string{"oauth_groups"}
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User"))
|
|
assert.Equal(t, startSession.Groups, pcTest.rw.Header().Values("X-Auth-Request-Groups"))
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email"))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Preferred-Username",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "preferred_username",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
BasicAuthPassword: &options.SecretSource{
|
|
Value: []byte("This is a secure password"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:This is a secure password"))
|
|
assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0])
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
|
var pcTest ProcessCookieTest
|
|
|
|
pcTest.opts = baseTestOptions()
|
|
pcTest.opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Groups",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "groups",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Preferred-Username",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "preferred_username",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
err := validation.Validate(pcTest.opts)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
|
return pcTest.validateUser
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pcTest.proxy.provider = &TestProvider{
|
|
ProviderData: &providers.ProviderData{},
|
|
ValidToken: true,
|
|
}
|
|
|
|
pcTest.validateUser = true
|
|
|
|
pcTest.rw = httptest.NewRecorder()
|
|
pcTest.req, _ = http.NewRequest("GET",
|
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
created := time.Now()
|
|
startSession := &sessions.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
|
|
err = pcTest.SaveSession(startSession)
|
|
assert.NoError(t, err)
|
|
|
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
|
assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
|
|
assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
|
|
assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries")
|
|
}
|
|
|
|
func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte("response"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
}
|
|
opts.SkipAuthPreflight = true
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
upstreamURL, _ := url.Parse(upstreamServer.URL)
|
|
opts.SetProvider(NewTestProvider(upstreamURL, ""))
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "response", rw.Body.String())
|
|
}
|
|
|
|
type SignatureAuthenticator struct {
|
|
auth hmacauth.HmacAuth
|
|
}
|
|
|
|
func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
|
|
result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
|
|
|
|
var msg string
|
|
switch result {
|
|
case hmacauth.ResultNoSignature:
|
|
msg = "no signature received"
|
|
case hmacauth.ResultMatch:
|
|
msg = "signatures match"
|
|
case hmacauth.ResultMismatch:
|
|
msg = fmt.Sprintf(
|
|
"signatures do not match:\n received: %s\n computed: %s",
|
|
headerSig,
|
|
computedSig)
|
|
default:
|
|
panic("unknown result value: " + result.String())
|
|
}
|
|
|
|
_, err := w.Write([]byte(msg))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
type SignatureTest struct {
|
|
opts *options.Options
|
|
upstream *httptest.Server
|
|
upstreamHost string
|
|
provider *httptest.Server
|
|
header http.Header
|
|
rw *httptest.ResponseRecorder
|
|
authenticator *SignatureAuthenticator
|
|
}
|
|
|
|
func NewSignatureTest() (*SignatureTest, error) {
|
|
opts := baseTestOptions()
|
|
opts.EmailDomains = []string{"acm.org"}
|
|
|
|
authenticator := &SignatureAuthenticator{}
|
|
upstreamServer := httptest.NewServer(
|
|
http.HandlerFunc(authenticator.Authenticate))
|
|
upstreamURL, err := url.Parse(upstreamServer.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
}
|
|
|
|
providerHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
|
|
providerURL, err := url.Parse(provider.URL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org"))
|
|
|
|
return &SignatureTest{
|
|
opts,
|
|
upstreamServer,
|
|
upstreamURL.Host,
|
|
provider,
|
|
make(http.Header),
|
|
httptest.NewRecorder(),
|
|
authenticator,
|
|
}, nil
|
|
}
|
|
|
|
func (st *SignatureTest) Close() {
|
|
st.provider.Close()
|
|
st.upstream.Close()
|
|
}
|
|
|
|
// fakeNetConn simulates an http.Request.Body buffer that will be consumed
|
|
// when it is read by the hmacauth.HmacAuth if not handled properly. See:
|
|
// https://github.com/18F/hmacauth/pull/4
|
|
type fakeNetConn struct {
|
|
reqBody string
|
|
}
|
|
|
|
func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
|
|
if bodyLen := len(fnc.reqBody); bodyLen != 0 {
|
|
copy(p, fnc.reqBody)
|
|
fnc.reqBody = ""
|
|
return bodyLen, io.EOF
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
|
|
func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error {
|
|
err := validation.Validate(st.opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var bodyBuf io.ReadCloser
|
|
if body != "" {
|
|
bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
|
|
}
|
|
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
|
req.Header = st.header
|
|
|
|
state := &sessions.SessionState{
|
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
|
err = proxy.SaveSession(st.rw, req, state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, c := range st.rw.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
// This is used by the upstream to validate the signature.
|
|
st.authenticator.auth = hmacauth.NewHmacAuth(
|
|
crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders)
|
|
proxy.ServeHTTP(st.rw, req)
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestRequestSignature(t *testing.T) {
|
|
testCases := map[string]struct {
|
|
method string
|
|
body string
|
|
key string
|
|
resp string
|
|
}{
|
|
"No request signature": {
|
|
method: "GET",
|
|
body: "",
|
|
key: "",
|
|
resp: "no signature received",
|
|
},
|
|
"Get request": {
|
|
method: "GET",
|
|
body: "",
|
|
key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d",
|
|
resp: "signatures match",
|
|
},
|
|
"Post request": {
|
|
method: "POST",
|
|
body: `{ "hello": "world!" }`,
|
|
key: "d90df39e2d19282840252612dd7c81421a372f61",
|
|
resp: "signatures match",
|
|
},
|
|
}
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
st, err := NewSignatureTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(st.Close)
|
|
if tc.key != "" {
|
|
st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key)
|
|
}
|
|
err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 200, st.rw.Code)
|
|
assert.Equal(t, tc.resp, st.rw.Body.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
type ajaxRequestTest struct {
|
|
opts *options.Options
|
|
proxy *OAuthProxy
|
|
}
|
|
|
|
func newAjaxRequestTest() (*ajaxRequestTest, error) {
|
|
test := &ajaxRequestTest{}
|
|
test.opts = baseTestOptions()
|
|
err := validation.Validate(test.opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return test, nil
|
|
}
|
|
|
|
func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) {
|
|
rw := httptest.NewRecorder()
|
|
req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
req.Header = header
|
|
test.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Header(), nil
|
|
}
|
|
|
|
func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) {
|
|
test, err := newAjaxRequestTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
endpoint := "/test"
|
|
|
|
code, rh, err := test.getEndpoint(endpoint, header)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusUnauthorized, code)
|
|
mime := rh.Get("Content-Type")
|
|
assert.Equal(t, applicationJSON, mime)
|
|
}
|
|
func TestAjaxUnauthorizedRequest1(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("accept", applicationJSON)
|
|
|
|
testAjaxUnauthorizedRequest(t, header)
|
|
}
|
|
|
|
func TestAjaxUnauthorizedRequest2(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("Accept", applicationJSON)
|
|
|
|
testAjaxUnauthorizedRequest(t, header)
|
|
}
|
|
|
|
func TestAjaxUnauthorizedRequestAccept1(t *testing.T) {
|
|
header := make(http.Header)
|
|
header.Add("Accept", "application/json, text/plain, */*")
|
|
|
|
testAjaxUnauthorizedRequest(t, header)
|
|
}
|
|
|
|
func TestAjaxForbiddendRequest(t *testing.T) {
|
|
test, err := newAjaxRequestTest()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
endpoint := "/test"
|
|
header := make(http.Header)
|
|
code, rh, err := test.getEndpoint(endpoint, header)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, http.StatusForbidden, code)
|
|
mime := rh.Get("Content-Type")
|
|
assert.NotEqual(t, applicationJSON, mime)
|
|
}
|
|
|
|
func TestClearSplitCookie(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Cookie.Secret = base64CookieSecret
|
|
opts.Cookie.Name = "oauth2"
|
|
opts.Cookie.Domains = []string{"abc"}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
|
|
var rw = httptest.NewRecorder()
|
|
req := httptest.NewRequest("get", "/", nil)
|
|
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "test1",
|
|
Value: "test1",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2_0",
|
|
Value: "oauth2_0",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2_1",
|
|
Value: "oauth2_1",
|
|
})
|
|
|
|
err = p.ClearSessionCookie(rw, req)
|
|
assert.NoError(t, err)
|
|
header := rw.Header()
|
|
|
|
assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
|
|
}
|
|
|
|
func TestClearSingleCookie(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.Cookie.Name = "oauth2"
|
|
opts.Cookie.Domains = []string{"abc"}
|
|
store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
|
|
var rw = httptest.NewRecorder()
|
|
req := httptest.NewRequest("get", "/", nil)
|
|
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "test1",
|
|
Value: "test1",
|
|
})
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "oauth2",
|
|
Value: "oauth2",
|
|
})
|
|
|
|
err = p.ClearSessionCookie(rw, req)
|
|
assert.NoError(t, err)
|
|
header := rw.Header()
|
|
|
|
assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
|
|
}
|
|
|
|
type NoOpKeySet struct {
|
|
}
|
|
|
|
func (NoOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
|
|
splitStrings := strings.Split(jwt, ".")
|
|
payloadString := splitStrings[1]
|
|
return base64.RawURLEncoding.DecodeString(payloadString)
|
|
}
|
|
|
|
func TestGetJwtSession(t *testing.T) {
|
|
/* token payload:
|
|
{
|
|
"sub": "1234567890",
|
|
"aud": "https://test.myapp.com",
|
|
"name": "John Doe",
|
|
"email": "john@example.com",
|
|
"iss": "https://issuer.example.com",
|
|
"iat": 1553691215,
|
|
"exp": 1912151821
|
|
}
|
|
*/
|
|
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
|
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
|
|
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
|
|
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
|
|
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
|
|
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
|
|
|
|
keyset := NoOpKeySet{}
|
|
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
|
|
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
|
|
|
|
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "id_token",
|
|
Prefix: "Bearer ",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
opts.InjectResponseHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "id_token",
|
|
Prefix: "Bearer ",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Auth-Request-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
opts.SkipJwtBearerTokens = true
|
|
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tp, _ := test.proxy.provider.(*TestProvider)
|
|
tp.GroupValidator = func(s string) bool {
|
|
return true
|
|
}
|
|
|
|
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
|
|
test.req.Header = map[string][]string{
|
|
"Authorization": {authHeader},
|
|
}
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
if test.rw.Code >= 400 {
|
|
t.Fatalf("expected 3xx got %d", test.rw.Code)
|
|
}
|
|
|
|
// Check PassAuthorization, should overwrite Basic header
|
|
assert.Equal(t, test.req.Header.Get("Authorization"), authHeader)
|
|
assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890")
|
|
assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
|
|
|
|
// SetAuthorization and SetXAuthRequest
|
|
assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader)
|
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890")
|
|
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
|
|
}
|
|
|
|
func Test_prepareNoCache(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
prepareNoCache(w)
|
|
})
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/", handler)
|
|
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
for k, v := range noCacheHeaders {
|
|
assert.Equal(t, rec.Header().Get(k), v)
|
|
}
|
|
}
|
|
|
|
func Test_noCacheHeaders(t *testing.T) {
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("upstream"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
}
|
|
opts.SkipAuthRegex = []string{".*"}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Run("not exist in response from upstream", func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/upstream", nil)
|
|
proxy.ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "upstream", rec.Body.String())
|
|
|
|
// checking noCacheHeaders does not exists in response headers from upstream
|
|
for k := range noCacheHeaders {
|
|
assert.Equal(t, "", rec.Header().Get(k))
|
|
}
|
|
})
|
|
|
|
t.Run("has no-cache", func(t *testing.T) {
|
|
tests := []struct {
|
|
path string
|
|
hasNoCache bool
|
|
}{
|
|
{
|
|
path: "/oauth2/sign_in",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/sign_out",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/start",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/callback",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/oauth2/auth",
|
|
hasNoCache: false,
|
|
},
|
|
{
|
|
path: "/oauth2/userinfo",
|
|
hasNoCache: true,
|
|
},
|
|
{
|
|
path: "/upstream",
|
|
hasNoCache: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.path, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
|
|
proxy.ServeHTTP(rec, req)
|
|
cacheControl := rec.Result().Header.Get("Cache-Control")
|
|
if tt.hasNoCache != (strings.Contains(cacheControl, "no-cache")) {
|
|
t.Errorf(`unexpected "Cache-Control" header: %s`, cacheControl)
|
|
}
|
|
})
|
|
}
|
|
|
|
})
|
|
}
|
|
|
|
func baseTestOptions() *options.Options {
|
|
opts := options.NewOptions()
|
|
opts.Cookie.Secret = rawCookieSecret
|
|
opts.Providers[0].ID = "providerID"
|
|
opts.Providers[0].ClientID = clientID
|
|
opts.Providers[0].ClientSecret = clientSecret
|
|
opts.EmailDomains = []string{"*"}
|
|
|
|
// Default injected headers for legacy configuration
|
|
opts.InjectRequestHeaders = []options.Header{
|
|
{
|
|
Name: "Authorization",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
BasicAuthPassword: &options.SecretSource{
|
|
Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-User",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "user",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Name: "X-Forwarded-Email",
|
|
Values: []options.HeaderValue{
|
|
{
|
|
ClaimSource: &options.ClaimSource{
|
|
Claim: "email",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return opts
|
|
}
|
|
|
|
func TestTrustedIPs(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
trustedIPs []string
|
|
reverseProxy bool
|
|
realClientIPHeader string
|
|
req *http.Request
|
|
expectTrusted bool
|
|
}{
|
|
// Check unconfigured behavior.
|
|
{
|
|
name: "Default",
|
|
trustedIPs: nil,
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check using req.RemoteAddr (Options.ReverseProxy == false).
|
|
{
|
|
name: "WithRemoteAddr",
|
|
trustedIPs: []string{"127.0.0.1"},
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "127.0.0.1:43670"
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check ignores req.RemoteAddr match when behind a reverse proxy / missing header.
|
|
{
|
|
name: "IgnoresRemoteAddrInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Real-IP", // Default value
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "127.0.0.1:44324"
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check successful trusting of localhost in IPv4.
|
|
{
|
|
name: "TrustsLocalhostInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "127.0.0.1")
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check successful trusting of localhost in IPv6.
|
|
{
|
|
name: "TrustsIP6LocalostInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "::1")
|
|
return req
|
|
}(),
|
|
expectTrusted: true,
|
|
},
|
|
// Check does not trust random IPv4 address.
|
|
{
|
|
name: "DoesNotTrustRandomIP4Address",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "12.34.56.78")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check does not trust random IPv6 address.
|
|
{
|
|
name: "DoesNotTrustRandomIP6Address",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "::2")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check respects correct header.
|
|
{
|
|
name: "RespectsCorrectHeaderInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Real-IP", "::1")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check doesn't trust if garbage is provided.
|
|
{
|
|
name: "DoesNotTrustGarbageInReverseProxyMode",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: true,
|
|
realClientIPHeader: "X-Forwarded-For",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.Header.Add("X-Forwarded-For", "adsfljk29242as!!")
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
// Check doesn't trust if garbage is provided (no reverse-proxy).
|
|
{
|
|
name: "DoesNotTrustGarbage",
|
|
trustedIPs: []string{"127.0.0.0/8", "::1"},
|
|
reverseProxy: false,
|
|
realClientIPHeader: "X-Real-IP",
|
|
req: func() *http.Request {
|
|
req, _ := http.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = "adsfljk29242as!!"
|
|
return req
|
|
}(),
|
|
expectTrusted: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: "static",
|
|
Path: "/",
|
|
Static: true,
|
|
},
|
|
}
|
|
opts.TrustedIPs = tt.trustedIPs
|
|
opts.ReverseProxy = tt.reverseProxy
|
|
opts.RealClientIPHeader = tt.realClientIPHeader
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
|
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
|
assert.NoError(t, err)
|
|
rw := httptest.NewRecorder()
|
|
|
|
proxy.ServeHTTP(rw, tt.req)
|
|
if tt.expectTrusted {
|
|
assert.Equal(t, 200, rw.Code)
|
|
} else {
|
|
assert.Equal(t, 403, rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_buildRoutesAllowlist(t *testing.T) {
|
|
type expectedAllowedRoute struct {
|
|
method string
|
|
regexString string
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
skipAuthRegex []string
|
|
skipAuthRoutes []string
|
|
expectedRoutes []expectedAllowedRoute
|
|
shouldError bool
|
|
}{
|
|
{
|
|
name: "No skip auth configured",
|
|
skipAuthRegex: []string{},
|
|
skipAuthRoutes: []string{},
|
|
expectedRoutes: []expectedAllowedRoute{},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Only skipAuthRegex configured",
|
|
skipAuthRegex: []string{
|
|
"^/foo/bar",
|
|
"^/baz/[0-9]+/thing",
|
|
},
|
|
skipAuthRoutes: []string{},
|
|
expectedRoutes: []expectedAllowedRoute{
|
|
{
|
|
method: "",
|
|
regexString: "^/foo/bar",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/baz/[0-9]+/thing",
|
|
},
|
|
},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Only skipAuthRoutes configured",
|
|
skipAuthRegex: []string{},
|
|
skipAuthRoutes: []string{
|
|
"GET=^/foo/bar",
|
|
"POST=^/baz/[0-9]+/thing",
|
|
"^/all/methods$",
|
|
"WEIRD=^/methods/are/allowed",
|
|
"PATCH=/second/equals?are=handled&just=fine",
|
|
},
|
|
expectedRoutes: []expectedAllowedRoute{
|
|
{
|
|
method: "GET",
|
|
regexString: "^/foo/bar",
|
|
},
|
|
{
|
|
method: "POST",
|
|
regexString: "^/baz/[0-9]+/thing",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/all/methods$",
|
|
},
|
|
{
|
|
method: "WEIRD",
|
|
regexString: "^/methods/are/allowed",
|
|
},
|
|
{
|
|
method: "PATCH",
|
|
regexString: "/second/equals?are=handled&just=fine",
|
|
},
|
|
},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Both skipAuthRegexes and skipAuthRoutes configured",
|
|
skipAuthRegex: []string{
|
|
"^/foo/bar/regex",
|
|
"^/baz/[0-9]+/thing/regex",
|
|
},
|
|
skipAuthRoutes: []string{
|
|
"GET=^/foo/bar",
|
|
"POST=^/baz/[0-9]+/thing",
|
|
"^/all/methods$",
|
|
},
|
|
expectedRoutes: []expectedAllowedRoute{
|
|
{
|
|
method: "",
|
|
regexString: "^/foo/bar/regex",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/baz/[0-9]+/thing/regex",
|
|
},
|
|
{
|
|
method: "GET",
|
|
regexString: "^/foo/bar",
|
|
},
|
|
{
|
|
method: "POST",
|
|
regexString: "^/baz/[0-9]+/thing",
|
|
},
|
|
{
|
|
method: "",
|
|
regexString: "^/all/methods$",
|
|
},
|
|
},
|
|
shouldError: false,
|
|
},
|
|
{
|
|
name: "Invalid skipAuthRegex entry",
|
|
skipAuthRegex: []string{
|
|
"^/foo/bar",
|
|
"^/baz/[0-9]+/thing",
|
|
"(bad[regex",
|
|
},
|
|
skipAuthRoutes: []string{},
|
|
expectedRoutes: []expectedAllowedRoute{},
|
|
shouldError: true,
|
|
},
|
|
{
|
|
name: "Invalid skipAuthRoutes entry",
|
|
skipAuthRegex: []string{},
|
|
skipAuthRoutes: []string{
|
|
"GET=^/foo/bar",
|
|
"POST=^/baz/[0-9]+/thing",
|
|
"^/all/methods$",
|
|
"PUT=(bad[regex",
|
|
},
|
|
expectedRoutes: []expectedAllowedRoute{},
|
|
shouldError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
opts := &options.Options{
|
|
SkipAuthRegex: tc.skipAuthRegex,
|
|
SkipAuthRoutes: tc.skipAuthRoutes,
|
|
}
|
|
routes, err := buildRoutesAllowlist(opts)
|
|
if tc.shouldError {
|
|
assert.Error(t, err)
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
|
|
for i, route := range routes {
|
|
assert.Greater(t, len(tc.expectedRoutes), i)
|
|
assert.Equal(t, route.method, tc.expectedRoutes[i].method)
|
|
assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAllowedRequest(t *testing.T) {
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
_, err := w.Write([]byte("Allowed Request"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
opts := baseTestOptions()
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
}
|
|
opts.SkipAuthRegex = []string{
|
|
"^/skip/auth/regex$",
|
|
}
|
|
opts.SkipAuthRoutes = []string{
|
|
"GET=^/skip/auth/routes/get",
|
|
}
|
|
err := validation.Validate(opts)
|
|
assert.NoError(t, err)
|
|
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
method string
|
|
url string
|
|
allowed bool
|
|
}{
|
|
{
|
|
name: "Regex GET allowed",
|
|
method: "GET",
|
|
url: "/skip/auth/regex",
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Regex POST allowed ",
|
|
method: "POST",
|
|
url: "/skip/auth/regex",
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Regex denied",
|
|
method: "GET",
|
|
url: "/wrong/denied",
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Route allowed",
|
|
method: "GET",
|
|
url: "/skip/auth/routes/get",
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Route denied with wrong method",
|
|
method: "PATCH",
|
|
url: "/skip/auth/routes/get",
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Route denied with wrong path",
|
|
method: "GET",
|
|
url: "/skip/auth/routes/wrong/path",
|
|
allowed: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, err := http.NewRequest(tc.method, tc.url, nil)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req))
|
|
|
|
rw := httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
if tc.allowed {
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "Allowed Request", rw.Body.String())
|
|
} else {
|
|
assert.Equal(t, 403, rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProxyAllowedGroups(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedGroups []string
|
|
groups []string
|
|
expectUnauthorized bool
|
|
}{
|
|
{"NoAllowedGroups", []string{}, []string{}, false},
|
|
{"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false},
|
|
{"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false},
|
|
{"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
emailAddress := "test"
|
|
created := time.Now()
|
|
|
|
session := &sessions.SessionState{
|
|
Groups: tt.groups,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
}))
|
|
t.Cleanup(upstreamServer.Close)
|
|
|
|
test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
|
|
opts.Providers[0].AllowedGroups = tt.allowedGroups
|
|
opts.UpstreamServers = options.Upstreams{
|
|
{
|
|
ID: upstreamServer.URL,
|
|
Path: "/",
|
|
URI: upstreamServer.URL,
|
|
},
|
|
}
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
test.req, _ = http.NewRequest("GET", "/", nil)
|
|
|
|
test.req.Header.Add("accept", applicationJSON)
|
|
err = test.SaveSession(session)
|
|
assert.NoError(t, err)
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
|
|
if tt.expectUnauthorized {
|
|
assert.Equal(t, http.StatusForbidden, test.rw.Code)
|
|
} else {
|
|
assert.Equal(t, http.StatusOK, test.rw.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthOnlyAllowedGroups(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
allowedGroups []string
|
|
groups []string
|
|
querystring string
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
name: "NoAllowedGroups",
|
|
allowedGroups: []string{},
|
|
groups: []string{},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "NoAllowedGroupsUserHasGroups",
|
|
allowedGroups: []string{},
|
|
groups: []string{"a", "b"},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInAllowedGroup",
|
|
allowedGroups: []string{"a"},
|
|
groups: []string{"a", "b"},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserNotInAllowedGroup",
|
|
allowedGroups: []string{"a"},
|
|
groups: []string{"c"},
|
|
querystring: "",
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "UserInQuerystringGroup",
|
|
allowedGroups: []string{"a", "b"},
|
|
groups: []string{"a", "c"},
|
|
querystring: "?allowed_groups=a",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInMultiParamQuerystringGroup",
|
|
allowedGroups: []string{"a", "b"},
|
|
groups: []string{"b"},
|
|
querystring: "?allowed_groups=a&allowed_groups=b,d",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInOnlyQuerystringGroup",
|
|
allowedGroups: []string{},
|
|
groups: []string{"a", "c"},
|
|
querystring: "?allowed_groups=a,b",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserInDelimitedQuerystringGroup",
|
|
allowedGroups: []string{"a", "b", "c"},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=a,c",
|
|
expectedStatusCode: http.StatusAccepted,
|
|
},
|
|
{
|
|
name: "UserNotInQuerystringGroup",
|
|
allowedGroups: []string{},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=a,b",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInConfigGroupNotInQuerystringGroup",
|
|
allowedGroups: []string{"a", "b", "c"},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=a,b",
|
|
expectedStatusCode: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "UserInQuerystringGroupNotInConfigGroup",
|
|
allowedGroups: []string{"a", "b"},
|
|
groups: []string{"c"},
|
|
querystring: "?allowed_groups=b,c",
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
emailAddress := "test"
|
|
created := time.Now()
|
|
|
|
session := &sessions.SessionState{
|
|
Groups: tc.groups,
|
|
Email: emailAddress,
|
|
AccessToken: "oauth_token",
|
|
CreatedAt: &created,
|
|
}
|
|
|
|
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
|
|
opts.Providers[0].AllowedGroups = tc.allowedGroups
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = test.SaveSession(session)
|
|
assert.NoError(t, err)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
|
|
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
|
|
})
|
|
}
|
|
}
|