diff --git a/allsrv/client_http.go b/allsrv/client_http.go index 958e8b5..b582ad5 100644 --- a/allsrv/client_http.go +++ b/allsrv/client_http.go @@ -1,182 +1,107 @@ package allsrv import ( - "bytes" "context" - "encoding/json" - "io" "net/http" "time" "github.com/jsteenb2/errors" + + "github.com/jsteenb2/allsrvc" ) type ClientHTTP struct { - addr string - c *http.Client + c *allsrvc.ClientHTTP } var _ SVC = (*ClientHTTP)(nil) -func NewClientHTTP(addr string, c *http.Client) *ClientHTTP { +func NewClientHTTP(addr, origin string, c *http.Client, opts ...func(*allsrvc.ClientHTTP)) *ClientHTTP { return &ClientHTTP{ - addr: addr, - c: c, + c: allsrvc.NewClientHTTP(addr, origin, c, opts...), } } func (c *ClientHTTP) CreateFoo(ctx context.Context, f Foo) (Foo, error) { - req, err := jsonReq(ctx, "POST", c.fooPath(""), toReqCreateFooV1(f)) + resp, err := c.c.CreateFoo(ctx, allsrvc.FooCreateAttrs{ + Name: f.Name, + Note: f.Note, + }) if err != nil { return Foo{}, InternalErr(err.Error()) } - return returnsFooReq(c.c, req) + newFoo, err := takeRespFoo(resp) + return newFoo, errors.Wrap(err) } func (c *ClientHTTP) ReadFoo(ctx context.Context, id string) (Foo, error) { - if id == "" { - return Foo{}, errIDRequired - } - - req, err := http.NewRequestWithContext(ctx, "GET", c.fooPath(id), nil) + resp, err := c.c.ReadFoo(ctx, id) if err != nil { - return Foo{}, InternalErr(err.Error()) + if errors.Is(err, allsrvc.ErrIDRequired) { + return Foo{}, errIDRequired + } } - return returnsFooReq(c.c, req) + + newFoo, err := takeRespFoo(resp) + return newFoo, errors.Wrap(err) } func (c *ClientHTTP) UpdateFoo(ctx context.Context, f FooUpd) (Foo, error) { - req, err := jsonReq(ctx, "PATCH", c.fooPath(f.ID), toReqUpdateFooV1(f)) + resp, err := c.c.UpdateFoo(ctx, f.ID, allsrvc.FooUpdAttrs{ + Name: f.Name, + Note: f.Note, + }) if err != nil { return Foo{}, InternalErr(err.Error()) } - return returnsFooReq(c.c, req) + newFoo, err := takeRespFoo(resp) + return newFoo, errors.Wrap(err) } func (c *ClientHTTP) DelFoo(ctx context.Context, id string) error { - if id == "" { - return errIDRequired - } - - req, err := http.NewRequestWithContext(ctx, "DELETE", c.fooPath(id), nil) + resp, err := c.c.DelFoo(ctx, id) if err != nil { - return InternalErr(err.Error()) - } - - _, err = doReq[any](c.c, req) - return err -} - -func (c *ClientHTTP) fooPath(id string) string { - u := c.addr + "/v1/foos" - if id == "" { - return u - } - return u + "/" + id -} - -func jsonReq(ctx context.Context, method, path string, v any) (*http.Request, error) { - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(v); err != nil { - return nil, InvalidErr("failed to marshal payload: " + err.Error()) - } - - req, err := http.NewRequestWithContext(ctx, method, path, &buf) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - - return req, nil -} - -func returnsFooReq(c *http.Client, req *http.Request) (Foo, error) { - data, err := doReq[ResourceFooAttrs](c, req) - if err != nil { - return Foo{}, err - } - return toFoo(data), nil -} - -func doReq[Attr Attrs](c *http.Client, req *http.Request) (Data[Attr], error) { - resp, err := c.Do(req) - if err != nil { - return *new(Data[Attr]), InternalErr(err.Error()) - } - defer func() { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - }() - - if resp.Header.Get("Content-Type") != "application/json" { - b, err := io.ReadAll(io.LimitReader(resp.Body, 500<<10)) - if err != nil { - return *new(Data[Attr]), InternalErr("failed to read response body: ", err.Error()) + if errors.Is(err, allsrvc.ErrIDRequired) { + return errIDRequired } - return *new(Data[Attr]), InternalErr("invalid content type received; content=" + string(b)) - } - // TODO(berg): handle unexpected status code (502|503|etc) - - var respBody RespBody[Attr] - err = json.NewDecoder(resp.Body).Decode(&respBody) - if err != nil { - return *new(Data[Attr]), InternalErr(err.Error()) } + + return errors.Wrap(convertSDKErrors(resp.Errs)) +} - var errs []error - for _, respErr := range respBody.Errs { - errs = append(errs, toErr(respErr)) +func takeRespFoo(respBody allsrvc.RespBody[allsrvc.ResourceFooAttrs]) (Foo, error) { + if err := convertSDKErrors(respBody.Errs); err != nil { + return Foo{}, errors.Wrap(err) } - if len(errs) == 1 { - return *new(Data[Attr]), errs[0] - } - if len(errs) > 1 { - return *new(Data[Attr]), errors.Join(errs) - } - + if respBody.Data == nil { - return *new(Data[Attr]), nil - } - - return *respBody.Data, nil -} - -func toReqCreateFooV1(f Foo) ReqCreateFooV1 { - return ReqCreateFooV1{ - Data: Data[FooCreateAttrs]{ - Type: "foo", - Attrs: FooCreateAttrs{ - Name: f.Name, - Note: f.Note, - }, - }, + return Foo{}, nil } -} - -func toReqUpdateFooV1(f FooUpd) ReqUpdateFooV1 { - return ReqUpdateFooV1{ - Data: Data[FooUpdAttrs]{ - Type: "foo", - ID: f.ID, - Attrs: FooUpdAttrs{ - Name: f.Name, - Note: f.Note, - }, - }, + + f := Foo{ + ID: respBody.Data.ID, + Name: respBody.Data.Attrs.Name, + Note: respBody.Data.Attrs.Note, + CreatedAt: toTime(respBody.Data.Attrs.CreatedAt), + UpdatedAt: toTime(respBody.Data.Attrs.UpdatedAt), } + + return f, nil } -func toFoo(d Data[ResourceFooAttrs]) Foo { - return Foo{ - ID: d.ID, - Name: d.Attrs.Name, - Note: d.Attrs.Note, - CreatedAt: toTime(d.Attrs.CreatedAt), - UpdatedAt: toTime(d.Attrs.UpdatedAt), +func convertSDKErrors(errs []allsrvc.RespErr) error { + // TODO(@berg): update this to slices pkg when 1.23 lands + switch out := toSlc(errs, toErr); { + case len(out) == 1: + return out[0] + case len(out) > 1: + return errors.Join(out) + default: + return nil } } -func toErr(respErr RespErr) error { +func toErr(respErr allsrvc.RespErr) error { errFn := InternalErr switch respErr.Code { case errCodeExist: @@ -199,3 +124,11 @@ func toTime(in string) time.Time { t, _ := time.Parse(time.RFC3339, in) return t } + +func toSlc[In, Out any](in []In, to func(In) Out) []Out { + out := make([]Out, len(in)) + for _, v := range in { + out = append(out, to(v)) + } + return out +} diff --git a/allsrv/cmd/allsrvc/main.go b/allsrv/cmd/allsrvc/main.go index 0cb49e9..2065204 100644 --- a/allsrv/cmd/allsrvc/main.go +++ b/allsrv/cmd/allsrvc/main.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" + "github.com/jsteenb2/allsrvc" "github.com/jsteenb2/mess/allsrv" ) @@ -23,8 +24,15 @@ func newCmd() *cobra.Command { return c.cmd() } +const name = "allsrvc" + type cli struct { + // base flags addr string + pass string + user string + + // foo flags id string name string note string @@ -32,7 +40,8 @@ type cli struct { func (c *cli) cmd() *cobra.Command { cmd := cobra.Command{ - Use: "allsrvc", + Use: name, + SilenceUsage: true, } cmd.AddCommand( @@ -65,7 +74,7 @@ func (c *cli) cmdCreateFoo() *cobra.Command { return json.NewEncoder(cmd.OutOrStderr()).Encode(f) }, } - cmd.Flags().StringVar(&c.addr, "addr", "http://localhost:8091", "addr for foo svc") + c.registerCommonFlags(&cmd) cmd.Flags().StringVar(&c.name, "name", "", "name of the new foo") cmd.Flags().StringVar(&c.note, "note", "", "optional foo note") @@ -88,8 +97,7 @@ func (c *cli) cmdReadFoo() *cobra.Command { return json.NewEncoder(cmd.OutOrStderr()).Encode(f) }, } - cmd.Flags().StringVar(&c.addr, "addr", "http://localhost:8091", "addr for foo svc") - + c.registerCommonFlags(&cmd) return &cmd } @@ -119,7 +127,7 @@ func (c *cli) cmdUpdateFoo() *cobra.Command { return json.NewEncoder(cmd.OutOrStderr()).Encode(f) }, } - cmd.Flags().StringVar(&c.addr, "addr", "http://localhost:8091", "addr for foo svc") + c.registerCommonFlags(&cmd) cmd.Flags().StringVar(&c.id, "id", "", "id of the foo resource") cmd.Flags().StringVar(&c.name, "name", "", "optional foo name") cmd.Flags().StringVar(&c.note, "note", "", "optional foo note") @@ -137,11 +145,21 @@ func (c *cli) cmdRmFoo() *cobra.Command { return client.DelFoo(cmd.Context(), args[0]) }, } - cmd.Flags().StringVar(&c.addr, "addr", "http://localhost:8091", "addr for foo svc") - + c.registerCommonFlags(&cmd) return &cmd } func (c *cli) newClient() *allsrv.ClientHTTP { - return allsrv.NewClientHTTP(c.addr, &http.Client{Timeout: 5 * time.Second}) + return allsrv.NewClientHTTP( + c.addr, + name, + &http.Client{Timeout: 5 * time.Second}, + allsrvc.WithBasicAuth(c.user, c.pass), + ) +} + +func (c *cli) registerCommonFlags(cmd *cobra.Command) { + cmd.Flags().StringVar(&c.addr, "addr", "http://localhost:8091", "addr for foo svc") + cmd.Flags().StringVar(&c.user, "user", "admin", "user for basic auth") + cmd.Flags().StringVar(&c.pass, "password", "pass", "password for basic auth") } diff --git a/allsrv/server_v2.go b/allsrv/server_v2.go index 2c383a2..499b077 100644 --- a/allsrv/server_v2.go +++ b/allsrv/server_v2.go @@ -5,10 +5,12 @@ import ( "encoding/json" "net/http" "time" - + "github.com/gofrs/uuid" "github.com/hashicorp/go-metrics" "github.com/jsteenb2/errors" + + "github.com/jsteenb2/allsrvc" ) type SvrOptFn func(o *serverOpts) @@ -38,32 +40,32 @@ func NewServerV2(svc SVC, opts ...SvrOptFn) *ServerV2 { for _, o := range opts { o(&opt) } - + s := ServerV2{ svc: svc, mux: opt.mux, } - + var mw []func(http.Handler) http.Handler if opt.authFn != nil { mw = append(mw, opt.authFn) } - mw = append(mw, withTraceID, withStartTime) + mw = append(mw, withOriginUserAgent, withTraceID, withStartTime) if opt.met != nil { // put metrics last since these are executed LIFO mw = append(mw, ObserveHandler("v2", opt.met)) } mw = append(mw, recoverer) - + s.mw = applyMW(mw...) - + s.routes() - + return &s } func (s *ServerV2) routes() { withContentTypeJSON := applyMW(contentTypeJSON, s.mw) - + // 9) s.mux.Handle("POST /v1/foos", withContentTypeJSON(jsonIn(resourceTypeFoo, http.StatusCreated, s.createFooV1))) s.mux.Handle("GET /v1/foos/{id}", s.mw(read(s.readFooV1))) @@ -76,95 +78,11 @@ func (s *ServerV2) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.mux.ServeHTTP(w, r) } -// API envelope types -type ( - // RespBody represents a JSON-API response body. - // https://jsonapi.org/format/#document-top-level - // - // note: data can be either an array or a single resource object. This allows for both. - RespBody[Attr Attrs] struct { - Meta RespMeta `json:"meta"` - Errs []RespErr `json:"errors,omitempty"` - Data *Data[Attr] `json:"data,omitempty"` - } - - // Attrs can be either a document or a collection of documents. - Attrs interface { - any | []Attrs - } - - // RespMeta represents a JSON-API meta object. The data here is - // useful for our example service. You can add whatever non-standard - // context that is relevant to your domain here. - // https://jsonapi.org/format/#document-meta - RespMeta struct { - TookMilli int `json:"took_ms"` - TraceID string `json:"trace_id"` - } - - // RespErr represents a JSON-API error object. Do note that we - // aren't implementing the entire error type. Just the most impactful - // bits for this workshop. Mainly, skipping Title & description separation. - // https://jsonapi.org/format/#error-objects - RespErr struct { - Status int `json:"status,string"` - Code int `json:"code"` - Msg string `json:"message"` - Source *RespErrSource `json:"source"` - } - - // RespErrSource represents a JSON-API err source. - // https://jsonapi.org/format/#error-objects - RespErrSource struct { - Pointer string `json:"pointer"` - Parameter string `json:"parameter,omitempty"` - Header string `json:"header,omitempty"` - } - - // ReqBody represents a JSON-API request body. - // https://jsonapi.org/format/#crud-creating - ReqBody[Attr Attrs] struct { - Data Data[Attr] `json:"data"` - } -) - -// Data represents a JSON-API data response. -// -// https://jsonapi.org/format/#document-top-level -type Data[Attr Attrs] struct { - Type string `json:"type"` - ID string `json:"id"` - Attrs Attr `json:"attributes"` - - // omitting the relationships here for brevity not at lvl 3 RMM -} - -func (d Data[Attr]) getType() string { - return d.Type -} - const ( resourceTypeFoo = "foo" ) -type ( - ReqCreateFooV1 = ReqBody[FooCreateAttrs] - - FooCreateAttrs struct { - Name string `json:"name"` - Note string `json:"note"` - } - - // ResourceFooAttrs are the attributes of a foo resource. - ResourceFooAttrs struct { - Name string `json:"name"` - Note string `json:"note"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` - } -) - -func (s *ServerV2) createFooV1(ctx context.Context, req ReqCreateFooV1) (*Data[ResourceFooAttrs], []RespErr) { +func (s *ServerV2) createFooV1(ctx context.Context, req allsrvc.ReqBody[allsrvc.FooCreateAttrs]) (*allsrvc.Data[allsrvc.ResourceFooAttrs], []allsrvc.RespErr) { newFoo, err := s.svc.CreateFoo(ctx, Foo{ Name: req.Data.Attrs.Name, Note: req.Data.Attrs.Note, @@ -172,35 +90,26 @@ func (s *ServerV2) createFooV1(ctx context.Context, req ReqCreateFooV1) (*Data[R if err != nil { respErr := toRespErr(err) if errors.Is(err, ErrKindExists) { - respErr.Source = &RespErrSource{Pointer: "/data/attributes/name"} + respErr.Source = &allsrvc.RespErrSource{Pointer: "/data/attributes/name"} } - return nil, []RespErr{respErr} + return nil, []allsrvc.RespErr{respErr} } - + out := fooToData(newFoo) return &out, nil } -func (s *ServerV2) readFooV1(ctx context.Context, r *http.Request) (*Data[ResourceFooAttrs], []RespErr) { +func (s *ServerV2) readFooV1(ctx context.Context, r *http.Request) (*allsrvc.Data[allsrvc.ResourceFooAttrs], []allsrvc.RespErr) { f, err := s.svc.ReadFoo(ctx, r.PathValue("id")) if err != nil { - return nil, []RespErr{toRespErr(err)} + return nil, []allsrvc.RespErr{toRespErr(err)} } - + out := fooToData(f) return &out, nil } -type ( - ReqUpdateFooV1 = ReqBody[FooUpdAttrs] - - FooUpdAttrs struct { - Name *string `json:"name"` - Note *string `json:"note"` - } -) - -func (s *ServerV2) updateFooV1(ctx context.Context, req ReqUpdateFooV1) (*Data[ResourceFooAttrs], []RespErr) { +func (s *ServerV2) updateFooV1(ctx context.Context, req allsrvc.ReqBody[allsrvc.FooUpdAttrs]) (*allsrvc.Data[allsrvc.ResourceFooAttrs], []allsrvc.RespErr) { existing, err := s.svc.UpdateFoo(ctx, FooUpd{ ID: req.Data.ID, Name: req.Data.Attrs.Name, @@ -209,37 +118,33 @@ func (s *ServerV2) updateFooV1(ctx context.Context, req ReqUpdateFooV1) (*Data[R if err != nil { respErr := toRespErr(err) if errors.Is(err, ErrKindExists) { - respErr.Source = &RespErrSource{Pointer: "/data/attributes/name"} + respErr.Source = &allsrvc.RespErrSource{Pointer: "/data/attributes/name"} } - return nil, []RespErr{respErr} + return nil, []allsrvc.RespErr{respErr} } - + out := fooToData(existing) return &out, nil } -func (s *ServerV2) delFooV1(ctx context.Context, r *http.Request) []RespErr { +func (s *ServerV2) delFooV1(ctx context.Context, r *http.Request) []allsrvc.RespErr { id := r.PathValue("id") if err := s.svc.DelFoo(ctx, id); err != nil { - return []RespErr{toRespErr(err)} + return []allsrvc.RespErr{toRespErr(err)} } return nil } -func fooToData(f Foo) Data[ResourceFooAttrs] { - return toFooData(f.ID, ResourceFooAttrs{ - Name: f.Name, - Note: f.Note, - CreatedAt: toTimestamp(f.CreatedAt), - UpdatedAt: toTimestamp(f.UpdatedAt), - }) -} - -func toFooData(id string, attrs ResourceFooAttrs) Data[ResourceFooAttrs] { - return Data[ResourceFooAttrs]{ - Type: resourceTypeFoo, - ID: id, - Attrs: attrs, +func fooToData(f Foo) allsrvc.Data[allsrvc.ResourceFooAttrs] { + return allsrvc.Data[allsrvc.ResourceFooAttrs]{ + Type: resourceTypeFoo, + ID: f.ID, + Attrs: allsrvc.ResourceFooAttrs{ + Name: f.Name, + Note: f.Note, + CreatedAt: toTimestamp(f.CreatedAt), + UpdatedAt: toTimestamp(f.UpdatedAt), + }, } } @@ -247,50 +152,52 @@ func toTimestamp(t time.Time) string { return t.Format(time.RFC3339) } -func jsonIn[ReqAttr, RespAttr Attrs](resource string, successCode int, fn func(context.Context, ReqBody[ReqAttr]) (*Data[RespAttr], []RespErr)) http.Handler { - return handler(successCode, func(ctx context.Context, r *http.Request) (*Data[RespAttr], []RespErr) { - var reqBody ReqBody[ReqAttr] +func jsonIn[ReqAttr, RespAttr allsrvc.Attrs]( + resource string, + successCode int, + fn func(context.Context, allsrvc.ReqBody[ReqAttr]) (*allsrvc.Data[RespAttr], []allsrvc.RespErr), +) http.Handler { + return handler(successCode, func(ctx context.Context, r *http.Request) (*allsrvc.Data[RespAttr], []allsrvc.RespErr) { + var reqBody allsrvc.ReqBody[ReqAttr] if respErr := decodeReq(r, &reqBody); respErr != nil { - return nil, []RespErr{*respErr} + return nil, []allsrvc.RespErr{*respErr} } if reqBody.Data.Type != resource { - return nil, []RespErr{{ + return nil, []allsrvc.RespErr{{ Status: http.StatusUnprocessableEntity, Code: errCode(ErrKindInvalid), Msg: "type must be " + resource, - Source: &RespErrSource{ + Source: &allsrvc.RespErrSource{ Pointer: "/data/type", }, }} } - + return fn(r.Context(), reqBody) }) } -func read[Attr any | []Attr](fn func(ctx context.Context, r *http.Request) (*Data[Attr], []RespErr)) http.Handler { +func read[Attr any | []Attr](fn func(ctx context.Context, r *http.Request) (*allsrvc.Data[Attr], []allsrvc.RespErr)) http.Handler { return handler(http.StatusOK, fn) } -func del(fn func(ctx context.Context, r *http.Request) []RespErr) http.Handler { - return handler(http.StatusOK, func(ctx context.Context, r *http.Request) (*Data[any], []RespErr) { +func del(fn func(ctx context.Context, r *http.Request) []allsrvc.RespErr) http.Handler { + return handler(http.StatusOK, func(ctx context.Context, r *http.Request) (*allsrvc.Data[any], []allsrvc.RespErr) { return nil, fn(ctx, r) }) } -func handler[Attr Attrs](successCode int, fn func(ctx context.Context, req *http.Request) (*Data[Attr], []RespErr)) http.Handler { +func handler[Attr allsrvc.Attrs](successCode int, fn func(ctx context.Context, req *http.Request) (*allsrvc.Data[Attr], []allsrvc.RespErr)) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { out, errs := fn(r.Context(), r) - + status := successCode for _, e := range errs { if e.Status > status { status = e.Status } } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - json.NewEncoder(w).Encode(RespBody[Attr]{ + writeResp(w, status, allsrvc.RespBody[Attr]{ Meta: getMeta(r.Context()), Errs: errs, Data: out, @@ -298,12 +205,18 @@ func handler[Attr Attrs](successCode int, fn func(ctx context.Context, req *http }) } -func decodeReq[Attr Attrs](r *http.Request, v *ReqBody[Attr]) *RespErr { +func writeResp(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(body) // 10.b) +} + +func decodeReq[Attr allsrvc.Attrs](r *http.Request, v *allsrvc.ReqBody[Attr]) *allsrvc.RespErr { if err := json.NewDecoder(r.Body).Decode(v); err != nil { - respErr := RespErr{ + respErr := allsrvc.RespErr{ Status: http.StatusBadRequest, Msg: "failed to decode request body: " + err.Error(), - Source: &RespErrSource{ + Source: &allsrvc.RespErrSource{ Pointer: "/data", }, Code: errCode(ErrKindInvalid), @@ -314,21 +227,21 @@ func decodeReq[Attr Attrs](r *http.Request, v *ReqBody[Attr]) *RespErr { return &respErr } if r.Method == http.MethodPatch && r.PathValue("id") != v.Data.ID { - return &RespErr{ + return &allsrvc.RespErr{ Status: http.StatusBadRequest, Msg: "path id and data id must match", - Source: &RespErrSource{ + Source: &allsrvc.RespErrSource{ Pointer: "/data/id", }, Code: errCode(ErrKindInvalid), } } - + return nil } -func toRespErr(err error) RespErr { - return RespErr{ +func toRespErr(err error) allsrvc.RespErr { + return allsrvc.RespErr{ Status: errStatus(err), Code: errCode(err), Msg: err.Error(), @@ -357,14 +270,13 @@ func WithBasicAuthV2(adminUser, adminPass string) func(*serverOpts) { s.authFn = func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if user, pass, ok := r.BasicAuth(); !(ok && user == adminUser && pass == adminPass) { - w.WriteHeader(http.StatusUnauthorized) // 9) - json.NewEncoder(w).Encode(RespBody[any]{ + writeResp(w, http.StatusUnauthorized, allsrvc.RespBody[any]{ Meta: getMeta(r.Context()), - Errs: []RespErr{{ + Errs: []allsrvc.RespErr{{ Status: http.StatusUnauthorized, Code: errCode(ErrKindUnAuthed), Msg: "unauthorized access", - Source: &RespErrSource{ + Source: &allsrvc.RespErrSource{ Header: "Authorization", }, }}, @@ -381,10 +293,9 @@ func contentTypeJSON(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ct := r.Header.Get("Content-Type") if ct != "application/json" { - w.WriteHeader(http.StatusUnsupportedMediaType) - json.NewEncoder(w).Encode(RespBody[any]{ + writeResp(w, http.StatusUnsupportedMediaType, allsrvc.RespBody[any]{ Meta: getMeta(r.Context()), - Errs: []RespErr{{ + Errs: []allsrvc.RespErr{{ Code: http.StatusUnsupportedMediaType, Msg: "received invalid media type", }}, @@ -395,8 +306,8 @@ func contentTypeJSON(next http.Handler) http.Handler { }) } -func getMeta(ctx context.Context) RespMeta { - return RespMeta{ +func getMeta(ctx context.Context) allsrvc.RespMeta { + return allsrvc.RespMeta{ TookMilli: int(took(ctx).Milliseconds()), TraceID: getTraceID(ctx), } @@ -409,23 +320,27 @@ func recoverer(next http.Handler) http.Handler { if rvr == nil { return } - + if rvr == http.ErrAbortHandler { // we don't recover http.ErrAbortHandler so the response // to the client is aborted, this should not be logged panic(rvr) } - + w.WriteHeader(http.StatusInternalServerError) }() - + next.ServeHTTP(w, r) }) } +type ctxKey string + const ( - ctxStartTime = "start" - ctxTraceID = "trace-id" + ctxKeyOrigin ctxKey = "origin" + ctxStartTime ctxKey = "start" + ctxTraceID ctxKey = "trace-id" + ctxKeyUserAgent ctxKey = "user_agent" ) func withTraceID(next http.Handler) http.Handler { @@ -439,11 +354,6 @@ func withTraceID(next http.Handler) http.Handler { }) } -func getTraceID(ctx context.Context) string { - traceID, _ := ctx.Value(ctxTraceID).(string) - return traceID -} - func withStartTime(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), ctxStartTime, time.Now()) @@ -451,6 +361,31 @@ func withStartTime(next http.Handler) http.Handler { }) } +func withOriginUserAgent(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxKeyOrigin, r.Header.Get("Origin")) + ctx = context.WithValue(ctx, ctxKeyUserAgent, r.Header.Get("User-Agent")) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func getTraceID(ctx context.Context) string { + traceID, _ := ctx.Value(ctxTraceID).(string) + return traceID +} + +func getOrigin(ctx context.Context) string { + origin, _ := ctx.Value(ctxKeyOrigin).(string) + return origin +} + +func getUserAgent(ctx context.Context) string { + userAgent, _ := ctx.Value(ctxKeyUserAgent).(string) + return userAgent +} + func took(ctx context.Context) time.Duration { t, _ := ctx.Value(ctxStartTime).(time.Time) return time.Since(t) diff --git a/allsrv/server_v2_test.go b/allsrv/server_v2_test.go index 789038a..1bffd6f 100644 --- a/allsrv/server_v2_test.go +++ b/allsrv/server_v2_test.go @@ -7,10 +7,12 @@ import ( "net/http/httptest" "testing" "time" - + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - + + "github.com/jsteenb2/allsrvc" + "github.com/jsteenb2/mess/allsrv" "github.com/jsteenb2/mess/allsrv/allsrvtesting" ) @@ -20,9 +22,9 @@ func TestServerV2HttpClient(t *testing.T) { svc := allsrvtesting.NewInmemSVC(t, opts) srv := httptest.NewServer(allsrv.NewServerV2(svc)) t.Cleanup(srv.Close) - + return allsrvtesting.SVCDeps{ - SVC: allsrv.NewClientHTTP(srv.URL, &http.Client{Timeout: time.Second}), + SVC: allsrv.NewClientHTTP(srv.URL, "allsrv_test", &http.Client{Timeout: time.Second}), } }) } @@ -32,9 +34,9 @@ func TestServerV2(t *testing.T) { inputs struct { req *http.Request } - + wantFn func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) - + testCase struct { name string prepare func(t *testing.T, db allsrv.DB) @@ -44,30 +46,30 @@ func TestServerV2(t *testing.T) { want wantFn } ) - + start := time.Time{}.Add(time.Hour).UTC() - + testSvr := func(t *testing.T, tt testCase) { db := new(allsrv.InmemDB) - + if tt.prepare != nil { tt.prepare(t, db) } - + svcOpts := append(allsrvtesting.DefaultSVCOpts(start), tt.svcOpts...) svc := allsrv.NewService(db, svcOpts...) - + defaultSvrOpts := []allsrv.SvrOptFn{allsrv.WithMetrics(newTestMetrics(t))} svrOpts := append(defaultSvrOpts, tt.svrOpts...) - + rec := httptest.NewRecorder() - + svr := allsrv.NewServerV2(svc, svrOpts...) svr.ServeHTTP(rec, tt.inputs.req) - + tt.want(t, rec, db) } - + t.Run("foo create", func(t *testing.T) { tests := []testCase{ { @@ -75,10 +77,10 @@ func TestServerV2(t *testing.T) { svrOpts: []allsrv.SvrOptFn{allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd")}, inputs: inputs{ req: newJSONReq("POST", "/v1/foos", - newJSONBody(t, allsrv.ReqCreateFooV1{ - Data: allsrv.Data[allsrv.FooCreateAttrs]{ + newJSONBody(t, allsrvc.ReqBody[allsrvc.FooCreateAttrs]{ + Data: allsrvc.Data[allsrvc.FooCreateAttrs]{ Type: "foo", - Attrs: allsrv.FooCreateAttrs{ + Attrs: allsrvc.FooCreateAttrs{ Name: "first-foo", Note: "some note", }, @@ -89,17 +91,17 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusCreated, rec.Code) - expectData[allsrv.ResourceFooAttrs](t, rec.Body, allsrv.Data[allsrv.ResourceFooAttrs]{ + expectData[allsrvc.ResourceFooAttrs](t, rec.Body, allsrvc.Data[allsrvc.ResourceFooAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.ResourceFooAttrs{ + Attrs: allsrvc.ResourceFooAttrs{ Name: "first-foo", Note: "some note", CreatedAt: start.Format(time.RFC3339), UpdatedAt: start.Format(time.RFC3339), }, }) - + dbHasFoo(t, db, allsrv.Foo{ ID: "1", Name: "first-foo", @@ -114,10 +116,10 @@ func TestServerV2(t *testing.T) { svrOpts: []allsrv.SvrOptFn{allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd")}, inputs: inputs{ req: newJSONReq("POST", "/v1/foos", - newJSONBody(t, allsrv.ReqCreateFooV1{ - Data: allsrv.Data[allsrv.FooCreateAttrs]{ + newJSONBody(t, allsrvc.ReqBody[allsrvc.FooCreateAttrs]{ + Data: allsrvc.Data[allsrvc.FooCreateAttrs]{ Type: "foo", - Attrs: allsrv.FooCreateAttrs{ + Attrs: allsrvc.FooCreateAttrs{ Name: "first-foo", Note: "some note", }, @@ -128,15 +130,15 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusUnauthorized, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusUnauthorized, Code: 4, Msg: "unauthorized access", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Header: "Authorization", }, }) - + _, err := db.ReadFoo(context.TODO(), "1") require.Error(t, err) }, @@ -145,10 +147,10 @@ func TestServerV2(t *testing.T) { name: "when creating foo with name that collides with existing should fail", prepare: allsrvtesting.CreateFoos(allsrv.Foo{ID: "9000", Name: "existing-foo"}), inputs: inputs{ - req: newJSONReq("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ - Data: allsrv.Data[allsrv.FooCreateAttrs]{ + req: newJSONReq("POST", "/v1/foos", newJSONBody(t, allsrvc.ReqBody[allsrvc.FooCreateAttrs]{ + Data: allsrvc.Data[allsrvc.FooCreateAttrs]{ Type: "foo", - Attrs: allsrv.FooCreateAttrs{ + Attrs: allsrvc.FooCreateAttrs{ Name: "existing-foo", Note: "some note", }, @@ -157,15 +159,15 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusConflict, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusConflict, Code: 1, Msg: "foo existing-foo exists", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Pointer: "/data/attributes/name", }, }) - + _, err := db.ReadFoo(context.TODO(), "1") require.Error(t, err) }, @@ -173,10 +175,10 @@ func TestServerV2(t *testing.T) { { name: "when creating foo with invalid resource type should fail", inputs: inputs{ - req: newJSONReq("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ - Data: allsrv.Data[allsrv.FooCreateAttrs]{ + req: newJSONReq("POST", "/v1/foos", newJSONBody(t, allsrvc.ReqBody[allsrvc.FooCreateAttrs]{ + Data: allsrvc.Data[allsrvc.FooCreateAttrs]{ Type: "WRONGO", - Attrs: allsrv.FooCreateAttrs{ + Attrs: allsrvc.FooCreateAttrs{ Name: "first-foo", Note: "some note", }, @@ -185,28 +187,28 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusUnprocessableEntity, Code: 2, Msg: "type must be foo", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Pointer: "/data/type", }, }) - + _, err := db.ReadFoo(context.TODO(), "1") require.Error(t, err) }, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testSvr(t, tt) }) } }) - + t.Run("foo read", func(t *testing.T) { tests := []testCase{ { @@ -224,10 +226,10 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, _ allsrv.DB) { assert.Equal(t, http.StatusOK, rec.Code) - expectData[allsrv.ResourceFooAttrs](t, rec.Body, allsrv.Data[allsrv.ResourceFooAttrs]{ + expectData[allsrvc.ResourceFooAttrs](t, rec.Body, allsrvc.Data[allsrvc.ResourceFooAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.ResourceFooAttrs{ + Attrs: allsrvc.ResourceFooAttrs{ Name: "first-foo", Note: "some note", CreatedAt: start.Format(time.RFC3339), @@ -250,11 +252,11 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusUnauthorized, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusUnauthorized, Code: 4, Msg: "unauthorized access", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Header: "Authorization", }, }) @@ -267,7 +269,7 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, _ allsrv.DB) { assert.Equal(t, http.StatusNotFound, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusNotFound, Code: 3, Msg: "foo not found for id: 1", @@ -275,14 +277,14 @@ func TestServerV2(t *testing.T) { }, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testSvr(t, tt) }) } }) - + t.Run("foo update", func(t *testing.T) { tests := []testCase{ { @@ -297,11 +299,11 @@ func TestServerV2(t *testing.T) { svrOpts: []allsrv.SvrOptFn{allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd")}, inputs: inputs{ req: newJSONReq("PATCH", "/v1/foos/1", - newJSONBody(t, allsrv.ReqUpdateFooV1{ - Data: allsrv.Data[allsrv.FooUpdAttrs]{ + newJSONBody(t, allsrvc.ReqBody[allsrvc.FooUpdAttrs]{ + Data: allsrvc.Data[allsrvc.FooUpdAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.FooUpdAttrs{ + Attrs: allsrvc.FooUpdAttrs{ Name: allsrvtesting.Ptr("new-name"), Note: allsrvtesting.Ptr("new note"), }, @@ -312,17 +314,17 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusOK, rec.Code) - expectData[allsrv.ResourceFooAttrs](t, rec.Body, allsrv.Data[allsrv.ResourceFooAttrs]{ + expectData[allsrvc.ResourceFooAttrs](t, rec.Body, allsrvc.Data[allsrvc.ResourceFooAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.ResourceFooAttrs{ + Attrs: allsrvc.ResourceFooAttrs{ Name: "new-name", Note: "new note", CreatedAt: start.Format(time.RFC3339), UpdatedAt: start.Add(time.Hour).Format(time.RFC3339), }, }) - + dbHasFoo(t, db, allsrv.Foo{ ID: "1", Name: "new-name", @@ -342,11 +344,11 @@ func TestServerV2(t *testing.T) { svcOpts: []func(*allsrv.Service){allsrv.WithSVCNowFn(allsrvtesting.NowFn(start.Add(time.Hour), time.Hour))}, inputs: inputs{ req: newJSONReq("PATCH", "/v1/foos/1", - newJSONBody(t, allsrv.ReqUpdateFooV1{ - Data: allsrv.Data[allsrv.FooUpdAttrs]{ + newJSONBody(t, allsrvc.ReqBody[allsrvc.FooUpdAttrs]{ + Data: allsrvc.Data[allsrvc.FooUpdAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.FooUpdAttrs{ + Attrs: allsrvc.FooUpdAttrs{ Note: allsrvtesting.Ptr("new note"), }, }, @@ -356,17 +358,17 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusOK, rec.Code) - expectData[allsrv.ResourceFooAttrs](t, rec.Body, allsrv.Data[allsrv.ResourceFooAttrs]{ + expectData[allsrvc.ResourceFooAttrs](t, rec.Body, allsrvc.Data[allsrvc.ResourceFooAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.ResourceFooAttrs{ + Attrs: allsrvc.ResourceFooAttrs{ Name: "first-name", Note: "new note", CreatedAt: start.Format(time.RFC3339), UpdatedAt: start.Add(time.Hour).Format(time.RFC3339), }, }) - + dbHasFoo(t, db, allsrv.Foo{ ID: "1", Name: "first-name", @@ -387,11 +389,11 @@ func TestServerV2(t *testing.T) { svrOpts: []allsrv.SvrOptFn{allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd")}, inputs: inputs{ req: newJSONReq("PATCH", "/v1/foos/1", - newJSONBody(t, allsrv.ReqUpdateFooV1{ - Data: allsrv.Data[allsrv.FooUpdAttrs]{ + newJSONBody(t, allsrvc.ReqBody[allsrvc.FooUpdAttrs]{ + Data: allsrvc.Data[allsrvc.FooUpdAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.FooUpdAttrs{ + Attrs: allsrvc.FooUpdAttrs{ Note: allsrvtesting.Ptr("new note"), }, }, @@ -401,11 +403,11 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusUnauthorized, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusUnauthorized, Code: 4, Msg: "unauthorized access", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Header: "Authorization", }, }) @@ -415,11 +417,11 @@ func TestServerV2(t *testing.T) { name: "when updating foo too a name that collides with existing should fail", prepare: allsrvtesting.CreateFoos(allsrv.Foo{ID: "1", Name: "start-foo"}, allsrv.Foo{ID: "9000", Name: "existing-foo"}), inputs: inputs{ - req: newJSONReq("PATCH", "/v1/foos/1", newJSONBody(t, allsrv.ReqUpdateFooV1{ - Data: allsrv.Data[allsrv.FooUpdAttrs]{ + req: newJSONReq("PATCH", "/v1/foos/1", newJSONBody(t, allsrvc.ReqBody[allsrvc.FooUpdAttrs]{ + Data: allsrvc.Data[allsrvc.FooUpdAttrs]{ Type: "foo", ID: "1", - Attrs: allsrv.FooUpdAttrs{ + Attrs: allsrvc.FooUpdAttrs{ Name: allsrvtesting.Ptr("existing-foo"), Note: allsrvtesting.Ptr("some note"), }, @@ -428,15 +430,15 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusConflict, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusConflict, Code: 1, Msg: "foo existing-foo exists", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Pointer: "/data/attributes/name", }, }) - + dbHasFoo(t, db, allsrv.Foo{ ID: "1", Name: "start-foo", @@ -444,14 +446,14 @@ func TestServerV2(t *testing.T) { }, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testSvr(t, tt) }) } }) - + t.Run("foo delete", func(t *testing.T) { tests := []testCase{ { @@ -468,12 +470,12 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusOK, rec.Code) - expectJSONBody(t, rec.Body, func(t *testing.T, got allsrv.RespBody[any]) { + expectJSONBody(t, rec.Body, func(t *testing.T, got allsrvc.RespBody[any]) { require.Nil(t, got.Data) require.Nil(t, got.Errs) require.NotZero(t, got.Meta.TraceID) }) - + _, err := db.ReadFoo(context.TODO(), "1") require.Error(t, err) }, @@ -492,11 +494,11 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { assert.Equal(t, http.StatusUnauthorized, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusUnauthorized, Code: 4, Msg: "unauthorized access", - Source: &allsrv.RespErrSource{ + Source: &allsrvc.RespErrSource{ Header: "Authorization", }, }) @@ -509,7 +511,7 @@ func TestServerV2(t *testing.T) { }, want: func(t *testing.T, rec *httptest.ResponseRecorder, _ allsrv.DB) { assert.Equal(t, http.StatusNotFound, rec.Code) - expectErrs(t, rec.Body, allsrv.RespErr{ + expectErrs(t, rec.Body, allsrvc.RespErr{ Status: http.StatusNotFound, Code: 3, Msg: "foo not found for id: 1", @@ -517,7 +519,7 @@ func TestServerV2(t *testing.T) { }, }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testSvr(t, tt) @@ -526,38 +528,38 @@ func TestServerV2(t *testing.T) { }) } -func expectErrs(t *testing.T, r io.Reader, want ...allsrv.RespErr) { +func expectErrs(t *testing.T, r io.Reader, want ...allsrvc.RespErr) { t.Helper() - - expectJSONBody(t, r, func(t *testing.T, got allsrv.RespBody[any]) { + + expectJSONBody(t, r, func(t *testing.T, got allsrvc.RespBody[any]) { t.Helper() - + require.Nil(t, got.Data) require.NotEmpty(t, got.Errs) - + assert.Equal(t, want, got.Errs) }) } -func expectData[Attrs any | []any](t *testing.T, r io.Reader, want allsrv.Data[Attrs]) { +func expectData[Attrs allsrvc.Attrs](t *testing.T, r io.Reader, want allsrvc.Data[Attrs]) { t.Helper() - - expectJSONBody(t, r, func(t *testing.T, got allsrv.RespBody[Attrs]) { + + expectJSONBody(t, r, func(t *testing.T, got allsrvc.RespBody[Attrs]) { t.Helper() - + require.Empty(t, got.Errs) require.NotNil(t, got.Data) - + assert.Equal(t, want, *got.Data) }) } func dbHasFoo(t *testing.T, db allsrv.DB, want allsrv.Foo) { t.Helper() - + got, err := db.ReadFoo(context.TODO(), want.ID) require.NoError(t, err) - + assert.Equal(t, want, got) } diff --git a/allsrv/svc_mw_logging.go b/allsrv/svc_mw_logging.go index 9878538..2492fa9 100644 --- a/allsrv/svc_mw_logging.go +++ b/allsrv/svc_mw_logging.go @@ -24,8 +24,8 @@ type svcMWLogger struct { } func (s *svcMWLogger) CreateFoo(ctx context.Context, f Foo) (Foo, error) { - logFn := s.logFn("input_name", f.Name, "input_note", f.Note) - + logFn := s.logFn(ctx, "input_name", f.Name, "input_note", f.Note) + f, err := s.next.CreateFoo(ctx, f) logger := logFn(err) if err != nil { @@ -33,19 +33,19 @@ func (s *svcMWLogger) CreateFoo(ctx context.Context, f Foo) (Foo, error) { } else { logger.Info("foo created successfully", "new_foo_id", f.ID) } - + return f, err } func (s *svcMWLogger) ReadFoo(ctx context.Context, id string) (Foo, error) { - logFn := s.logFn("input_id", id) - + logFn := s.logFn(ctx, "input_id", id) + f, err := s.next.ReadFoo(ctx, id) logger := logFn(err) if err != nil { logger.Error("failed to read foo") } - + return f, err } @@ -57,9 +57,9 @@ func (s *svcMWLogger) UpdateFoo(ctx context.Context, f FooUpd) (Foo, error) { if f.Note != nil { fields = append(fields, "input_note", *f.Note) } - - logFn := s.logFn(fields...) - + + logFn := s.logFn(ctx, fields...) + updatedFoo, err := s.next.UpdateFoo(ctx, f) logger := logFn(err) if err != nil { @@ -67,13 +67,13 @@ func (s *svcMWLogger) UpdateFoo(ctx context.Context, f FooUpd) (Foo, error) { } else { logger.Info("foo updated successfully") } - + return updatedFoo, err } func (s *svcMWLogger) DelFoo(ctx context.Context, id string) error { - logFn := s.logFn("input_id", id) - + logFn := s.logFn(ctx, "input_id", id) + err := s.next.DelFoo(ctx, id) logger := logFn(err) if err != nil { @@ -81,16 +81,21 @@ func (s *svcMWLogger) DelFoo(ctx context.Context, id string) error { } else { logger.Info("foo deleted successfully") } - + return err } -func (s *svcMWLogger) logFn(fields ...any) func(error) *slog.Logger { +func (s *svcMWLogger) logFn(ctx context.Context, fields ...any) func(error) *slog.Logger { start := time.Now() return func(err error) *slog.Logger { logger := s.logger. With(fields...). - With("took_ms", time.Since(start).Round(time.Millisecond).String()) + With( + "took_ms", time.Since(start).Round(time.Millisecond).String(), + "origin", getOrigin(ctx), + "user_agent", getUserAgent(ctx), + "trace_id", getTraceID(ctx), + ) if err != nil { logger = logger.With("err", err.Error()) logger = logger.WithGroup("err_fields").With(errors.Fields(err)...) diff --git a/go.mod b/go.mod index d1e62e7..5365842 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,8 @@ require ( github.com/golang-migrate/migrate/v4 v4.17.0 github.com/hashicorp/go-metrics v0.5.3 github.com/jmoiron/sqlx v1.3.5 - github.com/jsteenb2/errors v0.2.0 + github.com/jsteenb2/allsrvc v0.4.0 + github.com/jsteenb2/errors v0.3.0 github.com/mattn/go-sqlite3 v1.14.19 github.com/opentracing/opentracing-go v1.2.0 github.com/spf13/cobra v1.8.0 diff --git a/go.sum b/go.sum index a5e4dd5..a70d5a4 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,10 @@ github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/jsteenb2/errors v0.2.0 h1:7LImy2u+6CAKJnw6Ug8xuW/THKH2fwWf1BAUwQuNaeQ= -github.com/jsteenb2/errors v0.2.0/go.mod h1:vLm/10zo41mY2s7yGpB654h094ShSoG9LKwbivG0joU= +github.com/jsteenb2/allsrvc v0.4.0 h1:hz+es8ZQBlPHmc646j/ilwJOEiqwAVQBaUEkQifGqyQ= +github.com/jsteenb2/allsrvc v0.4.0/go.mod h1:q72Q/DWXKY+UyvgEEfEx1sdn2m0osVDUNA9MlSccFQg= +github.com/jsteenb2/errors v0.3.0 h1:m45UhWJUnlrHMLu2JA9xYDJ7PaYwkdwoowTZZ+34hSs= +github.com/jsteenb2/errors v0.3.0/go.mod h1:vLm/10zo41mY2s7yGpB654h094ShSoG9LKwbivG0joU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=