package main import ( "bufio" "context" "crypto" "encoding/base64" "fmt" "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" "os" "regexp" "strings" "testing" "time" "github.com/coreos/go-oidc" "github.com/mbland/hmacauth" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" sessionscookie "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/v7/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( // The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64 // encoded version of this. rawCookieSecret = "secretthirtytwobytes+abcdefghijk" base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams" clientID = "3984n253984d7348dm8234yf982t" clientSecret = "gv3498mfc9t23y23974dm2394dm9" ) func init() { logger.SetFlags(logger.Lshortfile) } func TestRobotsTxt(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) assert.Equal(t, 200, rw.Code) assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) } func TestIsValidRedirect(t *testing.T) { opts := baseTestOptions() // Should match domains that are exactly foo.bar and any subdomain of bar.foo opts.WhitelistDomains = []string{ "foo.bar", ".bar.foo", "port.bar:8080", ".sub.port.bar:8080", "anyport.bar:*", ".sub.anyport.bar:*", } err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } testCases := []struct { Desc, Redirect string ExpectedResult bool }{ { Desc: "noRD", Redirect: "", ExpectedResult: false, }, { Desc: "singleSlash", Redirect: "/redirect", ExpectedResult: true, }, { Desc: "doubleSlash", Redirect: "//redirect", ExpectedResult: false, }, { Desc: "validHTTP", Redirect: "http://foo.bar/redirect", ExpectedResult: true, }, { Desc: "validHTTPS", Redirect: "https://foo.bar/redirect", ExpectedResult: true, }, { Desc: "invalidHTTPSubdomain", Redirect: "http://baz.foo.bar/redirect", ExpectedResult: false, }, { Desc: "invalidHTTPSSubdomain", Redirect: "https://baz.foo.bar/redirect", ExpectedResult: false, }, { Desc: "validHTTPSubdomain", Redirect: "http://baz.bar.foo/redirect", ExpectedResult: true, }, { Desc: "validHTTPSSubdomain", Redirect: "https://baz.bar.foo/redirect", ExpectedResult: true, }, { Desc: "validHTTPDomain", Redirect: "http://bar.foo/redirect", ExpectedResult: true, }, { Desc: "invalidHTTP1", Redirect: "http://foo.bar.evil.corp/redirect", ExpectedResult: false, }, { Desc: "invalidHTTPS1", Redirect: "https://foo.bar.evil.corp/redirect", ExpectedResult: false, }, { Desc: "invalidHTTP2", Redirect: "http://evil.corp/redirect?rd=foo.bar", ExpectedResult: false, }, { Desc: "invalidHTTPS2", Redirect: "https://evil.corp/redirect?rd=foo.bar", ExpectedResult: false, }, { Desc: "invalidPort", Redirect: "https://evil.corp:3838/redirect", ExpectedResult: false, }, { Desc: "invalidEmptyPort", Redirect: "http://foo.bar:3838/redirect", ExpectedResult: false, }, { Desc: "invalidEmptyPortSubdomain", Redirect: "http://baz.bar.foo:3838/redirect", ExpectedResult: false, }, { Desc: "validSpecificPort", Redirect: "http://port.bar:8080/redirect", ExpectedResult: true, }, { Desc: "invalidSpecificPort", Redirect: "http://port.bar:3838/redirect", ExpectedResult: false, }, { Desc: "validSpecificPortSubdomain", Redirect: "http://foo.sub.port.bar:8080/redirect", ExpectedResult: true, }, { Desc: "invalidSpecificPortSubdomain", Redirect: "http://foo.sub.port.bar:3838/redirect", ExpectedResult: false, }, { Desc: "validAnyPort1", Redirect: "http://anyport.bar:8080/redirect", ExpectedResult: true, }, { Desc: "validAnyPort2", Redirect: "http://anyport.bar:8081/redirect", ExpectedResult: true, }, { Desc: "validAnyPortSubdomain1", Redirect: "http://a.sub.anyport.bar:8080/redirect", ExpectedResult: true, }, { Desc: "validAnyPortSubdomain2", Redirect: "http://a.sub.anyport.bar:8081/redirect", ExpectedResult: true, }, { Desc: "openRedirect1", Redirect: "/\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectSpace1", Redirect: "/ /evil.com", ExpectedResult: false, }, { Desc: "openRedirectSpace2", Redirect: "/ \\evil.com", ExpectedResult: false, }, { Desc: "openRedirectTab1", Redirect: "/\t/evil.com", ExpectedResult: false, }, { Desc: "openRedirectTab2", Redirect: "/\t\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectVerticalTab1", Redirect: "/\v/evil.com", ExpectedResult: false, }, { Desc: "openRedirectVerticalTab2", Redirect: "/\v\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectNewLine1", Redirect: "/\n/evil.com", ExpectedResult: false, }, { Desc: "openRedirectNewLine2", Redirect: "/\n\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectCarriageReturn1", Redirect: "/\r/evil.com", ExpectedResult: false, }, { Desc: "openRedirectCarriageReturn2", Redirect: "/\r\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectTripleTab", Redirect: "/\t\t/\t/evil.com", ExpectedResult: false, }, { Desc: "openRedirectTripleTab2", Redirect: "/\t\t\\\t/evil.com", ExpectedResult: false, }, { Desc: "openRedirectQuadTab1", Redirect: "/\t\t/\t\t\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectQuadTab2", Redirect: "/\t\t\\\t\t/evil.com", ExpectedResult: false, }, { Desc: "openRedirectPeriod1", Redirect: "/./\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectPeriod2", Redirect: "/./../../\\evil.com", ExpectedResult: false, }, { Desc: "openRedirectDoubleTab", Redirect: "/\t/\t\\evil.com", ExpectedResult: false, }, } for _, tc := range testCases { t.Run(tc.Desc, func(t *testing.T) { result := proxy.IsValidRedirect(tc.Redirect) if result != tc.ExpectedResult { t.Errorf("expected %t got %t", tc.ExpectedResult, result) } }) } } func TestOpenRedirects(t *testing.T) { opts := baseTestOptions() // Should match domains that are exactly foo.bar and any subdomain of bar.foo opts.WhitelistDomains = []string{ "foo.bar", ".bar.foo", "port.bar:8080", ".sub.port.bar:8080", "anyport.bar:*", ".sub.anyport.bar:*", "www.whitelisteddomain.tld", } err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } file, err := os.Open("./testdata/openredirects.txt") if err != nil { t.Fatal(err) } defer func(t *testing.T) { if err := file.Close(); err != nil { t.Fatal(err) } }(t) scanner := bufio.NewScanner(file) for scanner.Scan() { rd := scanner.Text() t.Run(rd, func(t *testing.T) { rdUnescaped, err := url.QueryUnescape(rd) if err != nil { t.Fatal(err) } if proxy.IsValidRedirect(rdUnescaped) { t.Errorf("Expected %q to not be valid (unescaped: %q)", rd, rdUnescaped) } }) } if err := scanner.Err(); err != nil { t.Fatal(err) } } type TestProvider struct { *providers.ProviderData EmailAddress string ValidToken bool GroupValidator func(string) bool } var _ providers.Provider = (*TestProvider)(nil) func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { return &TestProvider{ ProviderData: &providers.ProviderData{ ProviderName: "Test Provider", LoginURL: &url.URL{ Scheme: "http", Host: providerURL.Host, Path: "/oauth/authorize", }, RedeemURL: &url.URL{ Scheme: "http", Host: providerURL.Host, Path: "/oauth/token", }, ProfileURL: &url.URL{ Scheme: "http", Host: providerURL.Host, Path: "/api/v1/profile", }, Scope: "profile.email", }, EmailAddress: emailAddress, GroupValidator: func(s string) bool { return true }, } } func (tp *TestProvider) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) { return tp.EmailAddress, nil } func (tp *TestProvider) ValidateSessionState(_ context.Context, _ *sessions.SessionState) bool { return tp.ValidToken } func Test_redeemCode(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } _, err = proxy.redeemCode(context.Background(), "www.example.com", "") assert.Error(t, err) } func Test_enrichSession(t *testing.T) { const ( sessionUser = "Mr Session" sessionEmail = "session@example.com" providerEmail = "provider@example.com" ) testCases := map[string]struct { session *sessions.SessionState expectedUser string expectedEmail string }{ "Session already has enrichable fields": { session: &sessions.SessionState{ User: sessionUser, Email: sessionEmail, }, expectedUser: sessionUser, expectedEmail: sessionEmail, }, "Session is missing Email and GetEmailAddress is implemented": { session: &sessions.SessionState{ User: sessionUser, }, expectedUser: sessionUser, expectedEmail: providerEmail, }, "Session is missing User and GetUserName is not implemented": { session: &sessions.SessionState{ Email: sessionEmail, }, expectedUser: "", expectedEmail: sessionEmail, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) assert.NoError(t, err) // intentionally set after validation.Validate(opts) since it will clobber // our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider` opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } err = proxy.enrichSessionState(context.Background(), tc.session) assert.NoError(t, err) assert.Equal(t, tc.expectedUser, tc.session.User) assert.Equal(t, tc.expectedEmail, tc.session.Email) }) } } func TestBasicAuthPassword(t *testing.T) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%#v", r) var payload string switch r.URL.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: payload = r.Header.Get("Authorization") if payload == "" { payload = "No Authorization header found." } } w.WriteHeader(200) _, err := w.Write([]byte(payload)) if err != nil { t.Fatal(err) } })) basicAuthPassword := "This is a secure password" opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { ID: providerServer.URL, Path: "/", URI: providerServer.URL, }, } opts.Cookie.Secure = false opts.InjectRequestHeaders = []options.Header{ { Name: "Authorization", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", BasicAuthPassword: &options.SecretSource{ Value: []byte(base64.StdEncoding.EncodeToString([]byte(basicAuthPassword))), }, }, }, }, }, } err := validation.Validate(opts) assert.NoError(t, err) providerURL, _ := url.Parse(providerServer.URL) const emailAddress = "john.doe@example.com" opts.SetProvider(NewTestProvider(providerURL, emailAddress)) proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) if err != nil { t.Fatal(err) } // Save the required session rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) err = proxy.sessionStore.Save(rw, req, &sessions.SessionState{ Email: emailAddress, }) assert.NoError(t, err) // Extract the cookie value to inject into the test request cookie := rw.Header().Values("Set-Cookie")[0] req, _ = http.NewRequest("GET", "/", nil) req.Header.Set("Cookie", cookie) rw = httptest.NewRecorder() proxy.ServeHTTP(rw, req) // The username in the basic auth credentials is expected to be equal to the email address from the // auth response, so we use the same variable here. expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+basicAuthPassword)) assert.Equal(t, expectedHeader, rw.Body.String()) providerServer.Close() } func TestPassGroupsHeadersWithGroups(t *testing.T) { opts := baseTestOptions() opts.InjectRequestHeaders = []options.Header{ { Name: "X-Forwarded-Groups", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "groups", }, }, }, }, } err := validation.Validate(opts) assert.NoError(t, err) const emailAddress = "john.doe@example.com" const userName = "9fcab5c9b889a557" groups := []string{"a", "b"} created := time.Now() session := &sessions.SessionState{ User: userName, Groups: groups, Email: emailAddress, AccessToken: "oauth_token", CreatedAt: &created, } proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) assert.NoError(t, err) // Save the required session rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) err = proxy.sessionStore.Save(rw, req, session) assert.NoError(t, err) // Extract the cookie value to inject into the test request cookie := rw.Header().Values("Set-Cookie")[0] req, _ = http.NewRequest("GET", "/", nil) req.Header.Set("Cookie", cookie) rw = httptest.NewRecorder() proxy.ServeHTTP(rw, req) assert.Equal(t, groups, req.Header["X-Forwarded-Groups"]) } type PassAccessTokenTest struct { providerServer *httptest.Server proxy *OAuthProxy opts *options.Options } type PassAccessTokenTestOptions struct { PassAccessToken bool ProxyUpstream options.Upstream } func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { patt := &PassAccessTokenTest{} patt.providerServer = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var payload string switch r.URL.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: payload = r.Header.Get("X-Forwarded-Access-Token") if payload == "" { payload = "No access token found." } } w.WriteHeader(200) _, err := w.Write([]byte(payload)) if err != nil { panic(err) } })) patt.opts = baseTestOptions() patt.opts.UpstreamServers = options.Upstreams{ { ID: patt.providerServer.URL, Path: "/", URI: patt.providerServer.URL, }, } if opts.ProxyUpstream.ID != "" { patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream) } patt.opts.Cookie.Secure = false if opts.PassAccessToken { patt.opts.InjectRequestHeaders = []options.Header{ { Name: "X-Forwarded-Access-Token", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "access_token", }, }, }, }, } } err := validation.Validate(patt.opts) if err != nil { return nil, err } providerURL, _ := url.Parse(patt.providerServer.URL) const emailAddress = "michael.bland@gsa.gov" patt.opts.SetProvider(NewTestProvider(providerURL, emailAddress)) patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { return email == emailAddress }) if err != nil { return nil, err } return patt, nil } func (patTest *PassAccessTokenTest) Close() { patTest.providerServer.Close() } func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) { rw := httptest.NewRecorder() req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) if err != nil { return 0, "" } req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) patTest.proxy.ServeHTTP(rw, req) return rw.Code, rw.Header().Values("Set-Cookie")[1] } // getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath // and cookie and returns body and status code. func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) { cookieName := patTest.proxy.CookieName var value string keyPrefix := cookieName + "=" for _, field := range strings.Split(cookie, "; ") { value = strings.TrimPrefix(field, keyPrefix) if value != field { break } else { value = "" } } if value == "" { return 0, "" } req, err := http.NewRequest("GET", endpoint, strings.NewReader("")) if err != nil { return 0, "" } req.AddCookie(&http.Cookie{ Name: cookieName, Value: value, Path: "/", Expires: time.Now().Add(time.Duration(24)), HttpOnly: true, }) rw := httptest.NewRecorder() patTest.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestForwardAccessTokenUpstream(t *testing.T) { patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, }) if err != nil { t.Fatal(err) } t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } assert.NotNil(t, cookie) // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is // read by the test provider server and written in the response body. code, payload := patTest.getEndpointWithCookie(cookie, "/") if code != 200 { t.Fatalf("expected 200; got %d", code) } assert.Equal(t, "my_auth_token", payload) } func TestStaticProxyUpstream(t *testing.T) { patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, ProxyUpstream: options.Upstream{ ID: "static-proxy", Path: "/static-proxy", Static: true, }, }) if err != nil { t.Fatal(err) } t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } assert.NotEqual(t, nil, cookie) // Now we make a regular request against the upstream proxy; And validate // the returned status code through the static proxy. code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy") if code != 200 { t.Fatalf("expected 200; got %d", code) } assert.Equal(t, "Authenticated", payload) } func TestDoNotForwardAccessTokenUpstream(t *testing.T) { patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false, }) if err != nil { t.Fatal(err) } t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } assert.NotEqual(t, nil, cookie) // Now we make a regular request, but the access token header should // not be present. code, payload := patTest.getEndpointWithCookie(cookie, "/") if code != 200 { t.Fatalf("expected 200; got %d", code) } assert.Equal(t, "No access token found.", payload) } type SignInPageTest struct { opts *options.Options proxy *OAuthProxy signInRegexp *regexp.Regexp signInProviderRegexp *regexp.Regexp } const signInRedirectPattern = `` const signInSkipProvider = `>Found<` func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) { var sipTest SignInPageTest sipTest.opts = baseTestOptions() sipTest.opts.SkipProviderButton = skipProvider err := validation.Validate(sipTest.opts) if err != nil { return nil, err } sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { return true }) if err != nil { return nil, err } sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) return &sipTest, nil } func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) sipTest.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestSignInPageIncludesTargetRedirect(t *testing.T) { sipTest, err := NewSignInPageTest(false) if err != nil { t.Fatal(err) } const endpoint = "/some/random/endpoint" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 403, code) match := sipTest.signInRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInRedirectPattern + "\nBody:\n" + body) } if match[1] != endpoint { t.Fatal(`expected redirect to "` + endpoint + `", but was "` + match[1] + `"`) } } func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { sipTest, err := NewSignInPageTest(false) if err != nil { t.Fatal(err) } code, body := sipTest.GetEndpoint("/oauth2/sign_in") assert.Equal(t, 200, code) match := sipTest.signInRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInRedirectPattern + "\nBody:\n" + body) } if match[1] != "/" { t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) } } func TestSignInPageSkipProvider(t *testing.T) { sipTest, err := NewSignInPageTest(true) if err != nil { t.Fatal(err) } endpoint := "/some/random/endpoint" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) match := sipTest.signInProviderRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) } } func TestSignInPageSkipProviderDirect(t *testing.T) { sipTest, err := NewSignInPageTest(true) if err != nil { t.Fatal(err) } endpoint := "/sign_in" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) match := sipTest.signInProviderRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) } } type ProcessCookieTest struct { opts *options.Options proxy *OAuthProxy rw *httptest.ResponseRecorder req *http.Request validateUser bool } type ProcessCookieTestOpts struct { providerValidateCookieResponse bool } type OptionsModifier func(*options.Options) func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() for _, modifier := range modifiers { modifier(pcTest.opts) } // First, set the CookieRefresh option so proxy.AesCipher is created, // needed to encrypt the access_token. pcTest.opts.Cookie.Refresh = time.Hour err := validation.Validate(pcTest.opts) if err != nil { return nil, err } pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { return nil, err } pcTest.proxy.provider = &TestProvider{ ProviderData: &providers.ProviderData{}, ValidToken: opts.providerValidateCookieResponse, } pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. pcTest.proxy.CookieRefresh = time.Duration(0) pcTest.rw = httptest.NewRecorder() pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) pcTest.validateUser = true return &pcTest, nil } func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) { return NewProcessCookieTest(ProcessCookieTestOpts{ providerValidateCookieResponse: true, }) } func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { return NewProcessCookieTest(ProcessCookieTestOpts{ providerValidateCookieResponse: true, }, modifiers...) } func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error { err := p.proxy.SaveSession(p.rw, p.req, s) if err != nil { return err } for _, cookie := range p.rw.Result().Cookies() { p.req.AddCookie(cookie) } return nil } func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) { return p.proxy.LoadCookiedSession(p.req) } func TestLoadCookiedSession(t *testing.T) { pcTest, err := NewProcessCookieTestWithDefaults() if err != nil { t.Fatal(err) } created := time.Now() startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created} err = pcTest.SaveSession(startSession) assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() if err != nil { t.Fatal(err) } assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "", session.User) assert.Equal(t, startSession.AccessToken, session.AccessToken) } func TestProcessCookieNoCookieError(t *testing.T) { pcTest, err := NewProcessCookieTestWithDefaults() if err != nil { t.Fatal(err) } session, err := pcTest.LoadCookiedSession() assert.Error(t, err, "cookie \"_oauth2_proxy\" not present") if session != nil { t.Errorf("expected nil session. got %#v", session) } } func TestProcessCookieRefreshNotSet(t *testing.T) { pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(23) * time.Hour }) if err != nil { t.Fatal(err) } reference := time.Now().Add(time.Duration(-2) * time.Hour) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} err = pcTest.SaveSession(startSession) assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) if session.Age() < time.Duration(-2)*time.Hour { t.Errorf("cookie too young %v", session.Age()) } assert.Equal(t, startSession.Email, session.Email) } func TestProcessCookieFailIfCookieExpired(t *testing.T) { pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) if err != nil { t.Fatal(err) } reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} err = pcTest.SaveSession(startSession) assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) } } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) if err != nil { t.Fatal(err) } reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} err = pcTest.SaveSession(startSession) assert.NoError(t, err) pcTest.proxy.CookieRefresh = time.Hour session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) } } func NewUserInfoEndpointTest() (*ProcessCookieTest, error) { pcTest, err := NewProcessCookieTestWithDefaults() if err != nil { return nil, err } pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/userinfo", nil) return pcTest, nil } func TestUserInfoEndpointAccepted(t *testing.T) { test, err := NewUserInfoEndpointTest() if err != nil { t.Fatal(err) } startSession := &sessions.SessionState{ Email: "john.doe@example.com", AccessToken: "my_access_token"} err = test.SaveSession(startSession) assert.NoError(t, err) } func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { test, err := NewUserInfoEndpointTest() if err != nil { t.Fatal(err) } test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...) if err != nil { return nil, err } pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) return pcTest, nil } func TestAuthOnlyEndpointAccepted(t *testing.T) { test, err := NewAuthOnlyEndpointTest() if err != nil { t.Fatal(err) } created := time.Now() startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} err = test.SaveSession(startSession) assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusAccepted, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { test, err := NewAuthOnlyEndpointTest() if err != nil { t.Fatal(err) } test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) if err != nil { t.Fatal(err) } reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} err = test.SaveSession(startSession) assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test, err := NewAuthOnlyEndpointTest() if err != nil { t.Fatal(err) } created := time.Now() startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} err = test.SaveSession(startSession) assert.NoError(t, err) test.validateUser = false test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() pcTest.opts.InjectResponseHeaders = []options.Header{ { Name: "X-Auth-Request-User", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", }, }, }, }, { Name: "X-Auth-Request-Email", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", }, }, }, }, { Name: "X-Auth-Request-Groups", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "groups", }, }, }, }, { Name: "X-Forwarded-Preferred-Username", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "preferred_username", }, }, }, }, } pcTest.opts.AllowedGroups = []string{"oauth_groups"} err := validation.Validate(pcTest.opts) assert.NoError(t, err) pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ProviderData: &providers.ProviderData{}, ValidToken: true, } pcTest.validateUser = true pcTest.rw = httptest.NewRecorder() pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Groups: []string{"oauth_groups"}, Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} err = pcTest.SaveSession(startSession) assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User")) assert.Equal(t, startSession.Groups, pcTest.rw.Header().Values("X-Auth-Request-Groups")) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email")) } func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() pcTest.opts.InjectResponseHeaders = []options.Header{ { Name: "X-Auth-Request-User", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", }, }, }, }, { Name: "X-Auth-Request-Email", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", }, }, }, }, { Name: "X-Auth-Request-Groups", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "groups", }, }, }, }, { Name: "X-Forwarded-Preferred-Username", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "preferred_username", }, }, }, }, { Name: "Authorization", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", BasicAuthPassword: &options.SecretSource{ Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))), }, }, }, }, }, } err := validation.Validate(pcTest.opts) assert.NoError(t, err) pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ProviderData: &providers.ProviderData{}, ValidToken: true, } pcTest.validateUser = true pcTest.rw = httptest.NewRecorder() pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} err = pcTest.SaveSession(startSession) assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0]) expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:This is a secure password")) assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0]) } func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() pcTest.opts.InjectResponseHeaders = []options.Header{ { Name: "X-Auth-Request-User", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", }, }, }, }, { Name: "X-Auth-Request-Email", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", }, }, }, }, { Name: "X-Auth-Request-Groups", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "groups", }, }, }, }, { Name: "X-Forwarded-Preferred-Username", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "preferred_username", }, }, }, }, } err := validation.Validate(pcTest.opts) assert.NoError(t, err) pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ProviderData: &providers.ProviderData{}, ValidToken: true, } pcTest.validateUser = true pcTest.rw = httptest.NewRecorder() pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} err = pcTest.SaveSession(startSession) assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0]) assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries") } func TestAuthSkippedForPreflightRequests(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) _, err := w.Write([]byte("response")) if err != nil { t.Fatal(err) } })) t.Cleanup(upstreamServer.Close) opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { ID: upstreamServer.URL, Path: "/", URI: upstreamServer.URL, }, } opts.SkipAuthPreflight = true err := validation.Validate(opts) assert.NoError(t, err) upstreamURL, _ := url.Parse(upstreamServer.URL) opts.SetProvider(NewTestProvider(upstreamURL, "")) proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) if err != nil { t.Fatal(err) } rw := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) proxy.ServeHTTP(rw, req) assert.Equal(t, 200, rw.Code) assert.Equal(t, "response", rw.Body.String()) } type SignatureAuthenticator struct { auth hmacauth.HmacAuth } func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) { result, headerSig, computedSig := v.auth.AuthenticateRequest(r) var msg string switch result { case hmacauth.ResultNoSignature: msg = "no signature received" case hmacauth.ResultMatch: msg = "signatures match" case hmacauth.ResultMismatch: msg = fmt.Sprintf( "signatures do not match:\n received: %s\n computed: %s", headerSig, computedSig) default: panic("unknown result value: " + result.String()) } _, err := w.Write([]byte(msg)) if err != nil { panic(err) } } type SignatureTest struct { opts *options.Options upstream *httptest.Server upstreamHost string provider *httptest.Server header http.Header rw *httptest.ResponseRecorder authenticator *SignatureAuthenticator } func NewSignatureTest() (*SignatureTest, error) { opts := baseTestOptions() opts.EmailDomains = []string{"acm.org"} authenticator := &SignatureAuthenticator{} upstreamServer := httptest.NewServer( http.HandlerFunc(authenticator.Authenticate)) upstreamURL, err := url.Parse(upstreamServer.URL) if err != nil { return nil, err } opts.UpstreamServers = options.Upstreams{ { ID: upstreamServer.URL, Path: "/", URI: upstreamServer.URL, }, } providerHandler := func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) if err != nil { panic(err) } } provider := httptest.NewServer(http.HandlerFunc(providerHandler)) providerURL, err := url.Parse(provider.URL) if err != nil { return nil, err } opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) return &SignatureTest{ opts, upstreamServer, upstreamURL.Host, provider, make(http.Header), httptest.NewRecorder(), authenticator, }, nil } func (st *SignatureTest) Close() { st.provider.Close() st.upstream.Close() } // fakeNetConn simulates an http.Request.Body buffer that will be consumed // when it is read by the hmacauth.HmacAuth if not handled properly. See: // https://github.com/18F/hmacauth/pull/4 type fakeNetConn struct { reqBody string } func (fnc *fakeNetConn) Read(p []byte) (n int, err error) { if bodyLen := len(fnc.reqBody); bodyLen != 0 { copy(p, fnc.reqBody) fnc.reqBody = "" return bodyLen, io.EOF } return 0, io.EOF } func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error { err := validation.Validate(st.opts) if err != nil { return err } proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) if err != nil { return err } var bodyBuf io.ReadCloser if body != "" { bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body}) } req := httptest.NewRequest(method, "/foo/bar", bodyBuf) req.Header = st.header state := &sessions.SessionState{ Email: "mbland@acm.org", AccessToken: "my_access_token"} err = proxy.SaveSession(st.rw, req, state) if err != nil { return err } for _, c := range st.rw.Result().Cookies() { req.AddCookie(c) } // This is used by the upstream to validate the signature. st.authenticator.auth = hmacauth.NewHmacAuth( crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders) proxy.ServeHTTP(st.rw, req) return nil } func TestRequestSignature(t *testing.T) { testCases := map[string]struct { method string body string key string resp string }{ "No request signature": { method: "GET", body: "", key: "", resp: "no signature received", }, "Get request": { method: "GET", body: "", key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d", resp: "signatures match", }, "Post request": { method: "POST", body: `{ "hello": "world!" }`, key: "d90df39e2d19282840252612dd7c81421a372f61", resp: "signatures match", }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { st, err := NewSignatureTest() if err != nil { t.Fatal(err) } t.Cleanup(st.Close) if tc.key != "" { st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key) } err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key) assert.NoError(t, err) assert.Equal(t, 200, st.rw.Code) assert.Equal(t, tc.resp, st.rw.Body.String()) }) } } func TestGetRedirect(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) proxy, err := NewOAuthProxy(opts, func(s string) bool { return false }) if err != nil { t.Fatal(err) } tests := []struct { name string url string expectedRedirect string }{ { name: "request outside of ProxyPrefix redirects to original URL", url: "/foo/bar", expectedRedirect: "/foo/bar", }, { name: "request with query preserves query", url: "/foo?bar", expectedRedirect: "/foo?bar", }, { name: "request under ProxyPrefix redirects to root", url: proxy.ProxyPrefix + "/foo/bar", expectedRedirect: "/", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", tt.url, nil) redirect, err := proxy.GetRedirect(req) assert.NoError(t, err) assert.Equal(t, tt.expectedRedirect, redirect) }) } } type ajaxRequestTest struct { opts *options.Options proxy *OAuthProxy } func newAjaxRequestTest() (*ajaxRequestTest, error) { test := &ajaxRequestTest{} test.opts = baseTestOptions() err := validation.Validate(test.opts) if err != nil { return nil, err } test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { return true }) if err != nil { return nil, err } return test, nil } func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { rw := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader("")) if err != nil { return 0, nil, err } req.Header = header test.proxy.ServeHTTP(rw, req) return rw.Code, rw.Header(), nil } func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { test, err := newAjaxRequestTest() if err != nil { t.Fatal(err) } endpoint := "/test" code, rh, err := test.getEndpoint(endpoint, header) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, code) mime := rh.Get("Content-Type") assert.Equal(t, applicationJSON, mime) } func TestAjaxUnauthorizedRequest1(t *testing.T) { header := make(http.Header) header.Add("accept", applicationJSON) testAjaxUnauthorizedRequest(t, header) } func TestAjaxUnauthorizedRequest2(t *testing.T) { header := make(http.Header) header.Add("Accept", applicationJSON) testAjaxUnauthorizedRequest(t, header) } func TestAjaxForbiddendRequest(t *testing.T) { test, err := newAjaxRequestTest() if err != nil { t.Fatal(err) } endpoint := "/test" header := make(http.Header) code, rh, err := test.getEndpoint(endpoint, header) assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, code) mime := rh.Get("Content-Type") assert.NotEqual(t, applicationJSON, mime) } func TestClearSplitCookie(t *testing.T) { opts := baseTestOptions() opts.Cookie.Secret = base64CookieSecret opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} err := validation.Validate(opts) assert.NoError(t, err) store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) if err != nil { t.Fatal(err) } p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) req.AddCookie(&http.Cookie{ Name: "test1", Value: "test1", }) req.AddCookie(&http.Cookie{ Name: "oauth2_0", Value: "oauth2_0", }) req.AddCookie(&http.Cookie{ Name: "oauth2_1", Value: "oauth2_1", }) err = p.ClearSessionCookie(rw, req) assert.NoError(t, err) header := rw.Header() assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries") } func TestClearSingleCookie(t *testing.T) { opts := baseTestOptions() opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) if err != nil { t.Fatal(err) } p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) req.AddCookie(&http.Cookie{ Name: "test1", Value: "test1", }) req.AddCookie(&http.Cookie{ Name: "oauth2", Value: "oauth2", }) err = p.ClearSessionCookie(rw, req) assert.NoError(t, err) header := rw.Header() assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries") } type NoOpKeySet struct { } func (NoOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { splitStrings := strings.Split(jwt, ".") payloadString := splitStrings[1] return base64.RawURLEncoding.DecodeString(payloadString) } func TestGetJwtSession(t *testing.T) { /* token payload: { "sub": "1234567890", "aud": "https://test.myapp.com", "name": "John Doe", "email": "john@example.com", "iss": "https://issuer.example.com", "iat": 1553691215, "exp": 1912151821 } */ goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + "WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + "E1LCJleHAiOjE5MTIxNTE4MjF9." + "rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + "OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" keyset := NoOpKeySet{} verifier := oidc.NewVerifier("https://issuer.example.com", keyset, &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.InjectRequestHeaders = []options.Header{ { Name: "Authorization", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "id_token", Prefix: "Bearer ", }, }, }, }, { Name: "X-Forwarded-User", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", }, }, }, }, { Name: "X-Forwarded-Email", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", }, }, }, }, } opts.InjectResponseHeaders = []options.Header{ { Name: "Authorization", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "id_token", Prefix: "Bearer ", }, }, }, }, { Name: "X-Auth-Request-User", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", }, }, }, }, { Name: "X-Auth-Request-Email", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", }, }, }, }, } opts.SkipJwtBearerTokens = true opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier)) }) if err != nil { t.Fatal(err) } tp, _ := test.proxy.provider.(*TestProvider) tp.GroupValidator = func(s string) bool { return true } authHeader := fmt.Sprintf("Bearer %s", goodJwt) test.req.Header = map[string][]string{ "Authorization": {authHeader}, } test.proxy.ServeHTTP(test.rw, test.req) if test.rw.Code >= 400 { t.Fatalf("expected 3xx got %d", test.rw.Code) } // Check PassAuthorization, should overwrite Basic header assert.Equal(t, test.req.Header.Get("Authorization"), authHeader) assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890") assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com") // SetAuthorization and SetXAuthRequest assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader) assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890") assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") } func Test_prepareNoCache(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { prepareNoCache(w) }) mux := http.NewServeMux() mux.Handle("/", handler) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mux.ServeHTTP(rec, req) for k, v := range noCacheHeaders { assert.Equal(t, rec.Header().Get(k), v) } } func Test_noCacheHeaders(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("upstream")) if err != nil { t.Error(err) } })) t.Cleanup(upstreamServer.Close) opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { ID: upstreamServer.URL, Path: "/", URI: upstreamServer.URL, }, } opts.SkipAuthRegex = []string{".*"} err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } t.Run("not exist in response from upstream", func(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/upstream", nil) proxy.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "upstream", rec.Body.String()) // checking noCacheHeaders does not exists in response headers from upstream for k := range noCacheHeaders { assert.Equal(t, "", rec.Header().Get(k)) } }) t.Run("has no-cache", func(t *testing.T) { tests := []struct { path string hasNoCache bool }{ { path: "/oauth2/sign_in", hasNoCache: true, }, { path: "/oauth2/sign_out", hasNoCache: true, }, { path: "/oauth2/start", hasNoCache: true, }, { path: "/oauth2/callback", hasNoCache: true, }, { path: "/oauth2/auth", hasNoCache: false, }, { path: "/oauth2/userinfo", hasNoCache: true, }, { path: "/upstream", hasNoCache: false, }, } for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, tt.path, nil) proxy.ServeHTTP(rec, req) cacheControl := rec.Result().Header.Get("Cache-Control") if tt.hasNoCache != (strings.Contains(cacheControl, "no-cache")) { t.Errorf(`unexpected "Cache-Control" header: %s`, cacheControl) } }) } }) } func baseTestOptions() *options.Options { opts := options.NewOptions() opts.Cookie.Secret = rawCookieSecret opts.ClientID = clientID opts.ClientSecret = clientSecret opts.EmailDomains = []string{"*"} // Default injected headers for legacy configuration opts.InjectRequestHeaders = []options.Header{ { Name: "Authorization", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", BasicAuthPassword: &options.SecretSource{ Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))), }, }, }, }, }, { Name: "X-Forwarded-User", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "user", }, }, }, }, { Name: "X-Forwarded-Email", Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ Claim: "email", }, }, }, }, } return opts } func TestTrustedIPs(t *testing.T) { tests := []struct { name string trustedIPs []string reverseProxy bool realClientIPHeader string req *http.Request expectTrusted bool }{ // Check unconfigured behavior. { name: "Default", trustedIPs: nil, reverseProxy: false, realClientIPHeader: "X-Real-IP", // Default value req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) return req }(), expectTrusted: false, }, // Check using req.RemoteAddr (Options.ReverseProxy == false). { name: "WithRemoteAddr", trustedIPs: []string{"127.0.0.1"}, reverseProxy: false, realClientIPHeader: "X-Real-IP", // Default value req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.RemoteAddr = "127.0.0.1:43670" return req }(), expectTrusted: true, }, // Check ignores req.RemoteAddr match when behind a reverse proxy / missing header. { name: "IgnoresRemoteAddrInReverseProxyMode", trustedIPs: []string{"127.0.0.1"}, reverseProxy: true, realClientIPHeader: "X-Real-IP", // Default value req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.RemoteAddr = "127.0.0.1:44324" return req }(), expectTrusted: false, }, // Check successful trusting of localhost in IPv4. { name: "TrustsLocalhostInReverseProxyMode", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: true, realClientIPHeader: "X-Forwarded-For", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.Header.Add("X-Forwarded-For", "127.0.0.1") return req }(), expectTrusted: true, }, // Check successful trusting of localhost in IPv6. { name: "TrustsIP6LocalostInReverseProxyMode", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: true, realClientIPHeader: "X-Forwarded-For", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.Header.Add("X-Forwarded-For", "::1") return req }(), expectTrusted: true, }, // Check does not trust random IPv4 address. { name: "DoesNotTrustRandomIP4Address", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: true, realClientIPHeader: "X-Forwarded-For", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.Header.Add("X-Forwarded-For", "12.34.56.78") return req }(), expectTrusted: false, }, // Check does not trust random IPv6 address. { name: "DoesNotTrustRandomIP6Address", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: true, realClientIPHeader: "X-Forwarded-For", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.Header.Add("X-Forwarded-For", "::2") return req }(), expectTrusted: false, }, // Check respects correct header. { name: "RespectsCorrectHeaderInReverseProxyMode", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: true, realClientIPHeader: "X-Forwarded-For", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.Header.Add("X-Real-IP", "::1") return req }(), expectTrusted: false, }, // Check doesn't trust if garbage is provided. { name: "DoesNotTrustGarbageInReverseProxyMode", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: true, realClientIPHeader: "X-Forwarded-For", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.Header.Add("X-Forwarded-For", "adsfljk29242as!!") return req }(), expectTrusted: false, }, // Check doesn't trust if garbage is provided (no reverse-proxy). { name: "DoesNotTrustGarbage", trustedIPs: []string{"127.0.0.0/8", "::1"}, reverseProxy: false, realClientIPHeader: "X-Real-IP", req: func() *http.Request { req, _ := http.NewRequest("GET", "/", nil) req.RemoteAddr = "adsfljk29242as!!" return req }(), expectTrusted: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { ID: "static", Path: "/", Static: true, }, } opts.TrustedIPs = tt.trustedIPs opts.ReverseProxy = tt.reverseProxy opts.RealClientIPHeader = tt.realClientIPHeader err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) assert.NoError(t, err) rw := httptest.NewRecorder() proxy.ServeHTTP(rw, tt.req) if tt.expectTrusted { assert.Equal(t, 200, rw.Code) } else { assert.Equal(t, 403, rw.Code) } }) } } func Test_buildRoutesAllowlist(t *testing.T) { type expectedAllowedRoute struct { method string regexString string } testCases := []struct { name string skipAuthRegex []string skipAuthRoutes []string expectedRoutes []expectedAllowedRoute shouldError bool }{ { name: "No skip auth configured", skipAuthRegex: []string{}, skipAuthRoutes: []string{}, expectedRoutes: []expectedAllowedRoute{}, shouldError: false, }, { name: "Only skipAuthRegex configured", skipAuthRegex: []string{ "^/foo/bar", "^/baz/[0-9]+/thing", }, skipAuthRoutes: []string{}, expectedRoutes: []expectedAllowedRoute{ { method: "", regexString: "^/foo/bar", }, { method: "", regexString: "^/baz/[0-9]+/thing", }, }, shouldError: false, }, { name: "Only skipAuthRoutes configured", skipAuthRegex: []string{}, skipAuthRoutes: []string{ "GET=^/foo/bar", "POST=^/baz/[0-9]+/thing", "^/all/methods$", "WEIRD=^/methods/are/allowed", "PATCH=/second/equals?are=handled&just=fine", }, expectedRoutes: []expectedAllowedRoute{ { method: "GET", regexString: "^/foo/bar", }, { method: "POST", regexString: "^/baz/[0-9]+/thing", }, { method: "", regexString: "^/all/methods$", }, { method: "WEIRD", regexString: "^/methods/are/allowed", }, { method: "PATCH", regexString: "/second/equals?are=handled&just=fine", }, }, shouldError: false, }, { name: "Both skipAuthRegexes and skipAuthRoutes configured", skipAuthRegex: []string{ "^/foo/bar/regex", "^/baz/[0-9]+/thing/regex", }, skipAuthRoutes: []string{ "GET=^/foo/bar", "POST=^/baz/[0-9]+/thing", "^/all/methods$", }, expectedRoutes: []expectedAllowedRoute{ { method: "", regexString: "^/foo/bar/regex", }, { method: "", regexString: "^/baz/[0-9]+/thing/regex", }, { method: "GET", regexString: "^/foo/bar", }, { method: "POST", regexString: "^/baz/[0-9]+/thing", }, { method: "", regexString: "^/all/methods$", }, }, shouldError: false, }, { name: "Invalid skipAuthRegex entry", skipAuthRegex: []string{ "^/foo/bar", "^/baz/[0-9]+/thing", "(bad[regex", }, skipAuthRoutes: []string{}, expectedRoutes: []expectedAllowedRoute{}, shouldError: true, }, { name: "Invalid skipAuthRoutes entry", skipAuthRegex: []string{}, skipAuthRoutes: []string{ "GET=^/foo/bar", "POST=^/baz/[0-9]+/thing", "^/all/methods$", "PUT=(bad[regex", }, expectedRoutes: []expectedAllowedRoute{}, shouldError: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { opts := &options.Options{ SkipAuthRegex: tc.skipAuthRegex, SkipAuthRoutes: tc.skipAuthRoutes, } routes, err := buildRoutesAllowlist(opts) if tc.shouldError { assert.Error(t, err) return } assert.NoError(t, err) for i, route := range routes { assert.Greater(t, len(tc.expectedRoutes), i) assert.Equal(t, route.method, tc.expectedRoutes[i].method) assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) } }) } } func TestAllowedRequest(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) _, err := w.Write([]byte("Allowed Request")) if err != nil { t.Fatal(err) } })) t.Cleanup(upstreamServer.Close) opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { ID: upstreamServer.URL, Path: "/", URI: upstreamServer.URL, }, } opts.SkipAuthRegex = []string{ "^/skip/auth/regex$", } opts.SkipAuthRoutes = []string{ "GET=^/skip/auth/routes/get", } err := validation.Validate(opts) assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } testCases := []struct { name string method string url string allowed bool }{ { name: "Regex GET allowed", method: "GET", url: "/skip/auth/regex", allowed: true, }, { name: "Regex POST allowed ", method: "POST", url: "/skip/auth/regex", allowed: true, }, { name: "Regex denied", method: "GET", url: "/wrong/denied", allowed: false, }, { name: "Route allowed", method: "GET", url: "/skip/auth/routes/get", allowed: true, }, { name: "Route denied with wrong method", method: "PATCH", url: "/skip/auth/routes/get", allowed: false, }, { name: "Route denied with wrong path", method: "GET", url: "/skip/auth/routes/wrong/path", allowed: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req, err := http.NewRequest(tc.method, tc.url, nil) assert.NoError(t, err) assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) rw := httptest.NewRecorder() proxy.ServeHTTP(rw, req) if tc.allowed { assert.Equal(t, 200, rw.Code) assert.Equal(t, "Allowed Request", rw.Body.String()) } else { assert.Equal(t, 403, rw.Code) } }) } } func TestProxyAllowedGroups(t *testing.T) { tests := []struct { name string allowedGroups []string groups []string expectUnauthorized bool }{ {"NoAllowedGroups", []string{}, []string{}, false}, {"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false}, {"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false}, {"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { emailAddress := "test" created := time.Now() session := &sessions.SessionState{ Groups: tt.groups, Email: emailAddress, AccessToken: "oauth_token", CreatedAt: &created, } upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })) t.Cleanup(upstreamServer.Close) test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.AllowedGroups = tt.allowedGroups opts.UpstreamServers = options.Upstreams{ { ID: upstreamServer.URL, Path: "/", URI: upstreamServer.URL, }, } }) if err != nil { t.Fatal(err) } test.req, _ = http.NewRequest("GET", "/", nil) test.req.Header.Add("accept", applicationJSON) err = test.SaveSession(session) assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) if tt.expectUnauthorized { assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } else { assert.Equal(t, http.StatusOK, test.rw.Code) } }) } } func TestAuthOnlyAllowedGroups(t *testing.T) { tests := []struct { name string allowedGroups []string groups []string expectUnauthorized bool }{ {"NoAllowedGroups", []string{}, []string{}, false}, {"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false}, {"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false}, {"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { emailAddress := "test" created := time.Now() session := &sessions.SessionState{ Groups: tt.groups, Email: emailAddress, AccessToken: "oauth_token", CreatedAt: &created, } test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.AllowedGroups = tt.allowedGroups }) if err != nil { t.Fatal(err) } err = test.SaveSession(session) assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) if tt.expectUnauthorized { assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } else { assert.Equal(t, http.StatusAccepted, test.rw.Code) } }) } }