Move RedirectToHTTPS to middleware package

Moves the logic for redirecting to HTTPs to a middleware package and adds tests for this logic.
Also makes the functionality more useful, previously it always redirected to the HTTPS address of the proxy, which may not have been intended, now it will redirect based on if a port is provided in the URL (assume public facing 80 to 443 or 4180 to 8443 for example)
This commit is contained in:
Joel Speed 2020-06-12 18:18:41 +01:00
parent 39c01d5930
commit 1c1106721e
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
7 changed files with 222 additions and 70 deletions

View File

@ -8,6 +8,7 @@
## Changes since v6.0.0
- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed)
- [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed)
- [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed)
- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed)

18
http.go
View File

@ -9,7 +9,6 @@ import (
"strings"
"time"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
)
@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
return func(next http.Handler) http.Handler {
return redirectToHTTPS(opts, next)
}
}
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proto := r.Header.Get("X-Forwarded-Proto")
if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) {
http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect)
}
h.ServeHTTP(w, r)
})
}

View File

@ -2,7 +2,6 @@ package main
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@ -11,56 +10,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestRedirectToHTTPSTrue(t *testing.T) {
opts := options.NewOptions()
opts.ForceHTTPS = true
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
h.ServeHTTP(rw, r)
assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code)
}
func TestRedirectToHTTPSFalse(t *testing.T) {
opts := options.NewOptions()
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
rw := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
h.ServeHTTP(rw, r)
assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code)
}
func TestRedirectNotWhenHTTPS(t *testing.T) {
opts := options.NewOptions()
opts.ForceHTTPS = true
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
h := redirectToHTTPS(opts, http.HandlerFunc(handler))
s := httptest.NewTLSServer(h)
defer s.Close()
opts.HTTPSAddress = s.URL
client := s.Client()
res, err := client.Get(s.URL)
if err != nil {
t.Fatalf("request to test server failed with error: %v", err)
}
assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode)
}
func TestGracefulShutdown(t *testing.T) {
opts := options.NewOptions()
stop := make(chan struct{}, 1)

View File

@ -3,6 +3,7 @@ package main
import (
"fmt"
"math/rand"
"net"
"os"
"os/signal"
"runtime"
@ -79,7 +80,11 @@ func main() {
chain := alice.New()
if opts.ForceHTTPS {
chain = chain.Append(newRedirectToHTTPS(opts))
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
if err != nil {
logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
}
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
}
healthCheckPaths := []string{opts.PingPath}

View File

@ -1,6 +1,7 @@
package middleware
import (
"net/http"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware")
}
func testHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte("test"))
})
}

View File

@ -0,0 +1,50 @@
package middleware
import (
"net"
"net/http"
"net/url"
"strings"
"github.com/justinas/alice"
)
const httpsScheme = "https"
// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect
// HTTP requests to HTTPS
func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
return func(next http.Handler) http.Handler {
return redirectToHTTPS(httpsPort, next)
}
}
// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS
// if it is not already HTTPS.
// If the request is to a non standard port, the redirection request will be
// to the port from the httpsAddress given.
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
proto := req.Header.Get("X-Forwarded-Proto")
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
// Only care about the connection to us being HTTPS if the proto is empty,
// otherwise the proto is source of truth
next.ServeHTTP(rw, req)
return
}
// Copy the request URL
targetURL, _ := url.Parse(req.URL.String())
// Set the scheme to HTTPS
targetURL.Scheme = httpsScheme
// Overwrite the port if the original request was to a non-standard port
if targetURL.Port() != "" {
// If Port was not empty, this should be fine to ignore the error
host, _, _ := net.SplitHostPort(targetURL.Host)
targetURL.Host = net.JoinHostPort(host, httpsPort)
}
http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect)
})
}

View File

@ -0,0 +1,158 @@
package middleware
import (
"crypto/tls"
"fmt"
"net/http/httptest"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
var _ = Describe("RedirectToHTTPS suite", func() {
const httpsPort = "8443"
var permanentRedirectBody = func(address string) string {
return fmt.Sprintf("<a href=\"%s\">Permanent Redirect</a>.\n\n", address)
}
type requestTableInput struct {
requestString string
useTLS bool
headers map[string]string
expectedStatus int
expectedBody string
expectedLocation string
}
DescribeTable("when serving a request",
func(in *requestTableInput) {
req := httptest.NewRequest("", in.requestString, nil)
for k, v := range in.headers {
req.Header.Add(k, v)
}
if in.useTLS {
req.TLS = &tls.ConnectionState{}
}
rw := httptest.NewRecorder()
handler := NewRedirectToHTTPS(httpsPort)(testHandler())
handler.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.expectedStatus))
Expect(rw.Body.String()).To(Equal(in.expectedBody))
if in.expectedLocation != "" {
Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation))
}
},
Entry("without TLS", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{
requestString: "https://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "http",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "http",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("without TLS on a non-standard port", &requestTableInput{
requestString: "http://example.com:8080",
useTLS: false,
headers: map[string]string{},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com:8443"),
expectedLocation: "https://example.com:8443",
}),
Entry("with TLS on a non-standard port", &requestTableInput{
requestString: "https://example.com:8443",
useTLS: true,
headers: map[string]string{},
expectedStatus: 200,
expectedBody: "test",
}),
)
})