From 584430174d957f48ae02be81484908c204ef8a28 Mon Sep 17 00:00:00 2001 From: Johnny Steenbergen Date: Sun, 14 Jan 2024 12:44:46 -0600 Subject: [PATCH] fix(allsrv): serialize access for in-mem value access to rm race condition This is pretty straight forward. We just add the `-race` flag to our `go test` invocation after we add some tests that access the db concurrently. There are many more test cases to add for this specific instance, but the point is made with the existing one. We're starting to get somewhat comfortable with our existing test suite. If this were a real world service, you'd have tests at a higher level and more integration and end to end testing. These tests are slow but are the most valuable as they are validating more integration points! Unit tests can easily create a false sense of safety. Especially when unit testing that is MOCKING ALL THE THINGS! Note, we can use other mechanisms to address the race condition. However, it gets complex... FAST. Katherine Cox-Buday's "Concurrency in Go" is a wonderful deep dive on the subject :-): Refs: [Concurrency in Go - Katherine Cox-Buday](https://katherine.cox-buday.com/concurrency-in-go/) --- allsrv/db_inmem.go | 16 +++++- allsrv/db_inmem_test.go | 119 ++++++++++++++++++++++++++++++++++++++++ allsrv/server.go | 4 +- 3 files changed, 136 insertions(+), 3 deletions(-) diff --git a/allsrv/db_inmem.go b/allsrv/db_inmem.go index 98f6b73..11f5a5c 100644 --- a/allsrv/db_inmem.go +++ b/allsrv/db_inmem.go @@ -3,14 +3,19 @@ package allsrv import ( "context" "errors" + "sync" ) // InmemDB is an in-memory store. type InmemDB struct { - m []Foo // 12) + mu sync.Mutex + m []Foo // 12) } func (db *InmemDB) CreateFoo(_ context.Context, f Foo) error { + db.mu.Lock() + defer db.mu.Unlock() + for _, existing := range db.m { if f.Name == existing.Name { return errors.New("foo " + f.Name + " exists") // 8) @@ -23,6 +28,9 @@ func (db *InmemDB) CreateFoo(_ context.Context, f Foo) error { } func (db *InmemDB) ReadFoo(_ context.Context, id string) (Foo, error) { + db.mu.Lock() + defer db.mu.Unlock() + for _, f := range db.m { if id == f.ID { return f, nil @@ -32,6 +40,9 @@ func (db *InmemDB) ReadFoo(_ context.Context, id string) (Foo, error) { } func (db *InmemDB) UpdateFoo(_ context.Context, f Foo) error { + db.mu.Lock() + defer db.mu.Unlock() + for i, existing := range db.m { if f.ID == existing.ID { db.m[i] = f @@ -42,6 +53,9 @@ func (db *InmemDB) UpdateFoo(_ context.Context, f Foo) error { } func (db *InmemDB) DelFoo(_ context.Context, id string) error { + db.mu.Lock() + defer db.mu.Unlock() + for i, f := range db.m { if id == f.ID { db.m = append(db.m[:i], db.m[i+1:]...) diff --git a/allsrv/db_inmem_test.go b/allsrv/db_inmem_test.go index 1f3e249..d4b1fb6 100644 --- a/allsrv/db_inmem_test.go +++ b/allsrv/db_inmem_test.go @@ -3,6 +3,7 @@ package allsrv_test import ( "context" "errors" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -30,6 +31,28 @@ func TestInmemDB(t *testing.T) { assert.Equal(t, want, got) }) + t.Run("with concurrent valid foo creates should pass", func(t *testing.T) { + db := new(allsrv.InmemDB) + + newFoo := func(id string) allsrv.Foo { + return allsrv.Foo{ + ID: id, + Name: "name-" + id, + Note: "note-" + id, + } + } + + var wg sync.WaitGroup + for _, f := range []allsrv.Foo{newFoo("1"), newFoo("2"), newFoo("3"), newFoo("4"), newFoo("5")} { + wg.Add(1) + go func(f allsrv.Foo) { + defer wg.Done() + require.NoError(t, db.CreateFoo(context.TODO(), f)) + }(f) + } + wg.Wait() + }) + t.Run("with foo containing name that already exists should fail", func(t *testing.T) { db := new(allsrv.InmemDB) @@ -68,6 +91,38 @@ func TestInmemDB(t *testing.T) { assert.Equal(t, want, got) }) + t.Run("with concurrent valid foo update the reading should pass", func(t *testing.T) { + db := new(allsrv.InmemDB) + require.NoError(t, db.CreateFoo(context.TODO(), allsrv.Foo{ + ID: "1", + Name: "one", + Note: "note", + })) + + newFoo := func(note string) allsrv.Foo { + return allsrv.Foo{ + ID: "1", + Name: "one", + Note: note, + } + } + + var wg sync.WaitGroup + for _, f := range []allsrv.Foo{newFoo("a"), newFoo("b"), newFoo("c"), newFoo("d"), newFoo("e")} { + wg.Add(1) + go func(f allsrv.Foo) { + defer wg.Done() + require.NoError(t, db.UpdateFoo(context.TODO(), f)) + }(f) + } + + got, err := db.ReadFoo(context.TODO(), "1") + require.NoError(t, err) + + assert.Contains(t, []string{"note", "a", "b", "c", "d", "e"}, got.Note) + wg.Wait() + }) + t.Run("with id for non-existent foo should fail", func(t *testing.T) { db := new(allsrv.InmemDB) @@ -102,6 +157,38 @@ func TestInmemDB(t *testing.T) { assert.Equal(t, want, got) }) + t.Run("with concurrent valid foo updates should pass", func(t *testing.T) { + db := new(allsrv.InmemDB) + require.NoError(t, db.CreateFoo(context.TODO(), allsrv.Foo{ + ID: "1", + Name: "one", + Note: "note", + })) + + newFoo := func(note string) allsrv.Foo { + return allsrv.Foo{ + ID: "1", + Name: "one", + Note: note, + } + } + + var wg sync.WaitGroup + for _, f := range []allsrv.Foo{newFoo("a"), newFoo("b"), newFoo("c"), newFoo("d"), newFoo("e")} { + wg.Add(1) + go func(f allsrv.Foo) { + defer wg.Done() + require.NoError(t, db.UpdateFoo(context.TODO(), f)) + }(f) + } + + got, err := db.ReadFoo(context.TODO(), "1") + require.NoError(t, err) + wg.Wait() + + assert.Contains(t, []string{"note", "a", "b", "c", "d", "e"}, got.Note) + }) + t.Run("with update for non-existent foo should fail", func(t *testing.T) { db := new(allsrv.InmemDB) @@ -140,6 +227,38 @@ func TestInmemDB(t *testing.T) { assert.Equal(t, want, err) }) + t.Run("with concurrent valid foo creates should pass", func(t *testing.T) { + db := new(allsrv.InmemDB) + + newFoo := func(id string) allsrv.Foo { + return allsrv.Foo{ + ID: id, + Name: "name-" + id, + Note: "note-" + id, + } + } + + for _, f := range []allsrv.Foo{newFoo("1"), newFoo("2"), newFoo("3"), newFoo("4"), newFoo("5")} { + require.NoError(t, db.CreateFoo(context.TODO(), f)) + } + + var wg sync.WaitGroup + for _, id := range []string{"1", "2", "3", "4", "5"} { + wg.Add(1) + go func(id string) { + defer wg.Done() + require.NoError(t, db.DelFoo(context.TODO(), id)) + }(id) + } + wg.Wait() + + for _, id := range []string{"1", "2", "3", "4", "5"} { + err := db.DelFoo(context.TODO(), id) + wantErr := errors.New("foo not found for id: " + id) + require.Error(t, wantErr, err) + } + }) + t.Run("with id for non-existent foo should fail", func(t *testing.T) { db := new(allsrv.InmemDB) diff --git a/allsrv/server.go b/allsrv/server.go index 63bedc8..4ccc7ae 100644 --- a/allsrv/server.go +++ b/allsrv/server.go @@ -21,7 +21,7 @@ import ( ✅4) router being used is the GLOBAL http.DefaultServeMux a) should avoid globals b) what happens if you have multiple servers in this go module who reference default serve mux? - 5) no tests + ✅5) no tests a) how do we ensure things work? b) how do we know what is intended by the current implementation? 6) http/db are coupled to the same type @@ -38,7 +38,7 @@ import ( b) logging ✅c) tracing ✅11) hard coding UUID generation into db - 12) possible race conditions in inmem store + ✅12) possible race conditions in inmem store ✅13) there is a bug in the delete foo inmem db implementation Praises: