diff --git a/Makefile b/Makefile index 79b9ba3ecf..c2cadaed60 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ all: test bench vet check-gofmt bench: - go test -bench . -run "Benchmark" ./form + go test -race -bench . -run "Benchmark" ./form build: go build ./... @@ -10,7 +10,7 @@ check-gofmt: scripts/check_gofmt.sh test: - go test ./... + go test -race ./... vet: go vet ./... diff --git a/error_test.go b/error_test.go index 427aab772e..693579e4c5 100644 --- a/error_test.go +++ b/error_test.go @@ -22,7 +22,7 @@ func TestErrorResponse(t *testing.T) { })) defer ts.Close() - SetBackend("api", BackendConfiguration{ + SetBackend("api", &BackendConfiguration{ APIBackend, ts.URL, &http.Client{}, diff --git a/stripe.go b/stripe.go index b7eb62c9f7..8a6f731fb8 100644 --- a/stripe.go +++ b/stripe.go @@ -14,6 +14,7 @@ import ( "os/exec" "runtime" "strings" + "sync" "time" "github.com/stripe/stripe-go/form" @@ -101,6 +102,7 @@ const ( // Backends are the currently supported endpoints. type Backends struct { API, Uploads Backend + mu sync.RWMutex } // stripeClientUserAgent contains information about the current runtime which @@ -157,31 +159,41 @@ func SetHTTPClient(client *http.Client) { // should only need to use this for testing purposes or on App Engine. func NewBackends(httpClient *http.Client) *Backends { return &Backends{ - API: BackendConfiguration{ + API: &BackendConfiguration{ APIBackend, APIURL, httpClient}, - Uploads: BackendConfiguration{ + Uploads: &BackendConfiguration{ UploadsBackend, UploadsURL, httpClient}, } } // GetBackend returns the currently used backend in the binding. func GetBackend(backend SupportedBackend) Backend { - var ret Backend switch backend { case APIBackend: - if backends.API == nil { - backends.API = BackendConfiguration{backend, apiURL, httpClient} + backends.mu.RLock() + ret := backends.API + backends.mu.RUnlock() + if ret != nil { + return ret } - - ret = backends.API + backends.mu.Lock() + defer backends.mu.Unlock() + backends.API = &BackendConfiguration{backend, apiURL, httpClient} + return backends.API case UploadsBackend: - if backends.Uploads == nil { - backends.Uploads = BackendConfiguration{backend, uploadsURL, httpClient} + backends.mu.RLock() + ret := backends.Uploads + backends.mu.RUnlock() + if ret != nil { + return ret } - ret = backends.Uploads + backends.mu.Lock() + defer backends.mu.Unlock() + backends.Uploads = &BackendConfiguration{backend, uploadsURL, httpClient} + return backends.Uploads } - return ret + return nil } // SetBackend sets the backend used in the binding. @@ -195,7 +207,7 @@ func SetBackend(backend SupportedBackend, b Backend) { } // Call is the Backend.Call implementation for invoking Stripe APIs. -func (s BackendConfiguration) Call(method, path, key string, form *form.Values, params *Params, v interface{}) error { +func (s *BackendConfiguration) Call(method, path, key string, form *form.Values, params *Params, v interface{}) error { var body io.Reader if form != nil && !form.Empty() { data := form.Encode() @@ -219,7 +231,7 @@ func (s BackendConfiguration) Call(method, path, key string, form *form.Values, } // CallMultipart is the Backend.CallMultipart implementation for invoking Stripe APIs. -func (s BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error { +func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error { contentType := "multipart/form-data; boundary=" + boundary req, err := s.NewRequest(method, path, key, contentType, body, params) diff --git a/stripe_test.go b/stripe_test.go index b88e546dea..9f77cc2d7e 100644 --- a/stripe_test.go +++ b/stripe_test.go @@ -5,6 +5,7 @@ import ( "net/http" "regexp" "runtime" + "sync" "testing" assert "github.com/stretchr/testify/require" @@ -22,6 +23,25 @@ func TestCheckinUseBearerAuth(t *testing.T) { assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization")) } +// TestMultipleAPICalls will fail the test run if a race condition is thrown while running multople NewRequest calls. +func TestMultipleAPICalls(t *testing.T) { + wg := &sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c := &stripe.BackendConfiguration{URL: stripe.APIURL} + key := "apiKey" + + req, err := c.NewRequest("", "", key, "", nil, nil) + assert.NoError(t, err) + + assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization")) + }() + } + wg.Wait() +} + func TestIdempotencyKey(t *testing.T) { c := &stripe.BackendConfiguration{URL: stripe.APIURL} p := &stripe.Params{IdempotencyKey: "idempotency-key"} diff --git a/testing/testing.go b/testing/testing.go index 55dc955fbb..3fcecc5e51 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -55,7 +55,7 @@ func init() { } stripe.Key = "sk_test_myTestKey" - stripe.SetBackend("api", stripe.BackendConfiguration{ + stripe.SetBackend("api", &stripe.BackendConfiguration{ Type: stripe.APIBackend, URL: "http://localhost:" + port + "/v1", HTTPClient: &http.Client{},