oauth2-proxy/pkg/upstream/upstream_suite_test.go

182 lines
4.2 KiB
Go

package upstream
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"path"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"golang.org/x/net/websocket"
)
var (
filesDir string
server *httptest.Server
serverAddr string
invalidServer = "http://::1"
)
func TestUpstreamSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
log.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Upstream Suite")
}
var _ = BeforeSuite(func() {
// Set up files for serving via file servers
dir, err := ioutil.TempDir("", "oauth2-proxy-upstream-suite")
Expect(err).ToNot(HaveOccurred())
Expect(ioutil.WriteFile(path.Join(dir, "foo"), []byte("foo"), 0644)).To(Succeed())
Expect(ioutil.WriteFile(path.Join(dir, "bar"), []byte("bar"), 0644)).To(Succeed())
Expect(os.Mkdir(path.Join(dir, "subdir"), os.ModePerm)).To(Succeed())
Expect(ioutil.WriteFile(path.Join(dir, "subdir", "baz"), []byte("baz"), 0644)).To(Succeed())
filesDir = dir
// Set up a webserver that reflects requests
server = httptest.NewServer(&testHTTPUpstream{})
serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String())
})
var _ = AfterSuite(func() {
server.Close()
Expect(os.RemoveAll(filesDir)).To(Succeed())
})
const (
contentType = "Content-Type"
contentLength = "Content-Length"
acceptEncoding = "Accept-Encoding"
applicationJSON = "application/json"
textPlainUTF8 = "text/plain; charset=utf-8"
textHTMLUTF8 = "text/html; charset=utf-8"
gapAuth = "Gap-Auth"
gapSignature = "Gap-Signature"
)
// testHTTPResponse is a struct used for checking responses in table tests
type testHTTPResponse struct {
code int
header http.Header
raw string
request testHTTPRequest
}
// testHTTPRequest is a struct used to capture the state of a request made to
// an upstream during a test
type testHTTPRequest struct {
Method string
URL string
Header http.Header
Body []byte
Host string
RequestURI string
}
type testWebSocketResponse struct {
Message string
Origin string
}
type testHTTPUpstream struct{}
func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("Upgrade") == "websocket" {
t.websocketHandler().ServeHTTP(rw, req)
} else {
t.serveHTTP(rw, req)
}
}
func (t *testHTTPUpstream) serveHTTP(rw http.ResponseWriter, req *http.Request) {
request, err := toTestHTTPRequest(req)
if err != nil {
t.writeError(rw, err)
return
}
data, err := json.Marshal(request)
if err != nil {
t.writeError(rw, err)
return
}
rw.Header().Set("Content-Type", "application/json")
rw.Write(data)
}
func (t *testHTTPUpstream) websocketHandler() http.Handler {
return websocket.Handler(func(ws *websocket.Conn) {
defer ws.Close()
var data []byte
err := websocket.Message.Receive(ws, &data)
if err != nil {
websocket.Message.Send(ws, []byte(err.Error()))
return
}
wsResponse := testWebSocketResponse{
Message: string(data),
Origin: ws.Request().Header.Get("Origin"),
}
err = websocket.JSON.Send(ws, wsResponse)
if err != nil {
websocket.Message.Send(ws, []byte(err.Error()))
return
}
})
}
func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) {
rw.WriteHeader(500)
if err != nil {
rw.Write([]byte(err.Error()))
}
}
func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) {
requestBody := []byte{}
if req.Body != http.NoBody {
var err error
requestBody, err = ioutil.ReadAll(req.Body)
if err != nil {
return testHTTPRequest{}, err
}
}
return testHTTPRequest{
Method: req.Method,
URL: req.URL.String(),
Header: req.Header,
Body: requestBody,
Host: req.Host,
RequestURI: req.RequestURI,
}, nil
}
// String headers added to the response that we do not want to test
func testSanitizeResponseHeader(h http.Header) {
// From HTTP responses
h.Del("Date")
h.Del(contentLength)
// From File responses
h.Del("Accept-Ranges")
h.Del("Last-Modified")
}
// Strip the accept header that is added by the HTTP Transport
func testSanitizeRequestHeader(h http.Header) {
h.Del(acceptEncoding)
}