diff --git a/allsrv/db_inmem.go b/allsrv/db_inmem.go index 29d6093..98f6b73 100644 --- a/allsrv/db_inmem.go +++ b/allsrv/db_inmem.go @@ -1,6 +1,7 @@ package allsrv import ( + "context" "errors" ) @@ -9,7 +10,7 @@ type InmemDB struct { m []Foo // 12) } -func (db *InmemDB) CreateFoo(f Foo) error { +func (db *InmemDB) CreateFoo(_ context.Context, f Foo) error { for _, existing := range db.m { if f.Name == existing.Name { return errors.New("foo " + f.Name + " exists") // 8) @@ -21,7 +22,7 @@ func (db *InmemDB) CreateFoo(f Foo) error { return nil } -func (db *InmemDB) ReadFoo(id string) (Foo, error) { +func (db *InmemDB) ReadFoo(_ context.Context, id string) (Foo, error) { for _, f := range db.m { if id == f.ID { return f, nil @@ -30,7 +31,7 @@ func (db *InmemDB) ReadFoo(id string) (Foo, error) { return Foo{}, errors.New("foo not found for id: " + id) // 8) } -func (db *InmemDB) UpdateFoo(f Foo) error { +func (db *InmemDB) UpdateFoo(_ context.Context, f Foo) error { for i, existing := range db.m { if f.ID == existing.ID { db.m[i] = f @@ -40,7 +41,7 @@ func (db *InmemDB) UpdateFoo(f Foo) error { return errors.New("foo not found for id: " + f.ID) // 8) } -func (db *InmemDB) DelFoo(id string) error { +func (db *InmemDB) DelFoo(_ context.Context, id string) error { for i, f := range db.m { if id == f.ID { db.m = append(db.m[:i], db.m[i+1:]...) diff --git a/allsrv/observe_db.go b/allsrv/observe_db.go index 1ea2be8..5cd38b9 100644 --- a/allsrv/observe_db.go +++ b/allsrv/observe_db.go @@ -1,9 +1,11 @@ package allsrv import ( + "context" "time" "github.com/hashicorp/go-metrics" + "github.com/opentracing/opentracing-go" ) const ( @@ -27,25 +29,37 @@ type dbMW struct { met *metrics.Metrics } -func (d *dbMW) CreateFoo(f Foo) error { +func (d *dbMW) CreateFoo(ctx context.Context, f Foo) error { + span, ctx := opentracing.StartSpanFromContext(ctx, d.name+"_foo_create") + defer span.Finish() + rec := d.record("create") - return rec(d.next.CreateFoo(f)) + return rec(d.next.CreateFoo(ctx, f)) } -func (d *dbMW) ReadFoo(id string) (Foo, error) { +func (d *dbMW) ReadFoo(ctx context.Context, id string) (Foo, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, d.name+"_foo_read") + defer span.Finish() + rec := d.record("read") - f, err := d.next.ReadFoo(id) + f, err := d.next.ReadFoo(ctx, id) return f, rec(err) } -func (d *dbMW) UpdateFoo(f Foo) error { +func (d *dbMW) UpdateFoo(ctx context.Context, f Foo) error { + span, ctx := opentracing.StartSpanFromContext(ctx, d.name+"_foo_update") + defer span.Finish() + rec := d.record("update") - return rec(d.next.UpdateFoo(f)) + return rec(d.next.UpdateFoo(ctx, f)) } -func (d *dbMW) DelFoo(id string) error { +func (d *dbMW) DelFoo(ctx context.Context, id string) error { + span, ctx := opentracing.StartSpanFromContext(ctx, d.name+"_foo_delete") + defer span.Finish() + rec := d.record("delete") - return rec(d.next.DelFoo(id)) + return rec(d.next.DelFoo(ctx, id)) } func (d *dbMW) record(op string) func(error) error { diff --git a/allsrv/server.go b/allsrv/server.go index c6bc612..49f5dd3 100644 --- a/allsrv/server.go +++ b/allsrv/server.go @@ -1,6 +1,7 @@ package allsrv import ( + "context" "encoding/json" "log" "net/http" @@ -51,10 +52,10 @@ import ( type ( // DB represents the foo persistence layer. DB interface { - CreateFoo(f Foo) error - ReadFoo(id string) (Foo, error) - UpdateFoo(f Foo) error - DelFoo(id string) error + CreateFoo(ctx context.Context, f Foo) error + ReadFoo(ctx context.Context, id string) (Foo, error) + UpdateFoo(ctx context.Context, f Foo) error + DelFoo(ctx context.Context, id string) error } ) @@ -135,7 +136,7 @@ func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) { f.ID = s.idFn() // 11) - if err := s.db.CreateFoo(f); err != nil { + if err := s.db.CreateFoo(r.Context(), f); err != nil { w.WriteHeader(http.StatusInternalServerError) // 9) return } @@ -147,7 +148,7 @@ func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) { } func (s *Server) readFoo(w http.ResponseWriter, r *http.Request) { - f, err := s.db.ReadFoo(r.URL.Query().Get("id")) + f, err := s.db.ReadFoo(r.Context(), r.URL.Query().Get("id")) if err != nil { w.WriteHeader(http.StatusNotFound) // 9) return @@ -165,14 +166,14 @@ func (s *Server) updateFoo(w http.ResponseWriter, r *http.Request) { return } - if err := s.db.UpdateFoo(f); err != nil { + if err := s.db.UpdateFoo(r.Context(), f); err != nil { w.WriteHeader(http.StatusInternalServerError) // 9) return } } func (s *Server) delFoo(w http.ResponseWriter, r *http.Request) { - if err := s.db.DelFoo(r.URL.Query().Get("id")); err != nil { + if err := s.db.DelFoo(r.Context(), r.URL.Query().Get("id")); err != nil { w.WriteHeader(http.StatusNotFound) // 9) return } diff --git a/allsrv/server_test.go b/allsrv/server_test.go index 470f6d7..ca99f21 100644 --- a/allsrv/server_test.go +++ b/allsrv/server_test.go @@ -2,6 +2,7 @@ package allsrv_test import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -65,7 +66,7 @@ func TestServer(t *testing.T) { t.Run("foo read", func(t *testing.T) { t.Run("when querying for existing foo id should pass", func(t *testing.T) { db := allsrv.ObserveDB("inmem", newTestMetrics(t))(new(allsrv.InmemDB)) - err := db.CreateFoo(allsrv.Foo{ + err := db.CreateFoo(context.TODO(), allsrv.Foo{ ID: "reader1", Name: "read", Note: "another note", @@ -107,7 +108,7 @@ func TestServer(t *testing.T) { t.Run("foo update", func(t *testing.T) { t.Run("when updating an existing foo with valid changes should pass", func(t *testing.T) { db := allsrv.ObserveDB("inmem", newTestMetrics(t))(new(allsrv.InmemDB)) - err := db.CreateFoo(allsrv.Foo{ + err := db.CreateFoo(context.TODO(), allsrv.Foo{ ID: "id1", Name: "first_name", Note: "first note", @@ -132,7 +133,7 @@ func TestServer(t *testing.T) { t.Run("when provided invalid basic auth should fail", func(t *testing.T) { db := new(allsrv.InmemDB) - err := db.CreateFoo(allsrv.Foo{ + err := db.CreateFoo(context.TODO(), allsrv.Foo{ ID: "id1", Name: "first_name", Note: "first note", @@ -158,7 +159,7 @@ func TestServer(t *testing.T) { t.Run("foo delete", func(t *testing.T) { t.Run("when deleting an existing foo should pass", func(t *testing.T) { db := allsrv.ObserveDB("inmem", newTestMetrics(t))(new(allsrv.InmemDB)) - err := db.CreateFoo(allsrv.Foo{ + err := db.CreateFoo(context.TODO(), allsrv.Foo{ ID: "id1", Name: "first_name", Note: "first note", diff --git a/go.mod b/go.mod index 8ea2ae1..d5da4f1 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 require ( github.com/gofrs/uuid v4.4.0+incompatible github.com/hashicorp/go-metrics v0.5.3 + github.com/opentracing/opentracing-go v1.2.0 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 3e87732..69503c8 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=