Merge pull request #660 from oauth2-proxy/request-builder

Use builder pattern to simplify requests to external endpoints
This commit is contained in:
Joel Speed 2020-07-06 21:01:55 +01:00 committed by GitHub
commit d29766609b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1246 additions and 650 deletions

View File

@ -8,6 +8,7 @@
## Changes since v6.0.0
- [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed)
- [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed)
- [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed)
- [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves)

118
pkg/requests/builder.go Normal file
View File

@ -0,0 +1,118 @@
package requests
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
)
// Builder allows users to construct a request and then execute the
// request via Do().
// Do returns a Result which allows the user to get the body,
// unmarshal the body into an interface, or into a simplejson.Json.
type Builder interface {
WithContext(context.Context) Builder
WithBody(io.Reader) Builder
WithMethod(string) Builder
WithHeaders(http.Header) Builder
SetHeader(key, value string) Builder
Do() Result
}
type builder struct {
context context.Context
method string
endpoint string
body io.Reader
header http.Header
result *result
}
// New provides a new Builder for the given endpoint.
func New(endpoint string) Builder {
return &builder{
endpoint: endpoint,
method: "GET",
}
}
// WithContext adds a context to the request.
// If no context is provided, context.Background() is used instead.
func (r *builder) WithContext(ctx context.Context) Builder {
r.context = ctx
return r
}
// WithBody adds a body to the request.
func (r *builder) WithBody(body io.Reader) Builder {
r.body = body
return r
}
// WithMethod sets the request method. Defaults to "GET".
func (r *builder) WithMethod(method string) Builder {
r.method = method
return r
}
// WithHeaders replaces the request header map with the given header map.
func (r *builder) WithHeaders(header http.Header) Builder {
r.header = header
return r
}
// SetHeader sets a single header to the given value.
// May be used to add multiple headers.
func (r *builder) SetHeader(key, value string) Builder {
if r.header == nil {
r.header = make(http.Header)
}
r.header.Set(key, value)
return r
}
// Do performs the request and returns the response in its raw form.
// If the request has already been performed, returns the previous result.
// This will not allow you to repeat a request.
func (r *builder) Do() Result {
if r.result != nil {
// Request has already been done
return r.result
}
// Must provide a non-nil context to NewRequestWithContext
if r.context == nil {
r.context = context.Background()
}
return r.do()
}
// do creates the request, executes it with the default client and extracts the
// the body into the response
func (r *builder) do() Result {
req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body)
if err != nil {
r.result = &result{err: fmt.Errorf("error creating request: %v", err)}
return r.result
}
req.Header = r.header
resp, err := http.DefaultClient.Do(req)
if err != nil {
r.result = &result{err: fmt.Errorf("error performing request: %v", err)}
return r.result
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
r.result = &result{err: fmt.Errorf("error reading response body: %v", err)}
return r.result
}
r.result = &result{response: resp, body: body}
return r.result
}

View File

@ -0,0 +1,376 @@
package requests
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"github.com/bitly/go-simplejson"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Builder suite", func() {
var b Builder
getBuilder := func() Builder { return b }
baseHeaders := http.Header{
"Accept-Encoding": []string{"gzip"},
"User-Agent": []string{"Go-http-client/1.1"},
}
BeforeEach(func() {
// Most tests will request the server address
b = New(serverAddr + "/json/path")
})
Context("with a basic request", func() {
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: baseHeaders,
Body: []byte{},
RequestURI: "/json/path",
})
})
Context("with a context", func() {
var ctx context.Context
var cancel context.CancelFunc
BeforeEach(func() {
ctx, cancel = context.WithCancel(context.Background())
b = b.WithContext(ctx)
})
AfterEach(func() {
cancel()
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: baseHeaders,
Body: []byte{},
RequestURI: "/json/path",
})
Context("if the context is cancelled", func() {
BeforeEach(func() {
cancel()
})
assertRequestError(getBuilder, "context canceled")
})
})
Context("with a body", func() {
const body = "{\"some\": \"body\"}"
header := baseHeaders.Clone()
header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
BeforeEach(func() {
buf := bytes.NewBuffer([]byte(body))
b = b.WithBody(buf)
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: header,
Body: []byte(body),
RequestURI: "/json/path",
})
})
Context("with a method", func() {
Context("POST with a body", func() {
const body = "{\"some\": \"body\"}"
header := baseHeaders.Clone()
header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
BeforeEach(func() {
buf := bytes.NewBuffer([]byte(body))
b = b.WithMethod("POST").WithBody(buf)
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "POST",
Header: header,
Body: []byte(body),
RequestURI: "/json/path",
})
})
Context("POST without a body", func() {
header := baseHeaders.Clone()
header.Set("Content-Length", "0")
BeforeEach(func() {
b = b.WithMethod("POST")
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "POST",
Header: header,
Body: []byte{},
RequestURI: "/json/path",
})
})
Context("OPTIONS", func() {
BeforeEach(func() {
b = b.WithMethod("OPTIONS")
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "OPTIONS",
Header: baseHeaders,
Body: []byte{},
RequestURI: "/json/path",
})
})
Context("INVALID-\\t-METHOD", func() {
BeforeEach(func() {
b = b.WithMethod("INVALID-\t-METHOD")
})
assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"")
})
})
Context("with headers", func() {
Context("setting a header", func() {
header := baseHeaders.Clone()
header.Set("header", "value")
BeforeEach(func() {
b = b.SetHeader("header", "value")
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: header,
Body: []byte{},
RequestURI: "/json/path",
})
Context("then replacing the headers", func() {
replacementHeaders := http.Header{
"Accept-Encoding": []string{"*"},
"User-Agent": []string{"test-agent"},
"Foo": []string{"bar, baz"},
}
BeforeEach(func() {
b = b.WithHeaders(replacementHeaders)
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: replacementHeaders,
Body: []byte{},
RequestURI: "/json/path",
})
})
})
Context("replacing the header", func() {
replacementHeaders := http.Header{
"Accept-Encoding": []string{"*"},
"User-Agent": []string{"test-agent"},
"Foo": []string{"bar, baz"},
}
BeforeEach(func() {
b = b.WithHeaders(replacementHeaders)
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: replacementHeaders,
Body: []byte{},
RequestURI: "/json/path",
})
Context("then setting a header", func() {
header := replacementHeaders.Clone()
header.Set("User-Agent", "different-agent")
BeforeEach(func() {
b = b.SetHeader("User-Agent", "different-agent")
})
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: header,
Body: []byte{},
RequestURI: "/json/path",
})
})
})
})
Context("if the request has been completed and then modified", func() {
BeforeEach(func() {
result := b.Do()
Expect(result.Error()).ToNot(HaveOccurred())
b.WithMethod("POST")
})
Context("should not redo the request", func() {
assertSuccessfulRequest(getBuilder, testHTTPRequest{
Method: "GET",
Header: baseHeaders,
Body: []byte{},
RequestURI: "/json/path",
})
})
})
Context("when the requested page is not found", func() {
BeforeEach(func() {
b = New(serverAddr + "/not-found")
})
assertJSONError(getBuilder, "404 page not found")
})
Context("when the requested page is not valid JSON", func() {
BeforeEach(func() {
b = New(serverAddr + "/string/path")
})
assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value")
})
})
func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) {
Context("Do", func() {
var result Result
BeforeEach(func() {
result = builder().Do()
Expect(result.Error()).ToNot(HaveOccurred())
})
It("returns a successful status", func() {
Expect(result.StatusCode()).To(Equal(http.StatusOK))
})
It("made the expected request", func() {
actualRequest := testHTTPRequest{}
Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed())
Expect(actualRequest).To(Equal(expectedRequest))
})
})
Context("UnmarshalInto", func() {
var actualRequest testHTTPRequest
BeforeEach(func() {
Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed())
})
It("made the expected request", func() {
Expect(actualRequest).To(Equal(expectedRequest))
})
})
Context("UnmarshalJSON", func() {
var response *simplejson.Json
BeforeEach(func() {
var err error
response, err = builder().Do().UnmarshalJSON()
Expect(err).ToNot(HaveOccurred())
})
It("made the expected reqest", func() {
header := http.Header{}
for key, value := range response.Get("Header").MustMap() {
vs, ok := value.([]interface{})
Expect(ok).To(BeTrue())
svs := []string{}
for _, v := range vs {
sv, ok := v.(string)
Expect(ok).To(BeTrue())
svs = append(svs, sv)
}
header[key] = svs
}
// Other json unmarhsallers base64 decode byte slices automatically
body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString())
Expect(err).ToNot(HaveOccurred())
actualRequest := testHTTPRequest{
Method: response.Get("Method").MustString(),
Header: header,
Body: body,
RequestURI: response.Get("RequestURI").MustString(),
}
Expect(actualRequest).To(Equal(expectedRequest))
})
})
}
func assertRequestError(builder func() Builder, errorMessage string) {
Context("Do", func() {
It("returns an error", func() {
result := builder().Do()
Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage)))
})
})
Context("UnmarshalInto", func() {
It("returns an error", func() {
var actualRequest testHTTPRequest
err := builder().Do().UnmarshalInto(&actualRequest)
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
// Should be empty
Expect(actualRequest).To(Equal(testHTTPRequest{}))
})
})
Context("UnmarshalJSON", func() {
It("returns an error", func() {
resp, err := builder().Do().UnmarshalJSON()
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
Expect(resp).To(BeNil())
})
})
}
func assertJSONError(builder func() Builder, errorMessage string) {
Context("Do", func() {
It("does not return an error", func() {
result := builder().Do()
Expect(result.Error()).To(BeNil())
})
})
Context("UnmarshalInto", func() {
It("returns an error", func() {
var actualRequest testHTTPRequest
err := builder().Do().UnmarshalInto(&actualRequest)
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
// Should be empty
Expect(actualRequest).To(Equal(testHTTPRequest{}))
})
})
Context("UnmarshalJSON", func() {
It("returns an error", func() {
resp, err := builder().Do().UnmarshalJSON()
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
Expect(resp).To(BeNil())
})
})
}

View File

@ -1,74 +0,0 @@
package requests
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/bitly/go-simplejson"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
)
// Request parses the request body into a simplejson.Json object
func Request(req *http.Request) (*simplejson.Json, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Printf("%s %s %s", req.Method, req.URL, err)
return nil, err
}
body, err := ioutil.ReadAll(resp.Body)
if body != nil {
defer resp.Body.Close()
}
logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
if err != nil {
return nil, fmt.Errorf("problem reading http request body: %w", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("got %d %s", resp.StatusCode, body)
}
data, err := simplejson.NewJson(body)
if err != nil {
return nil, fmt.Errorf("error unmarshalling json: %w", err)
}
return data, nil
}
// RequestJSON parses the request body into the given interface
func RequestJSON(req *http.Request, v interface{}) error {
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Printf("%s %s %s", req.Method, req.URL, err)
return err
}
body, err := ioutil.ReadAll(resp.Body)
if body != nil {
defer resp.Body.Close()
}
logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
if err != nil {
return fmt.Errorf("error reading body from http response: %w", err)
}
if resp.StatusCode != 200 {
return fmt.Errorf("got %d %s", resp.StatusCode, body)
}
return json.Unmarshal(body, v)
}
// RequestUnparsedResponse performs a GET and returns the raw response object
func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error performing get request: %w", err)
}
req.Header = header
return http.DefaultClient.Do(req)
}

View File

@ -0,0 +1,96 @@
package requests
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var (
server *httptest.Server
serverAddr string
)
func TestRequetsSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
log.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Requests Suite")
}
var _ = BeforeSuite(func() {
// Set up a webserver that reflects requests
mux := http.NewServeMux()
mux.Handle("/json/", &testHTTPUpstream{})
mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) {
rw.Write([]byte("OK"))
})
server = httptest.NewServer(mux)
serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String())
})
var _ = AfterSuite(func() {
server.Close()
})
// testHTTPRequest is a struct used to capture the state of a request made to
// the test server
type testHTTPRequest struct {
Method string
Header http.Header
Body []byte
RequestURI string
}
type testHTTPUpstream struct{}
func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
request, err := toTestHTTPRequest(req)
if err != nil {
t.writeError(rw, err)
return
}
data, err := json.Marshal(request)
if err != nil {
t.writeError(rw, err)
return
}
rw.Header().Set("Content-Type", "application/json")
rw.Write(data)
}
func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) {
rw.WriteHeader(500)
if err != nil {
rw.Write([]byte(err.Error()))
}
}
func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) {
requestBody := []byte{}
if req.Body != http.NoBody {
var err error
requestBody, err = ioutil.ReadAll(req.Body)
if err != nil {
return testHTTPRequest{}, err
}
}
return testHTTPRequest{
Method: req.Method,
Header: req.Header,
Body: requestBody,
RequestURI: req.RequestURI,
}, nil
}

View File

@ -1,136 +0,0 @@
package requests
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/bitly/go-simplejson"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(responseCode)
_, err := w.Write([]byte(payload))
require.NoError(t, err)
}))
}
func TestRequest(t *testing.T) {
backend := testBackend(t, 200, "{\"foo\": \"bar\"}")
defer backend.Close()
req, _ := http.NewRequest("GET", backend.URL, nil)
response, err := Request(req)
assert.Equal(t, nil, err)
result, err := response.Get("foo").String()
assert.Equal(t, nil, err)
assert.Equal(t, "bar", result)
}
func TestRequestFailure(t *testing.T) {
// Create a backend to generate a test URL, then close it to cause a
// connection error.
backend := testBackend(t, 200, "{\"foo\": \"bar\"}")
backend.Close()
req, err := http.NewRequest("GET", backend.URL, nil)
assert.Equal(t, nil, err)
resp, err := Request(req)
assert.Equal(t, (*simplejson.Json)(nil), resp)
assert.NotEqual(t, nil, err)
if !strings.Contains(err.Error(), "refused") {
t.Error("expected error when a connection fails: ", err)
}
}
func TestHttpErrorCode(t *testing.T) {
backend := testBackend(t, 404, "{\"foo\": \"bar\"}")
defer backend.Close()
req, err := http.NewRequest("GET", backend.URL, nil)
assert.Equal(t, nil, err)
resp, err := Request(req)
assert.Equal(t, (*simplejson.Json)(nil), resp)
assert.NotEqual(t, nil, err)
}
func TestJsonParsingError(t *testing.T) {
backend := testBackend(t, 200, "not well-formed JSON")
defer backend.Close()
req, err := http.NewRequest("GET", backend.URL, nil)
assert.Equal(t, nil, err)
resp, err := Request(req)
assert.Equal(t, (*simplejson.Json)(nil), resp)
assert.NotEqual(t, nil, err)
}
// Parsing a URL practically never fails, so we won't cover that test case.
func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
token := r.FormValue("access_token")
if r.URL.Path == "/" && token == "my_token" {
w.WriteHeader(200)
_, err := w.Write([]byte("some payload"))
require.NoError(t, err)
} else {
w.WriteHeader(403)
}
}))
defer backend.Close()
response, err := RequestUnparsedResponse(
context.Background(), backend.URL+"?access_token=my_token", nil)
assert.Equal(t, nil, err)
defer response.Body.Close()
assert.Equal(t, 200, response.StatusCode)
body, err := ioutil.ReadAll(response.Body)
assert.Equal(t, nil, err)
assert.Equal(t, "some payload", string(body))
}
func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) {
backend := testBackend(t, 200, "some payload")
// Close the backend now to force a request failure.
backend.Close()
response, err := RequestUnparsedResponse(
context.Background(), backend.URL+"?access_token=my_token", nil)
assert.NotEqual(t, nil, err)
assert.Equal(t, (*http.Response)(nil), response)
}
func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" {
w.WriteHeader(200)
_, err := w.Write([]byte("some payload"))
require.NoError(t, err)
} else {
w.WriteHeader(403)
}
}))
defer backend.Close()
headers := make(http.Header)
headers.Set("Auth", "my_token")
response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers)
assert.Equal(t, nil, err)
defer response.Body.Close()
assert.Equal(t, 200, response.StatusCode)
body, err := ioutil.ReadAll(response.Body)
assert.Equal(t, nil, err)
assert.Equal(t, "some payload", string(body))
}

98
pkg/requests/result.go Normal file
View File

@ -0,0 +1,98 @@
package requests
import (
"encoding/json"
"fmt"
"net/http"
"github.com/bitly/go-simplejson"
)
// Result is the result of a request created by a Builder
type Result interface {
Error() error
StatusCode() int
Headers() http.Header
Body() []byte
UnmarshalInto(interface{}) error
UnmarshalJSON() (*simplejson.Json, error)
}
type result struct {
err error
response *http.Response
body []byte
}
// Error returns an error from the result if present
func (r *result) Error() error {
return r.err
}
// StatusCode returns the response's status code
func (r *result) StatusCode() int {
if r.response != nil {
return r.response.StatusCode
}
return 0
}
// Headers returns the response's headers
func (r *result) Headers() http.Header {
if r.response != nil {
return r.response.Header
}
return nil
}
// Body returns the response's body
func (r *result) Body() []byte {
return r.body
}
// UnmarshalInto attempts to unmarshal the response into the the given interface.
// The response body is assumed to be JSON.
// The response must have a 200 status otherwise an error will be returned.
func (r *result) UnmarshalInto(into interface{}) error {
body, err := r.getBodyForUnmarshal()
if err != nil {
return err
}
if err := json.Unmarshal(body, into); err != nil {
return fmt.Errorf("error unmarshalling body: %v", err)
}
return nil
}
// UnmarshalJSON performs the request and attempts to unmarshal the response into a
// simplejson.Json. The response body is assume to be JSON.
// The response must have a 200 status otherwise an error will be returned.
func (r *result) UnmarshalJSON() (*simplejson.Json, error) {
body, err := r.getBodyForUnmarshal()
if err != nil {
return nil, err
}
data, err := simplejson.NewJson(body)
if err != nil {
return nil, fmt.Errorf("error reading json: %v", err)
}
return data, nil
}
// getBodyForUnmarshal returns the body if there wasn't an error and the status
// code was 200.
func (r *result) getBodyForUnmarshal() ([]byte, error) {
if r.Error() != nil {
return nil, r.Error()
}
// Only unmarshal body if the response was successful
if r.StatusCode() != http.StatusOK {
return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body())
}
return r.Body(), nil
}

326
pkg/requests/result_test.go Normal file
View File

@ -0,0 +1,326 @@
package requests
import (
"errors"
"net/http"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
var _ = Describe("Result suite", func() {
Context("with a result", func() {
type resultTableInput struct {
result Result
expectedError error
expectedStatusCode int
expectedHeaders http.Header
expectedBody []byte
}
DescribeTable("accessors should return expected results",
func(in resultTableInput) {
if in.expectedError != nil {
Expect(in.result.Error()).To(MatchError(in.expectedError))
} else {
Expect(in.result.Error()).To(BeNil())
}
Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode))
Expect(in.result.Headers()).To(Equal(in.expectedHeaders))
Expect(in.result.Body()).To(Equal(in.expectedBody))
},
Entry("with an empty result", resultTableInput{
result: &result{},
expectedError: nil,
expectedStatusCode: 0,
expectedHeaders: nil,
expectedBody: nil,
}),
Entry("with an error", resultTableInput{
result: &result{
err: errors.New("error"),
},
expectedError: errors.New("error"),
expectedStatusCode: 0,
expectedHeaders: nil,
expectedBody: nil,
}),
Entry("with a response with no headers", resultTableInput{
result: &result{
response: &http.Response{
StatusCode: http.StatusTeapot,
},
},
expectedError: nil,
expectedStatusCode: http.StatusTeapot,
expectedHeaders: nil,
expectedBody: nil,
}),
Entry("with a response with no status code", resultTableInput{
result: &result{
response: &http.Response{
Header: http.Header{
"foo": []string{"bar"},
},
},
},
expectedError: nil,
expectedStatusCode: 0,
expectedHeaders: http.Header{
"foo": []string{"bar"},
},
expectedBody: nil,
}),
Entry("with a response with a body", resultTableInput{
result: &result{
body: []byte("some body"),
},
expectedError: nil,
expectedStatusCode: 0,
expectedHeaders: nil,
expectedBody: []byte("some body"),
}),
Entry("with all fields", resultTableInput{
result: &result{
err: errors.New("some error"),
response: &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{
"header": []string{"value"},
},
},
body: []byte("a body"),
},
expectedError: errors.New("some error"),
expectedStatusCode: http.StatusFound,
expectedHeaders: http.Header{
"header": []string{"value"},
},
expectedBody: []byte("a body"),
}),
)
})
Context("UnmarshalInto", func() {
type testStruct struct {
A string `json:"a"`
B int `json:"b"`
}
type unmarshalIntoTableInput struct {
result Result
expectedErr error
expectedOutput *testStruct
}
DescribeTable("with a result",
func(in unmarshalIntoTableInput) {
input := &testStruct{}
err := in.result.UnmarshalInto(input)
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
} else {
Expect(err).ToNot(HaveOccurred())
}
Expect(input).To(Equal(in.expectedOutput))
},
Entry("with an error", unmarshalIntoTableInput{
result: &result{
err: errors.New("got an error"),
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("{\"a\": \"foo\"}"),
},
expectedErr: errors.New("got an error"),
expectedOutput: &testStruct{},
}),
Entry("with a 409 status code", unmarshalIntoTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusConflict,
},
body: []byte("{\"a\": \"foo\"}"),
},
expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"),
expectedOutput: &testStruct{},
}),
Entry("when the response has a valid json response", unmarshalIntoTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("{\"a\": \"foo\", \"b\": 1}"),
},
expectedErr: nil,
expectedOutput: &testStruct{A: "foo", B: 1},
}),
Entry("when the response body is empty", unmarshalIntoTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte(""),
},
expectedErr: errors.New("error unmarshalling body: unexpected end of JSON input"),
expectedOutput: &testStruct{},
}),
Entry("when the response body is not json", unmarshalIntoTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("not json"),
},
expectedErr: errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"),
expectedOutput: &testStruct{},
}),
)
})
Context("UnmarshalJSON", func() {
type testStruct struct {
A string `json:"a"`
B int `json:"b"`
}
type unmarshalJSONTableInput struct {
result Result
expectedErr error
expectedOutput *testStruct
}
DescribeTable("with a result",
func(in unmarshalJSONTableInput) {
j, err := in.result.UnmarshalJSON()
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
Expect(j).To(BeNil())
return
}
// No error so j should not be nil
Expect(err).ToNot(HaveOccurred())
input := &testStruct{
A: j.Get("a").MustString(),
B: j.Get("b").MustInt(),
}
Expect(input).To(Equal(in.expectedOutput))
},
Entry("with an error", unmarshalJSONTableInput{
result: &result{
err: errors.New("got an error"),
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("{\"a\": \"foo\"}"),
},
expectedErr: errors.New("got an error"),
expectedOutput: &testStruct{},
}),
Entry("with a 409 status code", unmarshalJSONTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusConflict,
},
body: []byte("{\"a\": \"foo\"}"),
},
expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"),
expectedOutput: &testStruct{},
}),
Entry("when the response has a valid json response", unmarshalJSONTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("{\"a\": \"foo\", \"b\": 1}"),
},
expectedErr: nil,
expectedOutput: &testStruct{A: "foo", B: 1},
}),
Entry("when the response body is empty", unmarshalJSONTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte(""),
},
expectedErr: errors.New("error reading json: EOF"),
expectedOutput: &testStruct{},
}),
Entry("when the response body is not json", unmarshalJSONTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("not json"),
},
expectedErr: errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"),
expectedOutput: &testStruct{},
}),
)
})
Context("getBodyForUnmarshal", func() {
type getBodyForUnmarshalTableInput struct {
result *result
expectedErr error
expectedBody []byte
}
DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) {
body, err := in.result.getBodyForUnmarshal()
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
} else {
Expect(err).ToNot(HaveOccurred())
}
Expect(body).To(Equal(in.expectedBody))
},
Entry("when the result has an error", getBodyForUnmarshalTableInput{
result: &result{
err: errors.New("got an error"),
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("body"),
},
expectedErr: errors.New("got an error"),
expectedBody: nil,
}),
Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusConflict,
},
body: []byte("body"),
},
expectedErr: errors.New("unexpected status \"409\": body"),
expectedBody: nil,
}),
Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{
result: &result{
err: nil,
response: &http.Response{
StatusCode: http.StatusOK,
},
body: []byte("body"),
},
expectedErr: nil,
expectedBody: []byte("body"),
}),
)
})
})

View File

@ -83,34 +83,34 @@ func Validate(o *options.Options) error {
logger.Printf("Performing OIDC Discovery...")
if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil {
if body, err := requests.Request(req); err == nil {
// Prefer manually configured URLs. It's a bit unclear
// why you'd be doing discovery and also providing the URLs
// explicitly though...
if o.LoginURL == "" {
o.LoginURL = body.Get("authorization_endpoint").MustString()
}
if o.RedeemURL == "" {
o.RedeemURL = body.Get("token_endpoint").MustString()
}
if o.OIDCJwksURL == "" {
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
}
if o.ProfileURL == "" {
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
}
o.SkipOIDCDiscovery = true
} else {
logger.Printf("error: failed to discover OIDC configuration: %v", err)
}
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
body, err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalJSON()
if err != nil {
logger.Printf("error: failed to discover OIDC configuration: %v", err)
} else {
logger.Printf("error: failed parsing OIDC discovery URL: %v", err)
// Prefer manually configured URLs. It's a bit unclear
// why you'd be doing discovery and also providing the URLs
// explicitly though...
if o.LoginURL == "" {
o.LoginURL = body.Get("authorization_endpoint").MustString()
}
if o.RedeemURL == "" {
o.RedeemURL = body.Get("token_endpoint").MustString()
}
if o.OIDCJwksURL == "" {
o.OIDCJwksURL = body.Get("jwks_uri").MustString()
}
if o.ProfileURL == "" {
o.ProfileURL = body.Get("userinfo_endpoint").MustString()
}
o.SkipOIDCDiscovery = true
}
}
@ -385,10 +385,10 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
if err != nil {
// Try as JWKS URI
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
_, err := http.NewRequest("GET", jwksURI, nil)
if err != nil {
if err := requests.New(jwksURI).Do().Error(); err != nil {
return nil, err
}
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
} else {
verifier = provider.Verifier(config)

View File

@ -3,10 +3,8 @@ package providers
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
@ -91,39 +89,22 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String())
}
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresOn int64 `json:"expires_on,string"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return
return nil, err
}
created := time.Now()
@ -169,26 +150,22 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getAzureHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", err
}
email, err = getEmailFromJSON(json)
if err == nil && email != "" {
return email, err
}
email, err = json.Get("userPrincipalName").String()
if err != nil {
logger.Printf("failed making request %s", err)
return "", err

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"net/http"
"net/url"
"strings"
@ -85,15 +84,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
FullName string `json:"full_name"`
}
}
req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalInto(&emails)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
err = requests.RequestJSON(req, &emails)
if err != nil {
logger.Printf("failed making request %s", err)
logger.Printf("failed making request: %v", err)
return "", err
}
@ -101,15 +99,15 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
teamURL := &url.URL{}
*teamURL = *p.ValidateURL
teamURL.Path = "/2.0/teams"
req, err = http.NewRequestWithContext(ctx, "GET",
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalInto(&teams)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
err = requests.RequestJSON(req, &teams)
if err != nil {
logger.Printf("failed requesting teams membership %s", err)
logger.Printf("failed requesting teams membership: %v", err)
return "", err
}
var found = false
@ -129,20 +127,20 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
repositoriesURL := &url.URL{}
*repositoriesURL = *p.ValidateURL
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
req, err = http.NewRequestWithContext(ctx, "GET",
repositoriesURL.String()+"?role=contributor"+
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+
"&access_token="+s.AccessToken,
nil)
requestURL := repositoriesURL.String() + "?role=contributor" +
"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") +
"&access_token=" + s.AccessToken
err := requests.New(requestURL).
WithContext(ctx).
Do().
UnmarshalInto(&repositories)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
err = requests.RequestJSON(req, &repositories)
if err != nil {
logger.Printf("failed checking repository access %s", err)
logger.Printf("failed checking repository access: %v", err)
return "", err
}
var found = false
for _, repository := range repositories.Values {
if p.Repository == repository.FullName {

View File

@ -60,13 +60,12 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
}
req.Header = getDigitalOceanHeader(s.AccessToken)
json, err := requests.Request(req)
json, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", err
}

View File

@ -62,20 +62,22 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
if err != nil {
return "", err
}
req.Header = getFacebookHeader(s.AccessToken)
type result struct {
Email string
}
var r result
err = requests.RequestJSON(req, &r)
requestURL := p.ProfileURL.String() + "?fields=name,email"
err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getFacebookHeader(s.AccessToken)).
Do().
UnmarshalInto(&r)
if err != nil {
return "", err
}
if r.Email == "" {
return "", errors.New("no email")
}

View File

@ -2,10 +2,8 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
@ -15,6 +13,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
)
// GitHubProvider represents an GitHub based Identity Provider
@ -111,27 +110,17 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
RawQuery: params.Encode(),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var op orgsPage
if err := json.Unmarshal(body, &op); err != nil {
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&op)
if err != nil {
return false, err
}
if len(op) == 0 {
break
}
@ -187,11 +176,15 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
RawQuery: params.Encode(),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
// bodyclose cannot detect that the body is being closed later in requests.Into,
// so have to skip the linting for the next line.
// nolint:bodyclose
result := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if result.Error() != nil {
return false, result.Error()
}
if last == 0 {
@ -207,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
// link header at last page (doesn't exist last info)
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
link := resp.Header.Get("Link")
link := result.Headers().Get("Link")
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
}
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
return false, err
}
resp.Body.Close()
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var tp teamsPage
if err := json.Unmarshal(body, &tp); err != nil {
return false, fmt.Errorf("%s unmarshaling %s", err, body)
if err := result.UnmarshalInto(&tp); err != nil {
return false, err
}
if len(tp) == 0 {
break
@ -297,25 +278,13 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf(
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}
var repo repository
if err := json.Unmarshal(body, &repo); err != nil {
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&repo)
if err != nil {
return false, err
}
@ -337,26 +306,15 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do().
UnmarshalInto(&user)
if err != nil {
return false, err
}
if resp.StatusCode != 200 {
return false, fmt.Errorf("got %d from %q %s",
resp.StatusCode, stripToken(endpoint.String()), body)
}
if err := json.Unmarshal(body, &user); err != nil {
return false, err
}
if p.isVerifiedUser(user.Login) {
return true, nil
@ -372,24 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return false, err
result := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(accessToken)).
Do()
if result.Error() != nil {
return false, result.Error()
}
if resp.StatusCode != 204 {
if result.StatusCode() != 204 {
return false, fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
result.StatusCode(), endpoint.String(), result.Body())
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body())
return true, nil
}
@ -440,28 +394,14 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
}
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req)
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)).
Do().
UnmarshalInto(&emails)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
returnEmail := ""
for _, email := range emails {
@ -489,34 +429,15 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}
req.Header = getGitHubHeader(s.AccessToken)
resp, err := http.DefaultClient.Do(req)
err := requests.New(endpoint.String()).
WithContext(ctx).
WithHeaders(getGitHubHeader(s.AccessToken)).
Do().
UnmarshalInto(&user)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
// Now that we have the username we can check collaborator status
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {

View File

@ -2,15 +2,13 @@ package providers
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
oidc "github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"golang.org/x/oauth2"
)
@ -131,31 +129,14 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
userInfoURL := *p.LoginURL
userInfoURL.Path = "/oauth/userinfo"
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create user info request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform user info request: %v", err)
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to read user info response: %v", err)
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body)
}
var userInfo gitlabUserInfo
err = json.Unmarshal(body, &userInfo)
err := requests.New(userInfoURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalInto(&userInfo)
if err != nil {
return nil, fmt.Errorf("failed to parse user info: %v", err)
return nil, fmt.Errorf("error getting user info: %v", err)
}
return &userInfo, nil

View File

@ -9,13 +9,13 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
params.Add("client_secret", clientSecret)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
@ -145,10 +123,18 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return
return nil, err
}
c, err := claimsFromIDToken(jsonResponse.IDToken)
if err != nil {
return
@ -283,38 +269,24 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
params.Add("client_secret", clientSecret)
params.Add("refresh_token", refreshToken)
params.Add("grant_type", "refresh_token")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var data struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &data)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&data)
if err != nil {
return
return "", "", 0, err
}
token = data.AccessToken
idToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"io/ioutil"
"net/http"
"net/url"
@ -56,20 +55,22 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode()
}
resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
if err != nil {
result := requests.New(endpoint).
WithContext(ctx).
WithHeaders(header).
Do()
if result.Error() != nil {
logger.Printf("GET %s", stripToken(endpoint))
logger.Printf("token validation request failed: %s", err)
logger.Printf("token validation request failed: %s", result.Error())
return false
}
body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()
logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body())
if resp.StatusCode == 200 {
if result.StatusCode() == 200 {
return true
}
logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body())
return false
}

View File

@ -2,7 +2,6 @@ package providers
import (
"context"
"net/http"
"net/url"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
@ -51,14 +50,11 @@ func (p *KeycloakProvider) SetGroup(group string) {
}
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil)
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
json, err := requests.Request(req)
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalJSON()
if err != nil {
logger.Printf("failed making request %s", err)
return "", err

View File

@ -58,13 +58,13 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
if err != nil {
return "", err
}
req.Header = getLinkedInHeader(s.AccessToken)
json, err := requests.Request(req)
requestURL := p.ProfileURL.String() + "?format=json"
json, err := requests.New(requestURL).
WithContext(ctx).
WithHeaders(getLinkedInHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return "", err
}

View File

@ -15,6 +15,7 @@ import (
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
"gopkg.in/square/go-jose.v2"
)
@ -128,51 +129,34 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
return
}
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) {
// query the user info endpoint for user attributes
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
if err != nil {
return
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
return
}
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
// parse the user attributes from the data we got and make sure that
// the email address has been validated.
var emailData struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
err = json.Unmarshal(body, &emailData)
// query the user info endpoint for user attributes
err := requests.New(userInfoEndpoint).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
Do().
UnmarshalInto(&emailData)
if err != nil {
return
return "", err
}
if emailData.Email == "" {
err = fmt.Errorf("missing email")
return
email := emailData.Email
if email == "" {
return "", fmt.Errorf("missing email")
}
email = emailData.Email
if !emailData.EmailVerified {
err = fmt.Errorf("email %s not listed as verified", email)
return
return "", fmt.Errorf("email %s not listed as verified", email)
}
return
return email, nil
}
// Redeem exchanges the OAuth2 authentication token for an ID token
@ -201,30 +185,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
// Get the token from the body that we got from the token endpoint.
var jsonResponse struct {
AccessToken string `json:"access_token"`
@ -232,9 +192,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
err = json.Unmarshal(body, &jsonResponse)
err = requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do().
UnmarshalInto(&jsonResponse)
if err != nil {
return
return nil, err
}
// check nonce here

View File

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
)
@ -31,18 +30,15 @@ func getNextcloudHeader(accessToken string) http.Header {
// GetEmailAddress returns the Account email address
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET",
p.ValidateURL.String(), nil)
json, err := requests.New(p.ValidateURL.String()).
WithContext(ctx).
WithHeaders(getNextcloudHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
}
req.Header = getNextcloudHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
logger.Printf("failed making request %s", err)
return "", err
return "", fmt.Errorf("error making request: %v", err)
}
email, err := json.Get("ocs").Get("data").Get("email").String()
return email, err
}

View File

@ -256,13 +256,11 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
// contents at the profileURL contains the email.
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
if err != nil {
return nil, err
}
req.Header = getOIDCHeader(accessToken)
respJSON, err := requests.Request(req)
respJSON, err := requests.New(profileURL).
WithContext(ctx).
WithHeaders(getOIDCHeader(accessToken)).
Do().
UnmarshalJSON()
if err != nil {
return nil, err
}

View File

@ -3,17 +3,15 @@ package providers
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
)
var _ Provider = (*ProviderData)(nil)
@ -39,35 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
params.Add("resource", p.ProtectedResource.String())
}
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
result := requests.New(p.RedeemURL.String()).
WithContext(ctx).
WithMethod("POST").
WithBody(bytes.NewBufferString(params.Encode())).
SetHeader("Content-Type", "application/x-www-form-urlencoded").
Do()
if result.Error() != nil {
return nil, result.Error()
}
// blindly try json and x-www-form-urlencoded
var jsonResponse struct {
AccessToken string `json:"access_token"`
}
err = json.Unmarshal(body, &jsonResponse)
err = result.UnmarshalInto(&jsonResponse)
if err == nil {
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
@ -76,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
}
var v url.Values
v, err = url.ParseQuery(string(body))
v, err = url.ParseQuery(string(result.Body()))
if err != nil {
return
}
@ -84,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else {
err = fmt.Errorf("no access token found %s", body)
err = fmt.Errorf("no access token found %s", result.Body())
}
return
}