diff --git a/allsrv/errors.go b/allsrv/errors.go index d303578..dc53d51 100644 --- a/allsrv/errors.go +++ b/allsrv/errors.go @@ -6,6 +6,7 @@ const ( errTypeInvalid errTypeNotFound errTypeUnAuthed + errTypeInternal ) // Err provides a lightly structured error that we can attach behavior. Additionally, @@ -39,3 +40,8 @@ func NotFoundErr(msg string, fields ...any) error { Fields: fields, } } + +func isErrType(err error, want int) bool { + e, _ := err.(Err) + return err != nil && e.Type == want +} diff --git a/allsrv/server_v2.go b/allsrv/server_v2.go index 68177b3..19a60d1 100644 --- a/allsrv/server_v2.go +++ b/allsrv/server_v2.go @@ -11,19 +11,21 @@ import ( "github.com/hashicorp/go-metrics" ) -func WithMetrics(mets *metrics.Metrics) func(*serverOpts) { +type SvrOptFn func(o *serverOpts) + +func WithMetrics(mets *metrics.Metrics) SvrOptFn { return func(o *serverOpts) { o.met = mets } } -func WithMux(mux *http.ServeMux) func(*serverOpts) { +func WithMux(mux *http.ServeMux) SvrOptFn { return func(o *serverOpts) { o.mux = mux } } -func WithNowFn(fn func() time.Time) func(*serverOpts) { +func WithNowFn(fn func() time.Time) SvrOptFn { return func(o *serverOpts) { o.nowFn = fn } @@ -38,7 +40,7 @@ type ServerV2 struct { nowFn func() time.Time } -func NewServerV2(db DB, opts ...func(*serverOpts)) *ServerV2 { +func NewServerV2(db DB, opts ...SvrOptFn) *ServerV2 { opt := serverOpts{ mux: http.NewServeMux(), idFn: func() string { return uuid.Must(uuid.NewV4()).String() }, @@ -76,7 +78,7 @@ func (s *ServerV2) routes() { withContentTypeJSON := applyMW(contentTypeJSON, s.mw) // 9) - s.mux.Handle("POST /v1/foos", withContentTypeJSON(jsonIn(http.StatusCreated, s.createFooV1))) + s.mux.Handle("POST /v1/foos", withContentTypeJSON(jsonIn(resourceTypeFoo, http.StatusCreated, s.createFooV1))) } func (s *ServerV2) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -90,20 +92,15 @@ type ( // https://jsonapi.org/format/#document-top-level // // note: data can be either an array or a single resource object. This allows for both. - RespResourceBody[Attrs any | []any] struct { - Meta RespMeta `json:"meta"` - Errs []RespErr `json:"errors,omitempty"` - Data *RespData[Attrs] `json:"data,omitempty"` + RespResourceBody[Attr Attrs] struct { + Meta RespMeta `json:"meta"` + Errs []RespErr `json:"errors,omitempty"` + Data *Data[Attr] `json:"data,omitempty"` } - // RespData represents a JSON-API data response. - // https://jsonapi.org/format/#document-top-level - RespData[Attr any | []Attr] struct { - Type string `json:"type"` - ID string `json:"id"` - Attributes Attr `json:"attributes"` - - // omitting the relationships here for brevity not at lvl 3 RMM + // Attrs can be either a document or a collection of documents. + Attrs interface { + any | []Attrs } // RespMeta represents a JSON-API meta object. The data here is @@ -135,30 +132,47 @@ type ( } ) -type ( - // ReqCreateFooV1 represents the request body for the create foo API. - ReqCreateFooV1 struct { - Name string `json:"name"` - Note string `json:"note"` - } +// Data represents a JSON-API data response. +// +// https://jsonapi.org/format/#document-top-level +type Data[Attr Attrs] struct { + Type string `json:"type"` + ID string `json:"id"` + Attrs Attr `json:"attributes"` - // FooAttrs are the attributes for foo data. - FooAttrs struct { - Name string `json:"name"` - Note string `json:"note"` - CreatedAt string `json:"created_at"` - } + // omitting the relationships here for brevity not at lvl 3 RMM +} + +func (d Data[Attr]) getType() string { + return d.Type +} + +const ( + resourceTypeFoo = "foo" ) -func (s *ServerV2) createFooV1(ctx context.Context, req ReqCreateFooV1) (RespData[FooAttrs], []RespErr) { +type ReqCreateFooV1 = Data[FooAttrs] + +// FooAttrs are the attributes of a foo resource. +type FooAttrs struct { + Name string `json:"name"` + Note string `json:"note"` + CreatedAt string `json:"created_at"` +} + +func (s *ServerV2) createFooV1(ctx context.Context, req ReqCreateFooV1) (Data[FooAttrs], []RespErr) { newFoo := Foo{ ID: s.idFn(), - Name: req.Name, - Note: req.Note, + Name: req.Attrs.Name, + Note: req.Attrs.Note, CreatedAt: s.nowFn(), } if err := s.db.CreateFoo(ctx, newFoo); err != nil { - return RespData[FooAttrs]{}, toRespErrs(err) + respErr := toRespErr(err) + if isErrType(err, errTypeExists) { + respErr.Source = &RespErrSource{Pointer: "/data/attributes/name"} + } + return Data[FooAttrs]{}, []RespErr{respErr} } out := newFooData(newFoo.ID, FooAttrs{ @@ -169,11 +183,11 @@ func (s *ServerV2) createFooV1(ctx context.Context, req ReqCreateFooV1) (RespDat return out, nil } -func newFooData(id string, attrs FooAttrs) RespData[FooAttrs] { - return RespData[FooAttrs]{ - Type: "foo", - ID: id, - Attributes: attrs, +func newFooData(id string, attrs FooAttrs) Data[FooAttrs] { + return Data[FooAttrs]{ + Type: resourceTypeFoo, + ID: id, + Attrs: attrs, } } @@ -181,17 +195,36 @@ func toTimestamp(t time.Time) string { return t.Format(time.RFC3339) } -func jsonIn[ReqBody, Attr any](successCode int, fn func(context.Context, ReqBody) (RespData[Attr], []RespErr)) http.Handler { +func jsonIn[ + Attr Attrs, + ReqBody interface { + Data[Attr] + // this is limited by go's generics in a big way, which is very unfortunate :-( + // https://github.com/golang/go/issues/48522 + getType() string + }, +](resource string, successCode int, fn func(context.Context, ReqBody) (Data[Attr], []RespErr)) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ( reqBody ReqBody errs []RespErr - out *RespData[Attr] + out *Data[Attr] ) if respErr := decodeReq(r, &reqBody); respErr != nil { errs = append(errs, *respErr) - } else { - var data RespData[Attr] + } + if len(errs) == 0 && reqBody.getType() != resource { + errs = append(errs, RespErr{ + Status: http.StatusUnprocessableEntity, + Code: errTypeInvalid, + Msg: "type must be " + resource, + Source: &RespErrSource{ + Pointer: "/data/type", + }, + }) + } + if len(errs) == 0 { + var data Data[Attr] data, errs = fn(r.Context(), reqBody) if len(errs) == 0 { out = &data @@ -225,7 +258,7 @@ func decodeReq(r *http.Request, v any) *RespErr { Code: errTypeInvalid, } if unmarshErr := new(json.UnmarshalTypeError); errors.As(err, &unmarshErr) { - respErr.Source.Pointer += "/attributes/" + unmarshErr.Field + respErr.Source.Pointer += "/data" } return &respErr } @@ -233,28 +266,19 @@ func decodeReq(r *http.Request, v any) *RespErr { return nil } -func toRespErrs(err error) []RespErr { - if e := new(Err); errors.As(err, e) { - return []RespErr{{ - Code: errCode(e), - Msg: e.Msg, - }} +func toRespErr(err error) RespErr { + out := RespErr{ + Status: http.StatusInternalServerError, + Code: errTypeInternal, + Msg: err.Error(), } - - errs, ok := err.(interface{ Unwrap() []error }) - if !ok { - return nil - } - - var out []RespErr - for _, e := range errs.Unwrap() { - out = append(out, toRespErrs(e)...) + if e := new(Err); errors.As(err, e) { + out.Status, out.Code = errStatus(e), e.Type } - return out } -func errCode(err *Err) int { +func errStatus(err *Err) int { switch err.Type { case errTypeExists: return http.StatusConflict diff --git a/allsrv/server_v2_test.go b/allsrv/server_v2_test.go index f35de81..986bf4e 100644 --- a/allsrv/server_v2_test.go +++ b/allsrv/server_v2_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strconv" "testing" "time" @@ -15,114 +16,243 @@ import ( ) func TestServerV2(t *testing.T) { + start := time.Time{}.Add(time.Hour).UTC() + t.Run("foo create", func(t *testing.T) { - t.Run("when provided a valid foo should pass", func(t *testing.T) { - db := new(allsrv.InmemDB) - - var svr http.Handler = allsrv.NewServerV2( - db, - allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd"), - allsrv.WithMetrics(newTestMetrics(t)), - allsrv.WithIDFn(func() string { - return "id1" - }), - allsrv.WithNowFn(func() time.Time { - return time.Time{}.UTC().Add(time.Hour) - }), - ) - - req := httptest.NewRequest("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ - Name: "first-foo", - Note: "some note", - })) - req.Header.Set("Content-Type", "application/json") - req.SetBasicAuth("dodgers@stink.com", "PaSsWoRd") - rec := httptest.NewRecorder() - - svr.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusCreated, rec.Code) - expectData[allsrv.FooAttrs](t, rec.Body, allsrv.RespData[allsrv.FooAttrs]{ - Type: "foo", - ID: "id1", - Attributes: allsrv.FooAttrs{ - Name: "first-foo", - Note: "some note", - CreatedAt: time.Time{}.UTC().Add(time.Hour).Format(time.RFC3339), + type ( + inputs struct { + req *http.Request + } + + wantFn func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) + ) + + tests := []struct { + name string + prepare func(t *testing.T, db allsrv.DB) + svrOpts []allsrv.SvrOptFn + inputs inputs + want wantFn + }{ + { + name: "when provided a valid foo and authorized user should pass", + svrOpts: []allsrv.SvrOptFn{allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd")}, + inputs: inputs{ + req: newJSONReq("POST", "/v1/foos", + newJSONBody(t, allsrv.ReqCreateFooV1{ + Type: "foo", + Attrs: allsrv.FooAttrs{ + Name: "first-foo", + Note: "some note", + }, + }), + withBasicAuth("dodgers@stink.com", "PaSsWoRd"), + ), }, - }) + want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { + assert.Equal(t, http.StatusCreated, rec.Code) + expectData[allsrv.FooAttrs](t, rec.Body, allsrv.Data[allsrv.FooAttrs]{ + Type: "foo", + ID: "1", + Attrs: allsrv.FooAttrs{ + Name: "first-foo", + Note: "some note", + CreatedAt: start.Format(time.RFC3339), + }, + }) - got, err := db.ReadFoo(context.TODO(), "id1") - require.NoError(t, err) + dbHasFoo(t, db, allsrv.Foo{ + ID: "1", + Name: "first-foo", + Note: "some note", + CreatedAt: start, + }) + }, + }, + { + name: "when missing required auth should fail", + svrOpts: []allsrv.SvrOptFn{allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd")}, + inputs: inputs{ + req: newJSONReq("POST", "/v1/foos", + newJSONBody(t, allsrv.ReqCreateFooV1{ + Type: "foo", + Attrs: allsrv.FooAttrs{ + Name: "first-foo", + Note: "some note", + }, + }), + withBasicAuth("dodgers@stink.com", "WRONGO"), + ), + }, + want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { + assert.Equal(t, http.StatusUnauthorized, rec.Code) + expectErrs(t, rec.Body, allsrv.RespErr{ + Status: http.StatusUnauthorized, + Code: 4, + Msg: "unauthorized access", + Source: &allsrv.RespErrSource{ + Header: "Authorization", + }, + }) - want := allsrv.Foo{ - ID: "id1", - Name: "first-foo", - Note: "some note", - CreatedAt: time.Time{}.UTC().Add(time.Hour), - } - assert.Equal(t, want, got) - }) - - t.Run("when missing required auth should fail", func(t *testing.T) { - var svr http.Handler = allsrv.NewServerV2( - new(allsrv.InmemDB), - allsrv.WithBasicAuthV2("dodgers@stink.com", "PaSsWoRd"), - allsrv.WithMetrics(newTestMetrics(t)), - allsrv.WithIDFn(func() string { - return "id1" - }), - allsrv.WithNowFn(func() time.Time { - return time.Time{}.UTC().Add(time.Hour) - }), - ) - - req := httptest.NewRequest("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ - Name: "first-foo", - Note: "some note", - })) - req.Header.Set("Content-Type", "application/json") - req.SetBasicAuth("dodgers@stink.com", "WRONGO") - rec := httptest.NewRecorder() - - svr.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) - expectErrs(t, rec.Body, func(t *testing.T, got []allsrv.RespErr) { - require.Len(t, got, 1) - - want := allsrv.RespErr{ - Status: 401, - Code: 4, - Msg: "unauthorized access", - Source: &allsrv.RespErrSource{ - Header: "Authorization", - }, + _, err := db.ReadFoo(context.TODO(), "1") + require.Error(t, err) + }, + }, + { + name: "when creating foo with name that collides with existing should fail", + prepare: createFoos(allsrv.Foo{ID: "9000", Name: "existing-foo"}), + inputs: inputs{ + req: newJSONReq("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ + Type: "foo", + Attrs: allsrv.FooAttrs{ + Name: "existing-foo", + Note: "some note", + }, + })), + }, + want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { + assert.Equal(t, http.StatusConflict, rec.Code) + expectErrs(t, rec.Body, allsrv.RespErr{ + Status: http.StatusConflict, + Code: 1, + Msg: "foo existing-foo exists", + Source: &allsrv.RespErrSource{ + Pointer: "/data/attributes/name", + }, + }) + + _, err := db.ReadFoo(context.TODO(), "1") + require.Error(t, err) + }, + }, + { + name: "when creating foo with invalid resource type should fail", + inputs: inputs{ + req: newJSONReq("POST", "/v1/foos", newJSONBody(t, allsrv.ReqCreateFooV1{ + Type: "WRONGO", + Attrs: allsrv.FooAttrs{ + Name: "first-foo", + Note: "some note", + }, + })), + }, + want: func(t *testing.T, rec *httptest.ResponseRecorder, db allsrv.DB) { + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) + expectErrs(t, rec.Body, allsrv.RespErr{ + Status: http.StatusUnprocessableEntity, + Code: 2, + Msg: "type must be foo", + Source: &allsrv.RespErrSource{ + Pointer: "/data/type", + }, + }) + + _, err := db.ReadFoo(context.TODO(), "1") + require.Error(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := new(allsrv.InmemDB) + + if tt.prepare != nil { + tt.prepare(t, db) } - assert.Equal(t, want, got[0]) + + defaultOpts := []allsrv.SvrOptFn{ + allsrv.WithIDFn(newIDGen(1, 1)), + allsrv.WithNowFn(newNowFn(start, time.Hour)), + allsrv.WithMetrics(newTestMetrics(t)), + } + opts := append(defaultOpts, tt.svrOpts...) + + rec := httptest.NewRecorder() + + svr := allsrv.NewServerV2(db, opts...) + svr.ServeHTTP(rec, tt.inputs.req) + + tt.want(t, rec, db) }) - }) + } }) } -func expectErrs(t *testing.T, r io.Reader, fn func(t *testing.T, got []allsrv.RespErr)) { +func expectErrs(t *testing.T, r io.Reader, want ...allsrv.RespErr) { t.Helper() expectJSONBody(t, r, func(t *testing.T, got allsrv.RespResourceBody[any]) { + t.Helper() + require.Nil(t, got.Data) require.NotEmpty(t, got.Errs) - fn(t, got.Errs) + assert.Equal(t, want, got.Errs) }) } -func expectData[Attrs any | []any](t *testing.T, r io.Reader, want allsrv.RespData[Attrs]) { +func expectData[Attrs any | []any](t *testing.T, r io.Reader, want allsrv.Data[Attrs]) { t.Helper() expectJSONBody(t, r, func(t *testing.T, got allsrv.RespResourceBody[Attrs]) { + t.Helper() + require.Empty(t, got.Errs) require.NotNil(t, got.Data) assert.Equal(t, want, *got.Data) }) } + +func dbHasFoo(t *testing.T, db allsrv.DB, want allsrv.Foo) { + t.Helper() + + got, err := db.ReadFoo(context.TODO(), want.ID) + require.NoError(t, err) + + assert.Equal(t, want, got) +} + +func createFoos(foos ...allsrv.Foo) func(t *testing.T, db allsrv.DB) { + return func(t *testing.T, db allsrv.DB) { + t.Helper() + + for _, f := range foos { + err := db.CreateFoo(context.TODO(), f) + require.NoError(t, err) + } + } +} + +func newJSONReq(method, target string, body io.Reader, opts ...func(*http.Request)) *http.Request { + req := httptest.NewRequest(method, target, body) + req.Header.Set("Content-Type", "application/json") + for _, o := range opts { + o(req) + } + return req +} + +func withBasicAuth(user, pass string) func(*http.Request) { + return func(req *http.Request) { + req.SetBasicAuth(user, pass) + } +} + +func newIDGen(start, incr int) func() string { + return func() string { + id := strconv.Itoa(start) + start += incr + return id + } +} + +func newNowFn(start time.Time, incr time.Duration) func() time.Time { + return func() time.Time { + t := start + start = start.Add(incr) + return t + } +}