Merge pull request #660 from oauth2-proxy/request-builder
Use builder pattern to simplify requests to external endpoints
This commit is contained in:
commit
d29766609b
@ -8,6 +8,7 @@
|
||||
|
||||
## Changes since v6.0.0
|
||||
|
||||
- [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed)
|
||||
- [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed)
|
||||
- [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed)
|
||||
- [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves)
|
||||
|
118
pkg/requests/builder.go
Normal file
118
pkg/requests/builder.go
Normal file
@ -0,0 +1,118 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Builder allows users to construct a request and then execute the
|
||||
// request via Do().
|
||||
// Do returns a Result which allows the user to get the body,
|
||||
// unmarshal the body into an interface, or into a simplejson.Json.
|
||||
type Builder interface {
|
||||
WithContext(context.Context) Builder
|
||||
WithBody(io.Reader) Builder
|
||||
WithMethod(string) Builder
|
||||
WithHeaders(http.Header) Builder
|
||||
SetHeader(key, value string) Builder
|
||||
Do() Result
|
||||
}
|
||||
|
||||
type builder struct {
|
||||
context context.Context
|
||||
method string
|
||||
endpoint string
|
||||
body io.Reader
|
||||
header http.Header
|
||||
result *result
|
||||
}
|
||||
|
||||
// New provides a new Builder for the given endpoint.
|
||||
func New(endpoint string) Builder {
|
||||
return &builder{
|
||||
endpoint: endpoint,
|
||||
method: "GET",
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext adds a context to the request.
|
||||
// If no context is provided, context.Background() is used instead.
|
||||
func (r *builder) WithContext(ctx context.Context) Builder {
|
||||
r.context = ctx
|
||||
return r
|
||||
}
|
||||
|
||||
// WithBody adds a body to the request.
|
||||
func (r *builder) WithBody(body io.Reader) Builder {
|
||||
r.body = body
|
||||
return r
|
||||
}
|
||||
|
||||
// WithMethod sets the request method. Defaults to "GET".
|
||||
func (r *builder) WithMethod(method string) Builder {
|
||||
r.method = method
|
||||
return r
|
||||
}
|
||||
|
||||
// WithHeaders replaces the request header map with the given header map.
|
||||
func (r *builder) WithHeaders(header http.Header) Builder {
|
||||
r.header = header
|
||||
return r
|
||||
}
|
||||
|
||||
// SetHeader sets a single header to the given value.
|
||||
// May be used to add multiple headers.
|
||||
func (r *builder) SetHeader(key, value string) Builder {
|
||||
if r.header == nil {
|
||||
r.header = make(http.Header)
|
||||
}
|
||||
r.header.Set(key, value)
|
||||
return r
|
||||
}
|
||||
|
||||
// Do performs the request and returns the response in its raw form.
|
||||
// If the request has already been performed, returns the previous result.
|
||||
// This will not allow you to repeat a request.
|
||||
func (r *builder) Do() Result {
|
||||
if r.result != nil {
|
||||
// Request has already been done
|
||||
return r.result
|
||||
}
|
||||
|
||||
// Must provide a non-nil context to NewRequestWithContext
|
||||
if r.context == nil {
|
||||
r.context = context.Background()
|
||||
}
|
||||
|
||||
return r.do()
|
||||
}
|
||||
|
||||
// do creates the request, executes it with the default client and extracts the
|
||||
// the body into the response
|
||||
func (r *builder) do() Result {
|
||||
req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body)
|
||||
if err != nil {
|
||||
r.result = &result{err: fmt.Errorf("error creating request: %v", err)}
|
||||
return r.result
|
||||
}
|
||||
req.Header = r.header
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
r.result = &result{err: fmt.Errorf("error performing request: %v", err)}
|
||||
return r.result
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
r.result = &result{err: fmt.Errorf("error reading response body: %v", err)}
|
||||
return r.result
|
||||
}
|
||||
|
||||
r.result = &result{response: resp, body: body}
|
||||
return r.result
|
||||
}
|
376
pkg/requests/builder_test.go
Normal file
376
pkg/requests/builder_test.go
Normal file
@ -0,0 +1,376 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Builder suite", func() {
|
||||
var b Builder
|
||||
getBuilder := func() Builder { return b }
|
||||
|
||||
baseHeaders := http.Header{
|
||||
"Accept-Encoding": []string{"gzip"},
|
||||
"User-Agent": []string{"Go-http-client/1.1"},
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
// Most tests will request the server address
|
||||
b = New(serverAddr + "/json/path")
|
||||
})
|
||||
|
||||
Context("with a basic request", func() {
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: baseHeaders,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a context", func() {
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
b = b.WithContext(ctx)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: baseHeaders,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
|
||||
Context("if the context is cancelled", func() {
|
||||
BeforeEach(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
assertRequestError(getBuilder, "context canceled")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a body", func() {
|
||||
const body = "{\"some\": \"body\"}"
|
||||
header := baseHeaders.Clone()
|
||||
header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
|
||||
BeforeEach(func() {
|
||||
buf := bytes.NewBuffer([]byte(body))
|
||||
b = b.WithBody(buf)
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: header,
|
||||
Body: []byte(body),
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a method", func() {
|
||||
Context("POST with a body", func() {
|
||||
const body = "{\"some\": \"body\"}"
|
||||
header := baseHeaders.Clone()
|
||||
header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
|
||||
BeforeEach(func() {
|
||||
buf := bytes.NewBuffer([]byte(body))
|
||||
b = b.WithMethod("POST").WithBody(buf)
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "POST",
|
||||
Header: header,
|
||||
Body: []byte(body),
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
|
||||
Context("POST without a body", func() {
|
||||
header := baseHeaders.Clone()
|
||||
header.Set("Content-Length", "0")
|
||||
|
||||
BeforeEach(func() {
|
||||
b = b.WithMethod("POST")
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "POST",
|
||||
Header: header,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
|
||||
Context("OPTIONS", func() {
|
||||
BeforeEach(func() {
|
||||
b = b.WithMethod("OPTIONS")
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "OPTIONS",
|
||||
Header: baseHeaders,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
|
||||
Context("INVALID-\\t-METHOD", func() {
|
||||
BeforeEach(func() {
|
||||
b = b.WithMethod("INVALID-\t-METHOD")
|
||||
})
|
||||
|
||||
assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with headers", func() {
|
||||
Context("setting a header", func() {
|
||||
header := baseHeaders.Clone()
|
||||
header.Set("header", "value")
|
||||
|
||||
BeforeEach(func() {
|
||||
b = b.SetHeader("header", "value")
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: header,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
|
||||
Context("then replacing the headers", func() {
|
||||
replacementHeaders := http.Header{
|
||||
"Accept-Encoding": []string{"*"},
|
||||
"User-Agent": []string{"test-agent"},
|
||||
"Foo": []string{"bar, baz"},
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
b = b.WithHeaders(replacementHeaders)
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: replacementHeaders,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("replacing the header", func() {
|
||||
replacementHeaders := http.Header{
|
||||
"Accept-Encoding": []string{"*"},
|
||||
"User-Agent": []string{"test-agent"},
|
||||
"Foo": []string{"bar, baz"},
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
b = b.WithHeaders(replacementHeaders)
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: replacementHeaders,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
|
||||
Context("then setting a header", func() {
|
||||
header := replacementHeaders.Clone()
|
||||
header.Set("User-Agent", "different-agent")
|
||||
|
||||
BeforeEach(func() {
|
||||
b = b.SetHeader("User-Agent", "different-agent")
|
||||
})
|
||||
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: header,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("if the request has been completed and then modified", func() {
|
||||
BeforeEach(func() {
|
||||
result := b.Do()
|
||||
Expect(result.Error()).ToNot(HaveOccurred())
|
||||
|
||||
b.WithMethod("POST")
|
||||
})
|
||||
|
||||
Context("should not redo the request", func() {
|
||||
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||
Method: "GET",
|
||||
Header: baseHeaders,
|
||||
Body: []byte{},
|
||||
RequestURI: "/json/path",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("when the requested page is not found", func() {
|
||||
BeforeEach(func() {
|
||||
b = New(serverAddr + "/not-found")
|
||||
})
|
||||
|
||||
assertJSONError(getBuilder, "404 page not found")
|
||||
})
|
||||
|
||||
Context("when the requested page is not valid JSON", func() {
|
||||
BeforeEach(func() {
|
||||
b = New(serverAddr + "/string/path")
|
||||
})
|
||||
|
||||
assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value")
|
||||
})
|
||||
})
|
||||
|
||||
func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) {
|
||||
Context("Do", func() {
|
||||
var result Result
|
||||
|
||||
BeforeEach(func() {
|
||||
result = builder().Do()
|
||||
Expect(result.Error()).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns a successful status", func() {
|
||||
Expect(result.StatusCode()).To(Equal(http.StatusOK))
|
||||
})
|
||||
|
||||
It("made the expected request", func() {
|
||||
actualRequest := testHTTPRequest{}
|
||||
Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed())
|
||||
|
||||
Expect(actualRequest).To(Equal(expectedRequest))
|
||||
})
|
||||
})
|
||||
|
||||
Context("UnmarshalInto", func() {
|
||||
var actualRequest testHTTPRequest
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed())
|
||||
})
|
||||
|
||||
It("made the expected request", func() {
|
||||
Expect(actualRequest).To(Equal(expectedRequest))
|
||||
})
|
||||
})
|
||||
|
||||
Context("UnmarshalJSON", func() {
|
||||
var response *simplejson.Json
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
response, err = builder().Do().UnmarshalJSON()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("made the expected reqest", func() {
|
||||
header := http.Header{}
|
||||
for key, value := range response.Get("Header").MustMap() {
|
||||
vs, ok := value.([]interface{})
|
||||
Expect(ok).To(BeTrue())
|
||||
svs := []string{}
|
||||
for _, v := range vs {
|
||||
sv, ok := v.(string)
|
||||
Expect(ok).To(BeTrue())
|
||||
svs = append(svs, sv)
|
||||
}
|
||||
header[key] = svs
|
||||
}
|
||||
|
||||
// Other json unmarhsallers base64 decode byte slices automatically
|
||||
body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
actualRequest := testHTTPRequest{
|
||||
Method: response.Get("Method").MustString(),
|
||||
Header: header,
|
||||
Body: body,
|
||||
RequestURI: response.Get("RequestURI").MustString(),
|
||||
}
|
||||
|
||||
Expect(actualRequest).To(Equal(expectedRequest))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func assertRequestError(builder func() Builder, errorMessage string) {
|
||||
Context("Do", func() {
|
||||
It("returns an error", func() {
|
||||
result := builder().Do()
|
||||
Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("UnmarshalInto", func() {
|
||||
It("returns an error", func() {
|
||||
var actualRequest testHTTPRequest
|
||||
err := builder().Do().UnmarshalInto(&actualRequest)
|
||||
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||
|
||||
// Should be empty
|
||||
Expect(actualRequest).To(Equal(testHTTPRequest{}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("UnmarshalJSON", func() {
|
||||
It("returns an error", func() {
|
||||
resp, err := builder().Do().UnmarshalJSON()
|
||||
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||
Expect(resp).To(BeNil())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func assertJSONError(builder func() Builder, errorMessage string) {
|
||||
Context("Do", func() {
|
||||
It("does not return an error", func() {
|
||||
result := builder().Do()
|
||||
Expect(result.Error()).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("UnmarshalInto", func() {
|
||||
It("returns an error", func() {
|
||||
var actualRequest testHTTPRequest
|
||||
err := builder().Do().UnmarshalInto(&actualRequest)
|
||||
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||
|
||||
// Should be empty
|
||||
Expect(actualRequest).To(Equal(testHTTPRequest{}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("UnmarshalJSON", func() {
|
||||
It("returns an error", func() {
|
||||
resp, err := builder().Do().UnmarshalJSON()
|
||||
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||
Expect(resp).To(BeNil())
|
||||
})
|
||||
})
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
)
|
||||
|
||||
// Request parses the request body into a simplejson.Json object
|
||||
func Request(req *http.Request) (*simplejson.Json, error) {
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Printf("%s %s %s", req.Method, req.URL, err)
|
||||
return nil, err
|
||||
}
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("problem reading http request body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("got %d %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
data, err := simplejson.NewJson(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling json: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// RequestJSON parses the request body into the given interface
|
||||
func RequestJSON(req *http.Request, v interface{}) error {
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Printf("%s %s %s", req.Method, req.URL, err)
|
||||
return err
|
||||
}
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading body from http response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("got %d %s", resp.StatusCode, body)
|
||||
}
|
||||
return json.Unmarshal(body, v)
|
||||
}
|
||||
|
||||
// RequestUnparsedResponse performs a GET and returns the raw response object
|
||||
func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error performing get request: %w", err)
|
||||
}
|
||||
req.Header = header
|
||||
|
||||
return http.DefaultClient.Do(req)
|
||||
}
|
96
pkg/requests/requests_suite_test.go
Normal file
96
pkg/requests/requests_suite_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
server *httptest.Server
|
||||
serverAddr string
|
||||
)
|
||||
|
||||
func TestRequetsSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
log.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Requests Suite")
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
// Set up a webserver that reflects requests
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/json/", &testHTTPUpstream{})
|
||||
mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.Write([]byte("OK"))
|
||||
})
|
||||
server = httptest.NewServer(mux)
|
||||
serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String())
|
||||
})
|
||||
|
||||
var _ = AfterSuite(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
// testHTTPRequest is a struct used to capture the state of a request made to
|
||||
// the test server
|
||||
type testHTTPRequest struct {
|
||||
Method string
|
||||
Header http.Header
|
||||
Body []byte
|
||||
RequestURI string
|
||||
}
|
||||
|
||||
type testHTTPUpstream struct{}
|
||||
|
||||
func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
request, err := toTestHTTPRequest(req)
|
||||
if err != nil {
|
||||
t.writeError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.writeError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(data)
|
||||
}
|
||||
|
||||
func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) {
|
||||
rw.WriteHeader(500)
|
||||
if err != nil {
|
||||
rw.Write([]byte(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) {
|
||||
requestBody := []byte{}
|
||||
if req.Body != http.NoBody {
|
||||
var err error
|
||||
requestBody, err = ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return testHTTPRequest{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return testHTTPRequest{
|
||||
Method: req.Method,
|
||||
Header: req.Header,
|
||||
Body: requestBody,
|
||||
RequestURI: req.RequestURI,
|
||||
}, nil
|
||||
}
|
@ -1,136 +0,0 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(responseCode)
|
||||
_, err := w.Write([]byte(payload))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestRequest(t *testing.T) {
|
||||
backend := testBackend(t, 200, "{\"foo\": \"bar\"}")
|
||||
defer backend.Close()
|
||||
|
||||
req, _ := http.NewRequest("GET", backend.URL, nil)
|
||||
response, err := Request(req)
|
||||
assert.Equal(t, nil, err)
|
||||
result, err := response.Get("foo").String()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "bar", result)
|
||||
}
|
||||
|
||||
func TestRequestFailure(t *testing.T) {
|
||||
// Create a backend to generate a test URL, then close it to cause a
|
||||
// connection error.
|
||||
backend := testBackend(t, 200, "{\"foo\": \"bar\"}")
|
||||
backend.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", backend.URL, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
resp, err := Request(req)
|
||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
||||
assert.NotEqual(t, nil, err)
|
||||
if !strings.Contains(err.Error(), "refused") {
|
||||
t.Error("expected error when a connection fails: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpErrorCode(t *testing.T) {
|
||||
backend := testBackend(t, 404, "{\"foo\": \"bar\"}")
|
||||
defer backend.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", backend.URL, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
resp, err := Request(req)
|
||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestJsonParsingError(t *testing.T) {
|
||||
backend := testBackend(t, 200, "not well-formed JSON")
|
||||
defer backend.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", backend.URL, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
resp, err := Request(req)
|
||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
// Parsing a URL practically never fails, so we won't cover that test case.
|
||||
func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.FormValue("access_token")
|
||||
if r.URL.Path == "/" && token == "my_token" {
|
||||
w.WriteHeader(200)
|
||||
_, err := w.Write([]byte("some payload"))
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
w.WriteHeader(403)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
response, err := RequestUnparsedResponse(
|
||||
context.Background(), backend.URL+"?access_token=my_token", nil)
|
||||
assert.Equal(t, nil, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
assert.Equal(t, 200, response.StatusCode)
|
||||
body, err := ioutil.ReadAll(response.Body)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "some payload", string(body))
|
||||
}
|
||||
|
||||
func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) {
|
||||
backend := testBackend(t, 200, "some payload")
|
||||
// Close the backend now to force a request failure.
|
||||
backend.Close()
|
||||
|
||||
response, err := RequestUnparsedResponse(
|
||||
context.Background(), backend.URL+"?access_token=my_token", nil)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, (*http.Response)(nil), response)
|
||||
}
|
||||
|
||||
func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" {
|
||||
w.WriteHeader(200)
|
||||
_, err := w.Write([]byte("some payload"))
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
w.WriteHeader(403)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Auth", "my_token")
|
||||
response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers)
|
||||
assert.Equal(t, nil, err)
|
||||
defer response.Body.Close()
|
||||
|
||||
assert.Equal(t, 200, response.StatusCode)
|
||||
body, err := ioutil.ReadAll(response.Body)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.Equal(t, "some payload", string(body))
|
||||
}
|
98
pkg/requests/result.go
Normal file
98
pkg/requests/result.go
Normal file
@ -0,0 +1,98 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
)
|
||||
|
||||
// Result is the result of a request created by a Builder
|
||||
type Result interface {
|
||||
Error() error
|
||||
StatusCode() int
|
||||
Headers() http.Header
|
||||
Body() []byte
|
||||
UnmarshalInto(interface{}) error
|
||||
UnmarshalJSON() (*simplejson.Json, error)
|
||||
}
|
||||
|
||||
type result struct {
|
||||
err error
|
||||
response *http.Response
|
||||
body []byte
|
||||
}
|
||||
|
||||
// Error returns an error from the result if present
|
||||
func (r *result) Error() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// StatusCode returns the response's status code
|
||||
func (r *result) StatusCode() int {
|
||||
if r.response != nil {
|
||||
return r.response.StatusCode
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Headers returns the response's headers
|
||||
func (r *result) Headers() http.Header {
|
||||
if r.response != nil {
|
||||
return r.response.Header
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Body returns the response's body
|
||||
func (r *result) Body() []byte {
|
||||
return r.body
|
||||
}
|
||||
|
||||
// UnmarshalInto attempts to unmarshal the response into the the given interface.
|
||||
// The response body is assumed to be JSON.
|
||||
// The response must have a 200 status otherwise an error will be returned.
|
||||
func (r *result) UnmarshalInto(into interface{}) error {
|
||||
body, err := r.getBodyForUnmarshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, into); err != nil {
|
||||
return fmt.Errorf("error unmarshalling body: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON performs the request and attempts to unmarshal the response into a
|
||||
// simplejson.Json. The response body is assume to be JSON.
|
||||
// The response must have a 200 status otherwise an error will be returned.
|
||||
func (r *result) UnmarshalJSON() (*simplejson.Json, error) {
|
||||
body, err := r.getBodyForUnmarshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := simplejson.NewJson(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading json: %v", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// getBodyForUnmarshal returns the body if there wasn't an error and the status
|
||||
// code was 200.
|
||||
func (r *result) getBodyForUnmarshal() ([]byte, error) {
|
||||
if r.Error() != nil {
|
||||
return nil, r.Error()
|
||||
}
|
||||
|
||||
// Only unmarshal body if the response was successful
|
||||
if r.StatusCode() != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body())
|
||||
}
|
||||
|
||||
return r.Body(), nil
|
||||
}
|
326
pkg/requests/result_test.go
Normal file
326
pkg/requests/result_test.go
Normal file
@ -0,0 +1,326 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Result suite", func() {
|
||||
Context("with a result", func() {
|
||||
type resultTableInput struct {
|
||||
result Result
|
||||
expectedError error
|
||||
expectedStatusCode int
|
||||
expectedHeaders http.Header
|
||||
expectedBody []byte
|
||||
}
|
||||
|
||||
DescribeTable("accessors should return expected results",
|
||||
func(in resultTableInput) {
|
||||
if in.expectedError != nil {
|
||||
Expect(in.result.Error()).To(MatchError(in.expectedError))
|
||||
} else {
|
||||
Expect(in.result.Error()).To(BeNil())
|
||||
}
|
||||
|
||||
Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode))
|
||||
Expect(in.result.Headers()).To(Equal(in.expectedHeaders))
|
||||
Expect(in.result.Body()).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("with an empty result", resultTableInput{
|
||||
result: &result{},
|
||||
expectedError: nil,
|
||||
expectedStatusCode: 0,
|
||||
expectedHeaders: nil,
|
||||
expectedBody: nil,
|
||||
}),
|
||||
Entry("with an error", resultTableInput{
|
||||
result: &result{
|
||||
err: errors.New("error"),
|
||||
},
|
||||
expectedError: errors.New("error"),
|
||||
expectedStatusCode: 0,
|
||||
expectedHeaders: nil,
|
||||
expectedBody: nil,
|
||||
}),
|
||||
Entry("with a response with no headers", resultTableInput{
|
||||
result: &result{
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusTeapot,
|
||||
},
|
||||
},
|
||||
expectedError: nil,
|
||||
expectedStatusCode: http.StatusTeapot,
|
||||
expectedHeaders: nil,
|
||||
expectedBody: nil,
|
||||
}),
|
||||
Entry("with a response with no status code", resultTableInput{
|
||||
result: &result{
|
||||
response: &http.Response{
|
||||
Header: http.Header{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: nil,
|
||||
expectedStatusCode: 0,
|
||||
expectedHeaders: http.Header{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
expectedBody: nil,
|
||||
}),
|
||||
Entry("with a response with a body", resultTableInput{
|
||||
result: &result{
|
||||
body: []byte("some body"),
|
||||
},
|
||||
expectedError: nil,
|
||||
expectedStatusCode: 0,
|
||||
expectedHeaders: nil,
|
||||
expectedBody: []byte("some body"),
|
||||
}),
|
||||
Entry("with all fields", resultTableInput{
|
||||
result: &result{
|
||||
err: errors.New("some error"),
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusFound,
|
||||
Header: http.Header{
|
||||
"header": []string{"value"},
|
||||
},
|
||||
},
|
||||
body: []byte("a body"),
|
||||
},
|
||||
expectedError: errors.New("some error"),
|
||||
expectedStatusCode: http.StatusFound,
|
||||
expectedHeaders: http.Header{
|
||||
"header": []string{"value"},
|
||||
},
|
||||
expectedBody: []byte("a body"),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("UnmarshalInto", func() {
|
||||
type testStruct struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
}
|
||||
|
||||
type unmarshalIntoTableInput struct {
|
||||
result Result
|
||||
expectedErr error
|
||||
expectedOutput *testStruct
|
||||
}
|
||||
|
||||
DescribeTable("with a result",
|
||||
func(in unmarshalIntoTableInput) {
|
||||
input := &testStruct{}
|
||||
err := in.result.UnmarshalInto(input)
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(input).To(Equal(in.expectedOutput))
|
||||
},
|
||||
Entry("with an error", unmarshalIntoTableInput{
|
||||
result: &result{
|
||||
err: errors.New("got an error"),
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("{\"a\": \"foo\"}"),
|
||||
},
|
||||
expectedErr: errors.New("got an error"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
Entry("with a 409 status code", unmarshalIntoTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusConflict,
|
||||
},
|
||||
body: []byte("{\"a\": \"foo\"}"),
|
||||
},
|
||||
expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
Entry("when the response has a valid json response", unmarshalIntoTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("{\"a\": \"foo\", \"b\": 1}"),
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectedOutput: &testStruct{A: "foo", B: 1},
|
||||
}),
|
||||
Entry("when the response body is empty", unmarshalIntoTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte(""),
|
||||
},
|
||||
expectedErr: errors.New("error unmarshalling body: unexpected end of JSON input"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
Entry("when the response body is not json", unmarshalIntoTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("not json"),
|
||||
},
|
||||
expectedErr: errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("UnmarshalJSON", func() {
|
||||
type testStruct struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
}
|
||||
|
||||
type unmarshalJSONTableInput struct {
|
||||
result Result
|
||||
expectedErr error
|
||||
expectedOutput *testStruct
|
||||
}
|
||||
|
||||
DescribeTable("with a result",
|
||||
func(in unmarshalJSONTableInput) {
|
||||
j, err := in.result.UnmarshalJSON()
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr))
|
||||
Expect(j).To(BeNil())
|
||||
return
|
||||
}
|
||||
|
||||
// No error so j should not be nil
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
input := &testStruct{
|
||||
A: j.Get("a").MustString(),
|
||||
B: j.Get("b").MustInt(),
|
||||
}
|
||||
Expect(input).To(Equal(in.expectedOutput))
|
||||
},
|
||||
Entry("with an error", unmarshalJSONTableInput{
|
||||
result: &result{
|
||||
err: errors.New("got an error"),
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("{\"a\": \"foo\"}"),
|
||||
},
|
||||
expectedErr: errors.New("got an error"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
Entry("with a 409 status code", unmarshalJSONTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusConflict,
|
||||
},
|
||||
body: []byte("{\"a\": \"foo\"}"),
|
||||
},
|
||||
expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
Entry("when the response has a valid json response", unmarshalJSONTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("{\"a\": \"foo\", \"b\": 1}"),
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectedOutput: &testStruct{A: "foo", B: 1},
|
||||
}),
|
||||
Entry("when the response body is empty", unmarshalJSONTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte(""),
|
||||
},
|
||||
expectedErr: errors.New("error reading json: EOF"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
Entry("when the response body is not json", unmarshalJSONTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("not json"),
|
||||
},
|
||||
expectedErr: errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"),
|
||||
expectedOutput: &testStruct{},
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("getBodyForUnmarshal", func() {
|
||||
type getBodyForUnmarshalTableInput struct {
|
||||
result *result
|
||||
expectedErr error
|
||||
expectedBody []byte
|
||||
}
|
||||
|
||||
DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) {
|
||||
body, err := in.result.getBodyForUnmarshal()
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(body).To(Equal(in.expectedBody))
|
||||
},
|
||||
Entry("when the result has an error", getBodyForUnmarshalTableInput{
|
||||
result: &result{
|
||||
err: errors.New("got an error"),
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("body"),
|
||||
},
|
||||
expectedErr: errors.New("got an error"),
|
||||
expectedBody: nil,
|
||||
}),
|
||||
Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusConflict,
|
||||
},
|
||||
body: []byte("body"),
|
||||
},
|
||||
expectedErr: errors.New("unexpected status \"409\": body"),
|
||||
expectedBody: nil,
|
||||
}),
|
||||
Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{
|
||||
result: &result{
|
||||
err: nil,
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
body: []byte("body"),
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectedBody: []byte("body"),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
@ -83,34 +83,34 @@ func Validate(o *options.Options) error {
|
||||
|
||||
logger.Printf("Performing OIDC Discovery...")
|
||||
|
||||
if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil {
|
||||
if body, err := requests.Request(req); err == nil {
|
||||
|
||||
// Prefer manually configured URLs. It's a bit unclear
|
||||
// why you'd be doing discovery and also providing the URLs
|
||||
// explicitly though...
|
||||
if o.LoginURL == "" {
|
||||
o.LoginURL = body.Get("authorization_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.RedeemURL == "" {
|
||||
o.RedeemURL = body.Get("token_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.OIDCJwksURL == "" {
|
||||
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
|
||||
}
|
||||
|
||||
if o.ProfileURL == "" {
|
||||
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
|
||||
}
|
||||
|
||||
o.SkipOIDCDiscovery = true
|
||||
} else {
|
||||
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
||||
}
|
||||
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
|
||||
body, err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
||||
} else {
|
||||
logger.Printf("error: failed parsing OIDC discovery URL: %v", err)
|
||||
// Prefer manually configured URLs. It's a bit unclear
|
||||
// why you'd be doing discovery and also providing the URLs
|
||||
// explicitly though...
|
||||
if o.LoginURL == "" {
|
||||
o.LoginURL = body.Get("authorization_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.RedeemURL == "" {
|
||||
o.RedeemURL = body.Get("token_endpoint").MustString()
|
||||
}
|
||||
|
||||
if o.OIDCJwksURL == "" {
|
||||
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
|
||||
}
|
||||
|
||||
if o.ProfileURL == "" {
|
||||
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
|
||||
}
|
||||
|
||||
o.SkipOIDCDiscovery = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -385,10 +385,10 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
|
||||
if err != nil {
|
||||
// Try as JWKS URI
|
||||
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
|
||||
_, err := http.NewRequest("GET", jwksURI, nil)
|
||||
if err != nil {
|
||||
if err := requests.New(jwksURI).Do().Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
|
||||
} else {
|
||||
verifier = provider.Verifier(config)
|
||||
|
@ -3,10 +3,8 @@ package providers
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@ -91,39 +89,22 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresOn int64 `json:"expires_on,string"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
@ -169,26 +150,22 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getAzureHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
|
||||
json, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getAzureHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
email, err = getEmailFromJSON(json)
|
||||
|
||||
if err == nil && email != "" {
|
||||
return email, err
|
||||
}
|
||||
|
||||
email, err = json.Get("userPrincipalName").String()
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
@ -85,15 +84,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
FullName string `json:"full_name"`
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||
|
||||
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalInto(&emails)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
err = requests.RequestJSON(req, &emails)
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
logger.Printf("failed making request: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
@ -101,15 +99,15 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
teamURL := &url.URL{}
|
||||
*teamURL = *p.ValidateURL
|
||||
teamURL.Path = "/2.0/teams"
|
||||
req, err = http.NewRequestWithContext(ctx, "GET",
|
||||
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
|
||||
|
||||
requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken
|
||||
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalInto(&teams)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
err = requests.RequestJSON(req, &teams)
|
||||
if err != nil {
|
||||
logger.Printf("failed requesting teams membership %s", err)
|
||||
logger.Printf("failed requesting teams membership: %v", err)
|
||||
return "", err
|
||||
}
|
||||
var found = false
|
||||
@ -129,20 +127,20 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
||||
repositoriesURL := &url.URL{}
|
||||
*repositoriesURL = *p.ValidateURL
|
||||
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
|
||||
req, err = http.NewRequestWithContext(ctx, "GET",
|
||||
repositoriesURL.String()+"?role=contributor"+
|
||||
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+
|
||||
"&access_token="+s.AccessToken,
|
||||
nil)
|
||||
|
||||
requestURL := repositoriesURL.String() + "?role=contributor" +
|
||||
"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") +
|
||||
"&access_token=" + s.AccessToken
|
||||
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
Do().
|
||||
UnmarshalInto(&repositories)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
err = requests.RequestJSON(req, &repositories)
|
||||
if err != nil {
|
||||
logger.Printf("failed checking repository access %s", err)
|
||||
logger.Printf("failed checking repository access: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
var found = false
|
||||
for _, repository := range repositories.Values {
|
||||
if p.Repository == repository.FullName {
|
||||
|
@ -60,13 +60,12 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getDigitalOceanHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
json, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -62,20 +62,22 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getFacebookHeader(s.AccessToken)
|
||||
|
||||
type result struct {
|
||||
Email string
|
||||
}
|
||||
var r result
|
||||
err = requests.RequestJSON(req, &r)
|
||||
|
||||
requestURL := p.ProfileURL.String() + "?fields=name,email"
|
||||
err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getFacebookHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if r.Email == "" {
|
||||
return "", errors.New("no email")
|
||||
}
|
||||
|
@ -2,10 +2,8 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
@ -15,6 +13,7 @@ import (
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
)
|
||||
|
||||
// GitHubProvider represents an GitHub based Identity Provider
|
||||
@ -111,27 +110,17 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
|
||||
RawQuery: params.Encode(),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf(
|
||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
var op orgsPage
|
||||
if err := json.Unmarshal(body, &op); err != nil {
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&op)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(op) == 0 {
|
||||
break
|
||||
}
|
||||
@ -187,11 +176,15 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
RawQuery: params.Encode(),
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// bodyclose cannot detect that the body is being closed later in requests.Into,
|
||||
// so have to skip the linting for the next line.
|
||||
// nolint:bodyclose
|
||||
result := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do()
|
||||
if result.Error() != nil {
|
||||
return false, result.Error()
|
||||
}
|
||||
|
||||
if last == 0 {
|
||||
@ -207,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
// link header at last page (doesn't exist last info)
|
||||
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
|
||||
|
||||
link := resp.Header.Get("Link")
|
||||
link := result.Headers().Get("Link")
|
||||
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
|
||||
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
|
||||
|
||||
@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return false, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf(
|
||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
var tp teamsPage
|
||||
if err := json.Unmarshal(body, &tp); err != nil {
|
||||
return false, fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
if err := result.UnmarshalInto(&tp); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(tp) == 0 {
|
||||
break
|
||||
@ -297,25 +278,13 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
|
||||
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf(
|
||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
var repo repository
|
||||
if err := json.Unmarshal(body, &repo); err != nil {
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&repo)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@ -337,26 +306,15 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return false, fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, stripToken(endpoint.String()), body)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if p.isVerifiedUser(user.Login) {
|
||||
return true, nil
|
||||
@ -372,24 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return false, err
|
||||
result := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(accessToken)).
|
||||
Do()
|
||||
if result.Error() != nil {
|
||||
return false, result.Error()
|
||||
}
|
||||
|
||||
if resp.StatusCode != 204 {
|
||||
if result.StatusCode() != 204 {
|
||||
return false, fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
result.StatusCode(), endpoint.String(), result.Body())
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body())
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@ -440,28 +394,14 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
req.Header = getGitHubHeader(s.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&emails)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &emails); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
|
||||
returnEmail := ""
|
||||
for _, email := range emails {
|
||||
@ -489,34 +429,15 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
|
||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not create new GET request: %v", err)
|
||||
}
|
||||
|
||||
req.Header = getGitHubHeader(s.AccessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
err := requests.New(endpoint.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalInto(&user)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
|
||||
// Now that we have the username we can check collaborator status
|
||||
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
|
||||
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {
|
||||
|
@ -2,15 +2,13 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -131,31 +129,14 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
|
||||
userInfoURL := *p.LoginURL
|
||||
userInfoURL.Path = "/oauth/userinfo"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user info request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform user info request: %v", err)
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read user info response: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var userInfo gitlabUserInfo
|
||||
err = json.Unmarshal(body, &userInfo)
|
||||
err := requests.New(userInfoURL.String()).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
Do().
|
||||
UnmarshalInto(&userInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info: %v", err)
|
||||
return nil, fmt.Errorf("error getting user info: %v", err)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
|
@ -9,13 +9,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("code", code)
|
||||
params.Add("grant_type", "authorization_code")
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@ -145,10 +123,18 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
||||
if err != nil {
|
||||
return
|
||||
@ -283,38 +269,24 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("refresh_token", refreshToken)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &data)
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&data)
|
||||
if err != nil {
|
||||
return
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
token = data.AccessToken
|
||||
idToken = data.IDToken
|
||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
@ -56,20 +55,22 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
|
||||
params := url.Values{"access_token": {accessToken}}
|
||||
endpoint = endpoint + "?" + params.Encode()
|
||||
}
|
||||
resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
|
||||
if err != nil {
|
||||
|
||||
result := requests.New(endpoint).
|
||||
WithContext(ctx).
|
||||
WithHeaders(header).
|
||||
Do()
|
||||
if result.Error() != nil {
|
||||
logger.Printf("GET %s", stripToken(endpoint))
|
||||
logger.Printf("token validation request failed: %s", err)
|
||||
logger.Printf("token validation request failed: %s", result.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
|
||||
logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body())
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
if result.StatusCode() == 200 {
|
||||
return true
|
||||
}
|
||||
logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
|
||||
logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body())
|
||||
return false
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
@ -51,14 +50,11 @@ func (p *KeycloakProvider) SetGroup(group string) {
|
||||
}
|
||||
|
||||
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
json, err := requests.Request(req)
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
|
@ -58,13 +58,13 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = getLinkedInHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
requestURL := p.ProfileURL.String() + "?format=json"
|
||||
json, err := requests.New(requestURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getLinkedInHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
@ -128,51 +129,34 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) {
|
||||
// query the user info endpoint for user attributes
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
|
||||
return
|
||||
}
|
||||
|
||||
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
|
||||
// parse the user attributes from the data we got and make sure that
|
||||
// the email address has been validated.
|
||||
var emailData struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
err = json.Unmarshal(body, &emailData)
|
||||
|
||||
// query the user info endpoint for user attributes
|
||||
err := requests.New(userInfoEndpoint).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
Do().
|
||||
UnmarshalInto(&emailData)
|
||||
if err != nil {
|
||||
return
|
||||
return "", err
|
||||
}
|
||||
if emailData.Email == "" {
|
||||
err = fmt.Errorf("missing email")
|
||||
return
|
||||
|
||||
email := emailData.Email
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("missing email")
|
||||
}
|
||||
email = emailData.Email
|
||||
|
||||
if !emailData.EmailVerified {
|
||||
err = fmt.Errorf("email %s not listed as verified", email)
|
||||
return
|
||||
return "", fmt.Errorf("email %s not listed as verified", email)
|
||||
}
|
||||
return
|
||||
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
@ -201,30 +185,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
params.Add("code", code)
|
||||
params.Add("grant_type", "authorization_code")
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the token from the body that we got from the token endpoint.
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@ -232,9 +192,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check nonce here
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
)
|
||||
|
||||
@ -31,18 +30,15 @@ func getNextcloudHeader(accessToken string) http.Header {
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
||||
p.ValidateURL.String(), nil)
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getNextcloudHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
}
|
||||
req.Header = getNextcloudHeader(s.AccessToken)
|
||||
json, err := requests.Request(req)
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
return "", fmt.Errorf("error making request: %v", err)
|
||||
}
|
||||
|
||||
email, err := json.Get("ocs").Get("data").Get("email").String()
|
||||
return email, err
|
||||
}
|
||||
|
@ -256,13 +256,11 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
|
||||
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
||||
// contents at the profileURL contains the email.
|
||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header = getOIDCHeader(accessToken)
|
||||
|
||||
respJSON, err := requests.Request(req)
|
||||
respJSON, err := requests.New(profileURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(getOIDCHeader(accessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -3,17 +3,15 @@ package providers
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||
)
|
||||
|
||||
var _ Provider = (*ProviderData)(nil)
|
||||
@ -39,35 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
result := requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do()
|
||||
if result.Error() != nil {
|
||||
return nil, result.Error()
|
||||
}
|
||||
|
||||
// blindly try json and x-www-form-urlencoded
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
err = result.UnmarshalInto(&jsonResponse)
|
||||
if err == nil {
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
@ -76,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
}
|
||||
|
||||
var v url.Values
|
||||
v, err = url.ParseQuery(string(body))
|
||||
v, err = url.ParseQuery(string(result.Body()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -84,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||
created := time.Now()
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", body)
|
||||
err = fmt.Errorf("no access token found %s", result.Body())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user