Move RedirectToHTTPS to middleware package
Moves the logic for redirecting to HTTPs to a middleware package and adds tests for this logic. Also makes the functionality more useful, previously it always redirected to the HTTPS address of the proxy, which may not have been intended, now it will redirect based on if a port is provided in the URL (assume public facing 80 to 443 or 4180 to 8443 for example)
This commit is contained in:
parent
39c01d5930
commit
1c1106721e
@ -8,6 +8,7 @@
|
||||
|
||||
## Changes since v6.0.0
|
||||
|
||||
- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed)
|
||||
- [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed)
|
||||
- [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed)
|
||||
- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed)
|
||||
|
18
http.go
18
http.go
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
)
|
||||
@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return redirectToHTTPS(opts, next)
|
||||
}
|
||||
}
|
||||
|
||||
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) {
|
||||
http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
51
http_test.go
51
http_test.go
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -11,56 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPSTrue(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
opts.ForceHTTPS = true
|
||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
|
||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
||||
rw := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(rw, r)
|
||||
|
||||
assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code)
|
||||
}
|
||||
|
||||
func TestRedirectToHTTPSFalse(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
|
||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
||||
rw := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(rw, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code)
|
||||
}
|
||||
|
||||
func TestRedirectNotWhenHTTPS(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
opts.ForceHTTPS = true
|
||||
handler := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
|
||||
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
|
||||
s := httptest.NewTLSServer(h)
|
||||
defer s.Close()
|
||||
|
||||
opts.HTTPSAddress = s.URL
|
||||
client := s.Client()
|
||||
res, err := client.Get(s.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("request to test server failed with error: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
opts := options.NewOptions()
|
||||
stop := make(chan struct{}, 1)
|
||||
|
7
main.go
7
main.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
@ -79,7 +80,11 @@ func main() {
|
||||
chain := alice.New()
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
chain = chain.Append(newRedirectToHTTPS(opts))
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
if err != nil {
|
||||
logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||
}
|
||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||
}
|
||||
|
||||
healthCheckPaths := []string{opts.PingPath}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware")
|
||||
}
|
||||
|
||||
func testHandler() http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte("test"))
|
||||
})
|
||||
}
|
||||
|
50
pkg/middleware/redirect_to_https.go
Normal file
50
pkg/middleware/redirect_to_https.go
Normal file
@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
)
|
||||
|
||||
const httpsScheme = "https"
|
||||
|
||||
// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect
|
||||
// HTTP requests to HTTPS
|
||||
func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return redirectToHTTPS(httpsPort, next)
|
||||
}
|
||||
}
|
||||
|
||||
// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS
|
||||
// if it is not already HTTPS.
|
||||
// If the request is to a non standard port, the redirection request will be
|
||||
// to the port from the httpsAddress given.
|
||||
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
|
||||
// Only care about the connection to us being HTTPS if the proto is empty,
|
||||
// otherwise the proto is source of truth
|
||||
next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the request URL
|
||||
targetURL, _ := url.Parse(req.URL.String())
|
||||
// Set the scheme to HTTPS
|
||||
targetURL.Scheme = httpsScheme
|
||||
|
||||
// Overwrite the port if the original request was to a non-standard port
|
||||
if targetURL.Port() != "" {
|
||||
// If Port was not empty, this should be fine to ignore the error
|
||||
host, _, _ := net.SplitHostPort(targetURL.Host)
|
||||
targetURL.Host = net.JoinHostPort(host, httpsPort)
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect)
|
||||
})
|
||||
}
|
158
pkg/middleware/redirect_to_https_test.go
Normal file
158
pkg/middleware/redirect_to_https_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("RedirectToHTTPS suite", func() {
|
||||
const httpsPort = "8443"
|
||||
|
||||
var permanentRedirectBody = func(address string) string {
|
||||
return fmt.Sprintf("<a href=\"%s\">Permanent Redirect</a>.\n\n", address)
|
||||
}
|
||||
|
||||
type requestTableInput struct {
|
||||
requestString string
|
||||
useTLS bool
|
||||
headers map[string]string
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
expectedLocation string
|
||||
}
|
||||
|
||||
DescribeTable("when serving a request",
|
||||
func(in *requestTableInput) {
|
||||
req := httptest.NewRequest("", in.requestString, nil)
|
||||
for k, v := range in.headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
if in.useTLS {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler := NewRedirectToHTTPS(httpsPort)(testHandler())
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
Expect(rw.Code).To(Equal(in.expectedStatus))
|
||||
Expect(rw.Body.String()).To(Equal(in.expectedBody))
|
||||
|
||||
if in.expectedLocation != "" {
|
||||
Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation))
|
||||
}
|
||||
},
|
||||
Entry("without TLS", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("without TLS on a non-standard port", &requestTableInput{
|
||||
requestString: "http://example.com:8080",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
||||
expectedLocation: "https://example.com:8443",
|
||||
}),
|
||||
Entry("with TLS on a non-standard port", &requestTableInput{
|
||||
requestString: "https://example.com:8443",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
)
|
||||
})
|
Loading…
Reference in New Issue
Block a user