Merge pull request #1142 from oauth2-proxy/writer-funcs
Add pagewriter to upstream proxy
This commit is contained in:
commit
06808704a3
@ -8,6 +8,7 @@
|
||||
|
||||
## Changes since v7.1.3
|
||||
|
||||
- [#1142](https://github.com/oauth2-proxy/oauth2-proxy/pull/1142) Add pagewriter to upstream proxy (@JoelSpeed)
|
||||
- [#1181](https://github.com/oauth2-proxy/oauth2-proxy/pull/1181) Fix incorrect `cfg` name in show-debug-on-error flag (@iTaybb)
|
||||
|
||||
# V7.1.3
|
||||
|
@ -124,7 +124,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||
return nil, fmt.Errorf("error initialising page writer: %v", err)
|
||||
}
|
||||
|
||||
upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler)
|
||||
upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error initialising upstream proxy: %v", err)
|
||||
}
|
||||
|
@ -101,3 +101,73 @@ func NewWriter(opts Opts) (Writer, error) {
|
||||
staticPageWriter: staticPages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WriterFuncs is an implementation of the PageWriter interface based
|
||||
// on override functions.
|
||||
// If any of the funcs are not provided, a default implementation will be used.
|
||||
// This is primarily for us in testing.
|
||||
type WriterFuncs struct {
|
||||
SignInPageFunc func(rw http.ResponseWriter, req *http.Request, redirectURL string)
|
||||
ErrorPageFunc func(rw http.ResponseWriter, opts ErrorPageOpts)
|
||||
ProxyErrorFunc func(rw http.ResponseWriter, req *http.Request, proxyErr error)
|
||||
RobotsTxtfunc func(rw http.ResponseWriter, req *http.Request)
|
||||
}
|
||||
|
||||
// WriteSignInPage implements the Writer interface.
|
||||
// If the SignInPageFunc is provided, this will be used, else a default
|
||||
// implementation will be used.
|
||||
func (w *WriterFuncs) WriteSignInPage(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
if w.SignInPageFunc != nil {
|
||||
w.SignInPageFunc(rw, req, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := rw.Write([]byte("Sign In")); err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorPage implements the Writer interface.
|
||||
// If the ErrorPageFunc is provided, this will be used, else a default
|
||||
// implementation will be used.
|
||||
func (w *WriterFuncs) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageOpts) {
|
||||
if w.ErrorPageFunc != nil {
|
||||
w.ErrorPageFunc(rw, opts)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(opts.Status)
|
||||
errMsg := fmt.Sprintf("%d - %v", opts.Status, opts.AppError)
|
||||
if _, err := rw.Write([]byte(errMsg)); err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyErrorHandler implements the Writer interface.
|
||||
// If the ProxyErrorFunc is provided, this will be used, else a default
|
||||
// implementation will be used.
|
||||
func (w *WriterFuncs) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) {
|
||||
if w.ProxyErrorFunc != nil {
|
||||
w.ProxyErrorFunc(rw, req, proxyErr)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteErrorPage(rw, ErrorPageOpts{
|
||||
Status: http.StatusBadGateway,
|
||||
AppError: proxyErr.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteRobotsTxt implements the Writer interface.
|
||||
// If the RobotsTxtfunc is provided, this will be used, else a default
|
||||
// implementation will be used.
|
||||
func (w *WriterFuncs) WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) {
|
||||
if w.RobotsTxtfunc != nil {
|
||||
w.RobotsTxtfunc(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := rw.Write([]byte("Allow: *")); err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package pagewriter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -8,6 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
@ -135,4 +138,144 @@ var _ = Describe("Writer", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriterFuncs", func() {
|
||||
type writerFuncsTableInput struct {
|
||||
writer Writer
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
}
|
||||
|
||||
DescribeTable("WriteSignInPage",
|
||||
func(in writerFuncsTableInput) {
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("", "/sign-in", nil)
|
||||
redirectURL := "<redirectURL>"
|
||||
in.writer.WriteSignInPage(rw, req, redirectURL)
|
||||
|
||||
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
|
||||
|
||||
body, err := ioutil.ReadAll(rw.Result().Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("With no override", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "Sign In",
|
||||
}),
|
||||
Entry("With an override function", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{
|
||||
SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
rw.WriteHeader(202)
|
||||
rw.Write([]byte(fmt.Sprintf("%s %s", req.URL.Path, redirectURL)))
|
||||
},
|
||||
},
|
||||
expectedStatus: 202,
|
||||
expectedBody: "/sign-in <redirectURL>",
|
||||
}),
|
||||
)
|
||||
|
||||
DescribeTable("WriteErrorPage",
|
||||
func(in writerFuncsTableInput) {
|
||||
rw := httptest.NewRecorder()
|
||||
in.writer.WriteErrorPage(rw, ErrorPageOpts{
|
||||
Status: http.StatusInternalServerError,
|
||||
RedirectURL: "<redirectURL>",
|
||||
RequestID: "12345",
|
||||
AppError: "application error",
|
||||
})
|
||||
|
||||
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
|
||||
|
||||
body, err := ioutil.ReadAll(rw.Result().Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("With no override", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{},
|
||||
expectedStatus: 500,
|
||||
expectedBody: "500 - application error",
|
||||
}),
|
||||
Entry("With an override function", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{
|
||||
ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) {
|
||||
rw.WriteHeader(503)
|
||||
rw.Write([]byte(fmt.Sprintf("%s %s", opts.RequestID, opts.RedirectURL)))
|
||||
},
|
||||
},
|
||||
expectedStatus: 503,
|
||||
expectedBody: "12345 <redirectURL>",
|
||||
}),
|
||||
)
|
||||
|
||||
DescribeTable("ProxyErrorHandler",
|
||||
func(in writerFuncsTableInput) {
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("", "/proxy", nil)
|
||||
err := errors.New("proxy error")
|
||||
in.writer.ProxyErrorHandler(rw, req, err)
|
||||
|
||||
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
|
||||
|
||||
body, err := ioutil.ReadAll(rw.Result().Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("With no override", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{},
|
||||
expectedStatus: 502,
|
||||
expectedBody: "502 - proxy error",
|
||||
}),
|
||||
Entry("With an override function for the proxy handler", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{
|
||||
ProxyErrorFunc: func(rw http.ResponseWriter, req *http.Request, proxyErr error) {
|
||||
rw.WriteHeader(503)
|
||||
rw.Write([]byte(fmt.Sprintf("%s %v", req.URL.Path, proxyErr)))
|
||||
},
|
||||
},
|
||||
expectedStatus: 503,
|
||||
expectedBody: "/proxy proxy error",
|
||||
}),
|
||||
Entry("With an override function for the error page", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{
|
||||
ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) {
|
||||
rw.WriteHeader(500)
|
||||
rw.Write([]byte("Internal Server Error"))
|
||||
},
|
||||
},
|
||||
expectedStatus: 500,
|
||||
expectedBody: "Internal Server Error",
|
||||
}),
|
||||
)
|
||||
|
||||
DescribeTable("WriteRobotsTxt",
|
||||
func(in writerFuncsTableInput) {
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("", "/robots.txt", nil)
|
||||
in.writer.WriteRobotsTxt(rw, req)
|
||||
|
||||
Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus))
|
||||
|
||||
body, err := ioutil.ReadAll(rw.Result().Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("With no override", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{},
|
||||
expectedStatus: 200,
|
||||
expectedBody: "Allow: *",
|
||||
}),
|
||||
Entry("With an override function", writerFuncsTableInput{
|
||||
writer: &WriterFuncs{
|
||||
RobotsTxtfunc: func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(202)
|
||||
rw.Write([]byte("Disallow: *"))
|
||||
},
|
||||
},
|
||||
expectedStatus: 202,
|
||||
expectedBody: "Disallow: *",
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
)
|
||||
|
||||
@ -15,7 +16,7 @@ type ProxyErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
|
||||
// NewProxy creates a new multiUpstreamProxy that can serve requests directed to
|
||||
// multiple upstreams.
|
||||
func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, errorHandler ProxyErrorHandler) (http.Handler, error) {
|
||||
func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, writer pagewriter.Writer) (http.Handler, error) {
|
||||
m := &multiUpstreamProxy{
|
||||
serveMux: http.NewServeMux(),
|
||||
}
|
||||
@ -34,7 +35,7 @@ func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, error
|
||||
case fileScheme:
|
||||
m.registerFileServer(upstream, u)
|
||||
case httpScheme, httpsScheme:
|
||||
m.registerHTTPUpstreamProxy(upstream, u, sigData, errorHandler)
|
||||
m.registerHTTPUpstreamProxy(upstream, u, sigData, writer)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown scheme for upstream %q: %q", upstream.ID, u.Scheme)
|
||||
}
|
||||
@ -66,7 +67,7 @@ func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *ur
|
||||
}
|
||||
|
||||
// registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given.
|
||||
func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, errorHandler ProxyErrorHandler) {
|
||||
func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, writer pagewriter.Writer) {
|
||||
logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI)
|
||||
m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, errorHandler))
|
||||
m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler))
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
@ -20,9 +21,11 @@ var _ = Describe("Proxy Suite", func() {
|
||||
BeforeEach(func() {
|
||||
sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}
|
||||
|
||||
errorHandler := func(rw http.ResponseWriter, _ *http.Request, _ error) {
|
||||
rw.WriteHeader(502)
|
||||
rw.Write([]byte("Proxy Error"))
|
||||
writer := &pagewriter.WriterFuncs{
|
||||
ProxyErrorFunc: func(rw http.ResponseWriter, _ *http.Request, _ error) {
|
||||
rw.WriteHeader(502)
|
||||
rw.Write([]byte("Proxy Error"))
|
||||
},
|
||||
}
|
||||
|
||||
ok := http.StatusOK
|
||||
@ -58,7 +61,7 @@ var _ = Describe("Proxy Suite", func() {
|
||||
}
|
||||
|
||||
var err error
|
||||
upstreamServer, err = NewProxy(upstreams, sigData, errorHandler)
|
||||
upstreamServer, err = NewProxy(upstreams, sigData, writer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user