Skip to content

Commit

Permalink
Added a mutex around backend varible for concurrent reads
Browse files Browse the repository at this point in the history
  • Loading branch information
natdm committed Oct 17, 2017
1 parent 2568e0f commit 2fc83ed
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
all: test bench vet check-gofmt

bench:
go test -bench . -run "Benchmark" ./form
go test -race -bench . -run "Benchmark" ./form

build:
go build ./...
Expand All @@ -10,7 +10,7 @@ check-gofmt:
scripts/check_gofmt.sh

test:
go test ./...
go test -race ./...

vet:
go vet ./...
2 changes: 1 addition & 1 deletion error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestErrorResponse(t *testing.T) {
}))
defer ts.Close()

SetBackend("api", BackendConfiguration{
SetBackend("api", &BackendConfiguration{
APIBackend,
ts.URL,
&http.Client{},
Expand Down
38 changes: 25 additions & 13 deletions stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os/exec"
"runtime"
"strings"
"sync"
"time"

"github.com/stripe/stripe-go/form"
Expand Down Expand Up @@ -101,6 +102,7 @@ const (
// Backends are the currently supported endpoints.
type Backends struct {
API, Uploads Backend
mu sync.RWMutex
}

// stripeClientUserAgent contains information about the current runtime which
Expand Down Expand Up @@ -157,31 +159,41 @@ func SetHTTPClient(client *http.Client) {
// should only need to use this for testing purposes or on App Engine.
func NewBackends(httpClient *http.Client) *Backends {
return &Backends{
API: BackendConfiguration{
API: &BackendConfiguration{
APIBackend, APIURL, httpClient},
Uploads: BackendConfiguration{
Uploads: &BackendConfiguration{
UploadsBackend, UploadsURL, httpClient},
}
}

// GetBackend returns the currently used backend in the binding.
func GetBackend(backend SupportedBackend) Backend {
var ret Backend
switch backend {
case APIBackend:
if backends.API == nil {
backends.API = BackendConfiguration{backend, apiURL, httpClient}
backends.mu.RLock()
ret := backends.API
backends.mu.RUnlock()
if ret != nil {
return ret
}

ret = backends.API
backends.mu.Lock()
defer backends.mu.Unlock()
backends.API = &BackendConfiguration{backend, apiURL, httpClient}
return backends.API
case UploadsBackend:
if backends.Uploads == nil {
backends.Uploads = BackendConfiguration{backend, uploadsURL, httpClient}
backends.mu.RLock()
ret := backends.Uploads
backends.mu.RUnlock()
if ret != nil {
return ret
}
ret = backends.Uploads
backends.mu.Lock()
defer backends.mu.Unlock()
backends.Uploads = &BackendConfiguration{backend, uploadsURL, httpClient}
return backends.Uploads
}

return ret
return nil
}

// SetBackend sets the backend used in the binding.
Expand All @@ -195,7 +207,7 @@ func SetBackend(backend SupportedBackend, b Backend) {
}

// Call is the Backend.Call implementation for invoking Stripe APIs.
func (s BackendConfiguration) Call(method, path, key string, form *form.Values, params *Params, v interface{}) error {
func (s *BackendConfiguration) Call(method, path, key string, form *form.Values, params *Params, v interface{}) error {
var body io.Reader
if form != nil && !form.Empty() {
data := form.Encode()
Expand All @@ -219,7 +231,7 @@ func (s BackendConfiguration) Call(method, path, key string, form *form.Values,
}

// CallMultipart is the Backend.CallMultipart implementation for invoking Stripe APIs.
func (s BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error {
func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error {
contentType := "multipart/form-data; boundary=" + boundary

req, err := s.NewRequest(method, path, key, contentType, body, params)
Expand Down
20 changes: 20 additions & 0 deletions stripe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"regexp"
"runtime"
"sync"
"testing"

assert "github.com/stretchr/testify/require"
Expand All @@ -22,6 +23,25 @@ func TestCheckinUseBearerAuth(t *testing.T) {
assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization"))
}

// TestMultipleAPICalls will fail the test run if a race condition is thrown while running multople NewRequest calls.
func TestMultipleAPICalls(t *testing.T) {
wg := &sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
c := &stripe.BackendConfiguration{URL: stripe.APIURL}
key := "apiKey"

req, err := c.NewRequest("", "", key, "", nil, nil)
assert.NoError(t, err)

assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization"))
}()
}
wg.Wait()
}

func TestIdempotencyKey(t *testing.T) {
c := &stripe.BackendConfiguration{URL: stripe.APIURL}
p := &stripe.Params{IdempotencyKey: "idempotency-key"}
Expand Down
2 changes: 1 addition & 1 deletion testing/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func init() {
}

stripe.Key = "sk_test_myTestKey"
stripe.SetBackend("api", stripe.BackendConfiguration{
stripe.SetBackend("api", &stripe.BackendConfiguration{
Type: stripe.APIBackend,
URL: "http://localhost:" + port + "/v1",
HTTPClient: &http.Client{},
Expand Down

0 comments on commit 2fc83ed

Please sign in to comment.