Skip to content

Commit b5ba459

Browse files
authored
Fix race in ClientMock.MockGet() (#309)
* Show race in ClientMock.MockGet() Signed-off-by: Marco Pracucci <marco@pracucci.com> * Fixed race Signed-off-by: Marco Pracucci <marco@pracucci.com>
1 parent 182c73b commit b5ba459

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

pkg/storage/bucket/client_mock.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ func (m *ClientMock) MockIterWithCallback(prefix string, objects []string, err e
7777
// Get mocks objstore.Bucket.Get()
7878
func (m *ClientMock) Get(ctx context.Context, name string) (io.ReadCloser, error) {
7979
args := m.Called(ctx, name)
80+
81+
// Allow to mock the Get() with a function which is called each time.
82+
if fn, ok := args.Get(0).(func(ctx context.Context, name string) (io.ReadCloser, error)); ok {
83+
return fn(ctx, name)
84+
}
85+
8086
val, err := args.Get(0), args.Error(1)
8187
if val == nil {
8288
return nil, err
@@ -96,9 +102,8 @@ func (m *ClientMock) MockGet(name, content string, err error) {
96102
// Since we return an ReadCloser and it can be consumed only once,
97103
// each time the mocked Get() is called we do create a new one, so
98104
// that getting the same mocked object twice works as expected.
99-
mockedGet := m.On("Get", mock.Anything, name)
100-
mockedGet.Run(func(args mock.Arguments) {
101-
mockedGet.Return(ioutil.NopCloser(bytes.NewReader([]byte(content))), err)
105+
m.On("Get", mock.Anything, name).Return(func(_ context.Context, _ string) (io.ReadCloser, error) {
106+
return ioutil.NopCloser(bytes.NewReader([]byte(content))), err
102107
})
103108
} else {
104109
m.On("Exists", mock.Anything, name).Return(false, err)

pkg/storage/bucket/client_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ package bucket
77

88
import (
99
"context"
10+
"io"
11+
"sync"
1012
"testing"
1113

1214
"github.com/grafana/dskit/flagext"
@@ -97,3 +99,31 @@ func TestNewClient(t *testing.T) {
9799
})
98100
}
99101
}
102+
103+
func TestClientMock_MockGet(t *testing.T) {
104+
expected := "body"
105+
106+
m := ClientMock{}
107+
m.MockGet("test", expected, nil)
108+
109+
// Run many goroutines all requesting the same mocked object and
110+
// ensure there's no race.
111+
wg := sync.WaitGroup{}
112+
for i := 0; i < 1000; i++ {
113+
wg.Add(1)
114+
go func() {
115+
defer wg.Done()
116+
117+
reader, err := m.Get(context.Background(), "test")
118+
require.NoError(t, err)
119+
120+
actual, err := io.ReadAll(reader)
121+
require.NoError(t, err)
122+
require.Equal(t, []byte(expected), actual)
123+
124+
require.NoError(t, reader.Close())
125+
}()
126+
}
127+
128+
wg.Wait()
129+
}

0 commit comments

Comments
 (0)