Introduce Duration so that marshalling works for duration strings
This commit is contained in:
parent
ed92df3537
commit
b6d6f31ac1
@ -1,5 +1,11 @@
|
||||
package options
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecretSource references an individual secret value.
|
||||
// Only one source within the struct should be defined at any time.
|
||||
type SecretSource struct {
|
||||
@ -12,3 +18,45 @@ type SecretSource struct {
|
||||
// FromFile expects a path to a file containing the secret value.
|
||||
FromFile string
|
||||
}
|
||||
|
||||
// Duration is an alias for time.Duration so that we can ensure the marshalling
|
||||
// and unmarshalling of string durations is done as users expect.
|
||||
// Intentional blank line below to keep this first part of the comment out of
|
||||
// any generated references.
|
||||
|
||||
// Duration is as string representation of a period of time.
|
||||
// A duration string is a is a possibly signed sequence of decimal numbers,
|
||||
// each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
type Duration time.Duration
|
||||
|
||||
// UnmarshalJSON parses the duration string and sets the value of duration
|
||||
// to the value of the duration string.
|
||||
func (d *Duration) UnmarshalJSON(data []byte) error {
|
||||
input := string(data)
|
||||
if unquoted, err := strconv.Unquote(input); err == nil {
|
||||
input = unquoted
|
||||
}
|
||||
|
||||
du, err := time.ParseDuration(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = Duration(du)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON ensures that when the string is marshalled to JSON as a human
|
||||
// readable string.
|
||||
func (d *Duration) MarshalJSON() ([]byte, error) {
|
||||
dStr := fmt.Sprintf("%q", d.Duration().String())
|
||||
return []byte(dStr), nil
|
||||
}
|
||||
|
||||
// Duration returns the time.Duration version of this Duration
|
||||
func (d *Duration) Duration() time.Duration {
|
||||
if d == nil {
|
||||
return time.Duration(0)
|
||||
}
|
||||
return time.Duration(*d)
|
||||
}
|
||||
|
88
pkg/apis/options/common_test.go
Normal file
88
pkg/apis/options/common_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package options
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Common", func() {
|
||||
Context("Duration", func() {
|
||||
type marshalJSONTableInput struct {
|
||||
duration Duration
|
||||
expectedJSON string
|
||||
}
|
||||
|
||||
DescribeTable("MarshalJSON",
|
||||
func(in marshalJSONTableInput) {
|
||||
data, err := in.duration.MarshalJSON()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(data)).To(Equal(in.expectedJSON))
|
||||
|
||||
var d Duration
|
||||
Expect(json.Unmarshal(data, &d)).To(Succeed())
|
||||
Expect(d).To(Equal(in.duration))
|
||||
},
|
||||
Entry("30 seconds", marshalJSONTableInput{
|
||||
duration: Duration(30 * time.Second),
|
||||
expectedJSON: "\"30s\"",
|
||||
}),
|
||||
Entry("1 minute", marshalJSONTableInput{
|
||||
duration: Duration(1 * time.Minute),
|
||||
expectedJSON: "\"1m0s\"",
|
||||
}),
|
||||
Entry("1 hour 15 minutes", marshalJSONTableInput{
|
||||
duration: Duration(75 * time.Minute),
|
||||
expectedJSON: "\"1h15m0s\"",
|
||||
}),
|
||||
Entry("A zero Duration", marshalJSONTableInput{
|
||||
duration: Duration(0),
|
||||
expectedJSON: "\"0s\"",
|
||||
}),
|
||||
)
|
||||
|
||||
type unmarshalJSONTableInput struct {
|
||||
json string
|
||||
expectedErr error
|
||||
expectedDuration Duration
|
||||
}
|
||||
|
||||
DescribeTable("UnmarshalJSON",
|
||||
func(in unmarshalJSONTableInput) {
|
||||
// A duration must be initialised pointer before UnmarshalJSON will work.
|
||||
zero := Duration(0)
|
||||
d := &zero
|
||||
|
||||
err := d.UnmarshalJSON([]byte(in.json))
|
||||
if in.expectedErr != nil {
|
||||
Expect(err).To(MatchError(in.expectedErr.Error()))
|
||||
} else {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(d).ToNot(BeNil())
|
||||
Expect(*d).To(Equal(in.expectedDuration))
|
||||
},
|
||||
Entry("1m", unmarshalJSONTableInput{
|
||||
json: "\"1m\"",
|
||||
expectedDuration: Duration(1 * time.Minute),
|
||||
}),
|
||||
Entry("30s", unmarshalJSONTableInput{
|
||||
json: "\"30s\"",
|
||||
expectedDuration: Duration(30 * time.Second),
|
||||
}),
|
||||
Entry("1h15m", unmarshalJSONTableInput{
|
||||
json: "\"1h15m\"",
|
||||
expectedDuration: Duration(75 * time.Minute),
|
||||
}),
|
||||
Entry("am", unmarshalJSONTableInput{
|
||||
json: "\"am\"",
|
||||
expectedErr: errors.New("time: invalid duration \"am\""),
|
||||
expectedDuration: Duration(0),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
@ -84,6 +84,7 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) {
|
||||
u.Path = "/"
|
||||
}
|
||||
|
||||
flushInterval := Duration(l.FlushInterval)
|
||||
upstream := Upstream{
|
||||
ID: u.Path,
|
||||
Path: u.Path,
|
||||
@ -91,7 +92,7 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) {
|
||||
InsecureSkipTLSVerify: l.SSLUpstreamInsecureSkipVerify,
|
||||
PassHostHeader: &l.PassHostHeader,
|
||||
ProxyWebSockets: &l.ProxyWebSockets,
|
||||
FlushInterval: &l.FlushInterval,
|
||||
FlushInterval: &flushInterval,
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
|
@ -17,8 +17,8 @@ var _ = Describe("Legacy Options", func() {
|
||||
legacyOpts := NewLegacyOptions()
|
||||
|
||||
// Set upstreams and related options to test their conversion
|
||||
flushInterval := 5 * time.Second
|
||||
legacyOpts.LegacyUpstreams.FlushInterval = flushInterval
|
||||
flushInterval := Duration(5 * time.Second)
|
||||
legacyOpts.LegacyUpstreams.FlushInterval = time.Duration(flushInterval)
|
||||
legacyOpts.LegacyUpstreams.PassHostHeader = true
|
||||
legacyOpts.LegacyUpstreams.ProxyWebSockets = true
|
||||
legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true
|
||||
@ -124,7 +124,7 @@ var _ = Describe("Legacy Options", func() {
|
||||
skipVerify := true
|
||||
passHostHeader := false
|
||||
proxyWebSockets := true
|
||||
flushInterval := 5 * time.Second
|
||||
flushInterval := Duration(5 * time.Second)
|
||||
|
||||
// Test cases and expected outcomes
|
||||
validHTTP := "http://foo.bar/baz"
|
||||
@ -199,7 +199,7 @@ var _ = Describe("Legacy Options", func() {
|
||||
SSLUpstreamInsecureSkipVerify: skipVerify,
|
||||
PassHostHeader: passHostHeader,
|
||||
ProxyWebSockets: proxyWebSockets,
|
||||
FlushInterval: flushInterval,
|
||||
FlushInterval: time.Duration(flushInterval),
|
||||
}
|
||||
|
||||
upstreams, err := legacyUpstreams.convert()
|
||||
|
@ -1,7 +1,5 @@
|
||||
package options
|
||||
|
||||
import "time"
|
||||
|
||||
// Upstreams is a collection of definitions for upstream servers.
|
||||
type Upstreams []Upstream
|
||||
|
||||
@ -47,7 +45,7 @@ type Upstream struct {
|
||||
// FlushInterval is the period between flushing the response buffer when
|
||||
// streaming response from the upstream.
|
||||
// Defaults to 1 second.
|
||||
FlushInterval *time.Duration `json:"flushInterval,omitempty"`
|
||||
FlushInterval *Duration `json:"flushInterval,omitempty"`
|
||||
|
||||
// PassHostHeader determines whether the request host header should be proxied
|
||||
// to the upstream server.
|
||||
|
@ -98,7 +98,7 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr
|
||||
|
||||
// Configure options on the SingleHostReverseProxy
|
||||
if upstream.FlushInterval != nil {
|
||||
proxy.FlushInterval = *upstream.FlushInterval
|
||||
proxy.FlushInterval = upstream.FlushInterval.Duration()
|
||||
} else {
|
||||
proxy.FlushInterval = 1 * time.Second
|
||||
}
|
||||
|
@ -22,8 +22,8 @@ import (
|
||||
|
||||
var _ = Describe("HTTP Upstream Suite", func() {
|
||||
|
||||
const flushInterval5s = 5 * time.Second
|
||||
const flushInterval1s = 1 * time.Second
|
||||
const flushInterval5s = options.Duration(5 * time.Second)
|
||||
const flushInterval1s = options.Duration(1 * time.Second)
|
||||
truth := true
|
||||
falsum := false
|
||||
|
||||
@ -52,7 +52,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
flush := 1 * time.Second
|
||||
flush := options.Duration(1 * time.Second)
|
||||
|
||||
upstream := options.Upstream{
|
||||
ID: in.id,
|
||||
@ -258,7 +258,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
||||
req := httptest.NewRequest("", "http://example.localhost/foo", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
flush := 1 * time.Second
|
||||
flush := options.Duration(1 * time.Second)
|
||||
upstream := options.Upstream{
|
||||
ID: "noPassHost",
|
||||
PassHostHeader: &falsum,
|
||||
@ -290,7 +290,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
||||
|
||||
type newUpstreamTableInput struct {
|
||||
proxyWebSockets bool
|
||||
flushInterval time.Duration
|
||||
flushInterval options.Duration
|
||||
skipVerify bool
|
||||
sigData *options.SignatureData
|
||||
errorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
@ -319,7 +319,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
||||
|
||||
proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(proxy.FlushInterval).To(Equal(in.flushInterval))
|
||||
Expect(proxy.FlushInterval).To(Equal(in.flushInterval.Duration()))
|
||||
Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil))
|
||||
if in.skipVerify {
|
||||
Expect(proxy.Transport).To(Equal(&http.Transport{
|
||||
@ -370,7 +370,7 @@ var _ = Describe("HTTP Upstream Suite", func() {
|
||||
var proxyServer *httptest.Server
|
||||
|
||||
BeforeEach(func() {
|
||||
flush := 1 * time.Second
|
||||
flush := options.Duration(1 * time.Second)
|
||||
upstream := options.Upstream{
|
||||
ID: "websocketProxy",
|
||||
PassHostHeader: &truth,
|
||||
|
@ -70,7 +70,7 @@ func validateStaticUpstream(upstream options.Upstream) []string {
|
||||
if upstream.InsecureSkipTLSVerify {
|
||||
msgs = append(msgs, fmt.Sprintf("upstream %q has insecureSkipTLSVerify, but is a static upstream, this will have no effect.", upstream.ID))
|
||||
}
|
||||
if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second {
|
||||
if upstream.FlushInterval != nil && upstream.FlushInterval.Duration() != time.Second {
|
||||
msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID))
|
||||
}
|
||||
if upstream.PassHostHeader != nil {
|
||||
|
@ -15,7 +15,7 @@ var _ = Describe("Upstreams", func() {
|
||||
errStrings []string
|
||||
}
|
||||
|
||||
flushInterval := 5 * time.Second
|
||||
flushInterval := options.Duration(5 * time.Second)
|
||||
staticCode200 := 200
|
||||
truth := true
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user