From efaad803217781b99bfeec7b4b70ded7790146d8 Mon Sep 17 00:00:00 2001 From: Antoine Pourchet Date: Tue, 9 Jan 2024 10:39:25 -0700 Subject: [PATCH] Added some helpful default behavior and types (#3) --- after.go | 2 +- after_test.go | 4 +- before.go | 2 +- before_test.go | 2 +- constructor.go | 17 +---- context.go | 6 +- context_test.go | 12 ++-- decoder.go | 4 +- error.go | 31 +++++++++ examples/simple/main.go | 2 +- examples/standard/main.go | 143 ++++++++++++++++++++++++++++++++++++++ go.mod | 7 +- internal/defaults.go | 2 +- internal/deref.go | 4 +- internal/generate.go | 2 +- main.go | 4 +- main_test.go | 4 +- standard.go | 73 +++++++++++++++++++ utils.go | 2 +- wrapper.go | 10 +-- wrapper_test.go | 10 +-- 21 files changed, 291 insertions(+), 52 deletions(-) create mode 100644 error.go create mode 100644 examples/standard/main.go create mode 100644 standard.go diff --git a/after.go b/after.go index 0c82154..0c95a20 100644 --- a/after.go +++ b/after.go @@ -8,7 +8,7 @@ type afterFn struct { outTypes []reflect.Type } -func newAfter(fn interface{}) (afterFn, error) { +func newAfter(fn any) (afterFn, error) { val := reflect.ValueOf(fn) fnType := val.Type() inTypes, outTypes := []reflect.Type{}, []reflect.Type{} diff --git a/after_test.go b/after_test.go index 5e54472..521889f 100644 --- a/after_test.go +++ b/after_test.go @@ -15,7 +15,7 @@ func TestAfter(t *testing.T) { rw := httptest.NewRecorder() ctx := newRunCtx(rw, req, nopConstructor) - after, err := newAfter(func(w http.ResponseWriter, res interface{}) { + after, err := newAfter(func(w http.ResponseWriter, res any) { require.NotNil(t, w) w.WriteHeader(http.StatusOK) }) @@ -31,7 +31,7 @@ func TestAfter(t *testing.T) { ctx := newRunCtx(rw, req, nopConstructor) ctx.response = reflect.ValueOf(1) - after, err := newAfter(func(w http.ResponseWriter, res interface{}) { + after, err := newAfter(func(w http.ResponseWriter, res any) { require.NotNil(t, w) require.Equal(t, res, 1) }) diff --git a/before.go b/before.go index 218f796..784f3b0 100644 --- a/before.go +++ b/before.go @@ -8,7 +8,7 @@ type beforeFn struct { outTypes []reflect.Type } -func newBefore(fn interface{}) (beforeFn, error) { +func newBefore(fn any) (beforeFn, error) { val := reflect.ValueOf(fn) fnType := val.Type() inTypes, outTypes := []reflect.Type{}, []reflect.Type{} diff --git a/before_test.go b/before_test.go index 0ddfafe..42ffbd5 100644 --- a/before_test.go +++ b/before_test.go @@ -25,7 +25,7 @@ func TestBefore(t *testing.T) { }) t.Run("empty interface", func(t *testing.T) { - _, err := newBefore(func(in interface{}) error { + _, err := newBefore(func(in any) error { return fmt.Errorf("error") }) require.Error(t, err) diff --git a/constructor.go b/constructor.go index ea508db..ec08f98 100644 --- a/constructor.go +++ b/constructor.go @@ -4,21 +4,8 @@ import "net/http" // Constructor is the function signature for unmarshalling an http request into // an object. -type Constructor func(http.ResponseWriter, *http.Request, interface{}) error +type Constructor func(http.ResponseWriter, *http.Request, any) error // EmptyConstructor is the default constructor for new wrappers. // It is a no-op. -func EmptyConstructor(http.ResponseWriter, *http.Request, interface{}) error { return nil } - -// The StandardConstructor decodes the request using the following: -// - cookies -// - query params -// - path params -// - headers -// - JSON decoding of the body -func StandardConstructor() Constructor { - decoder := NewDecoder() - return func(rw http.ResponseWriter, req *http.Request, obj interface{}) error { - return decoder.Decode(req, obj) - } -} +func EmptyConstructor(http.ResponseWriter, *http.Request, any) error { return nil } diff --git a/context.go b/context.go index c2b2fdb..6902278 100644 --- a/context.go +++ b/context.go @@ -18,7 +18,7 @@ type runctx struct { type param struct { t reflect.Type v reflect.Value - i interface{} + i any } func newRunCtx( @@ -30,7 +30,7 @@ func newRunCtx( req: req, rw: rw, cons: cons, - response: reflect.Zero(reflect.TypeOf((*interface{})(nil)).Elem()), + response: reflect.Zero(reflect.TypeOf((*any)(nil)).Elem()), results: map[reflect.Type]param{}, resultSlice: []param{}, } @@ -39,7 +39,7 @@ func newRunCtx( return ctx } -func (ctx *runctx) provide(i interface{}) { +func (ctx *runctx) provide(i any) { if i == nil { return } diff --git a/context_test.go b/context_test.go index 9d16388..df2d00b 100644 --- a/context_test.go +++ b/context_test.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -12,19 +12,19 @@ import ( "github.com/stretchr/testify/require" ) -func nopConstructor(http.ResponseWriter, *http.Request, interface{}) error { return nil } +func nopConstructor(http.ResponseWriter, *http.Request, any) error { return nil } -func jsonBodyConstructor(_ http.ResponseWriter, req *http.Request, obj interface{}) error { - body, err := ioutil.ReadAll(req.Body) +func jsonBodyConstructor(_ http.ResponseWriter, req *http.Request, obj any) error { + body, err := io.ReadAll(req.Body) if err != nil { return err } err = json.Unmarshal(body, obj) - req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + req.Body = io.NopCloser(bytes.NewBuffer(body)) return err } -func failedConstructor(http.ResponseWriter, *http.Request, interface{}) error { +func failedConstructor(http.ResponseWriter, *http.Request, any) error { return fmt.Errorf("error") } diff --git a/decoder.go b/decoder.go index e5a134a..e8b9204 100644 --- a/decoder.go +++ b/decoder.go @@ -34,7 +34,7 @@ type Decoder struct { // DecodeFunc is the function signature for decoding a request into an // object. -type DecodeFunc func(req *http.Request, obj interface{}) error +type DecodeFunc func(req *http.Request, obj any) error // NewDecoder returns a new decoder with sensible defaults for the // DecodeBody, Header and Query functions. @@ -60,7 +60,7 @@ func NewDecoder() *Decoder { // request struct // The Limit field will come from the query string // The Resource field will come from the resource value of the path -func (d *Decoder) Decode(req *http.Request, obj interface{}) error { +func (d *Decoder) Decode(req *http.Request, obj any) error { if err := d.DecodeBody(req, obj); err != nil { return err } diff --git a/error.go b/error.go new file mode 100644 index 0000000..d541dd9 --- /dev/null +++ b/error.go @@ -0,0 +1,31 @@ +package httpwrap + +import ( + "fmt" + "io" +) + +// HTTPError implements both the HTTPResponse interface and the standard error +// interface. +type HTTPError struct { + code int + body string +} + +func NewHTTPError(code int, format string, args ...any) HTTPError { + return HTTPError{ + code: code, + body: fmt.Sprintf(format, args...), + } +} + +func (err HTTPError) Error() string { + return fmt.Sprintf("http error: %d: %s", err.code, err.body) +} + +func (err HTTPError) StatusCode() int { return err.code } + +func (err HTTPError) WriteBody(writer io.Writer) error { + _, writeError := io.WriteString(writer, err.body) + return writeError +} diff --git a/examples/simple/main.go b/examples/simple/main.go index b480f14..7582558 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -51,7 +51,7 @@ func (mw *Middlewares) checkAPICreds(creds APICredentials) error { // sendResponse writes out the response to the client given the output // of the handler. -func (mw *Middlewares) sendResponse(w http.ResponseWriter, res interface{}, err error) { +func (mw *Middlewares) sendResponse(w http.ResponseWriter, res any, err error) { switch err { case ErrBadAPICreds: w.WriteHeader(http.StatusUnauthorized) diff --git a/examples/standard/main.go b/examples/standard/main.go new file mode 100644 index 0000000..4623e0c --- /dev/null +++ b/examples/standard/main.go @@ -0,0 +1,143 @@ +package main + +import ( + "log" + "net/http" + + "github.com/apourchet/httpwrap" + "github.com/gorilla/mux" +) + +// ***** Type Definitions ***** +type APICredentials struct { + Key string `http:"header=X-PETSTORE-KEY"` +} + +type PetStoreHandler struct { + pets map[string]*Pet +} + +type Pet struct { + Name string `json:"name"` + Category int `json:"category"` + PhotoURLs []string `json:"photoUrls"` +} + +func (pet Pet) IsInCategories(categories []int) bool { + for _, c := range categories { + if pet.Category == c { + return true + } + } + return false +} + +var ErrBadAPICreds = httpwrap.NewHTTPError(http.StatusUnauthorized, "bad API credentials") +var ErrPetConflict = httpwrap.NewHTTPError(http.StatusConflict, "duplicate pet") +var ErrPetNotFound = httpwrap.NewHTTPError(http.StatusNotFound, "pet not found") + +// ***** Middleware Definitions ***** +// checkAPICreds checks the api credentials passed into the request. +func checkAPICreds(creds APICredentials) error { + if creds.Key == "my-secret-key" { + return nil + } + return ErrBadAPICreds +} + +// ***** Handler Methods ***** +// AddPet adds a new pet to the store. +func (h *PetStoreHandler) AddPet(pet Pet) error { + if _, found := h.pets[pet.Name]; found { + return ErrPetConflict + } + h.pets[pet.Name] = &pet + return nil +} + +// GetPets returns the list of pets in the store. +func (h *PetStoreHandler) GetPets() (res []Pet, err error) { + res = make([]Pet, 0, len(h.pets)) + for _, pet := range h.pets { + res = append(res, *pet) + } + return res, nil +} + +type GetByNameParams struct { + Name string `http:"segment=name"` +} + +// GetPetByName returns a pet given its name. +func (h *PetStoreHandler) GetPetByName(params GetByNameParams) (pet *Pet, err error) { + pet, found := h.pets[params.Name] + if !found { + return nil, ErrPetNotFound + } + return pet, nil +} + +type UpdateParams struct { + Name string `http:"segment=name"` + + Category *int `json:"category"` + PhotoURLs *[]string `json:"photoUrls"` +} + +// UpdatePet updates a pet given its name. +func (h *PetStoreHandler) UpdatePet(params UpdateParams) error { + pet, found := h.pets[params.Name] + if !found { + return ErrPetNotFound + } + + if params.Category != nil { + pet.Category = *params.Category + } + if params.PhotoURLs != nil { + pet.PhotoURLs = *params.PhotoURLs + } + return nil +} + +type FilterPetParams struct { + Categories *[]int `http:"query=categories"` + HasPhotos *bool `http:"query=hasPhotos"` +} + +// FilterPets returns a list of pets that match the parameters given. +func (h *PetStoreHandler) FilterPets(params FilterPetParams) []Pet { + res := []Pet{} + for _, pet := range h.pets { + if params.HasPhotos != nil && len(pet.PhotoURLs) == 0 { + continue + } else if params.Categories != nil && !pet.IsInCategories(*params.Categories) { + continue + } + res = append(res, *pet) + } + return res +} + +func (h *PetStoreHandler) ClearStore() error { + h.pets = map[string]*Pet{} + return nil +} + +func main() { + router := mux.NewRouter() + + handler := &PetStoreHandler{pets: map[string]*Pet{}} + wrapper := httpwrap.NewStandardWrapper().Before(checkAPICreds) + + router.Handle("/pets", wrapper.Wrap(handler.AddPet)).Methods("POST") + router.Handle("/pets", wrapper.Wrap(handler.GetPets)).Methods("GET") + router.Handle("/pets/filtered", wrapper.Wrap(handler.FilterPets)).Methods("GET") + router.Handle("/pets/{name}", wrapper.Wrap(handler.GetPetByName)).Methods("GET") + router.Handle("/pets/{name}", wrapper.Wrap(handler.UpdatePet)).Methods("PUT") + + router.Handle("/clear", wrapper.Wrap(handler.ClearStore)).Methods("POST") + + http.Handle("/", router) + log.Fatal(http.ListenAndServe(":3000", router)) +} diff --git a/go.mod b/go.mod index 4da03a9..c22e781 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,13 @@ module github.com/apourchet/httpwrap -go 1.12 +go 1.20 require ( github.com/gorilla/mux v1.7.2 github.com/stretchr/testify v1.3.0 ) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect +) diff --git a/internal/defaults.go b/internal/defaults.go index 74e5381..bcedf80 100644 --- a/internal/defaults.go +++ b/internal/defaults.go @@ -21,7 +21,7 @@ var ( // DecodeBody uses a json decoder to decode the body of the request // into the target object. -func DecodeBody(req *http.Request, obj interface{}) error { +func DecodeBody(req *http.Request, obj any) error { buf := &bytes.Buffer{} defer func() { req.Body = ioutil.NopCloser(buf) }() err := json.NewDecoder(io.TeeReader(req.Body, buf)).Decode(obj) diff --git a/internal/deref.go b/internal/deref.go index ce44394..89ee768 100644 --- a/internal/deref.go +++ b/internal/deref.go @@ -5,7 +5,7 @@ import "reflect" // DerefType dereferences the type of the object if it is // a pointer or an interface. // Returns whether the final type is a struct. -func DerefType(obj interface{}) (reflect.Type, bool) { +func DerefType(obj any) (reflect.Type, bool) { st := reflect.TypeOf(obj) for st.Kind() == reflect.Ptr || st.Kind() == reflect.Interface { st = st.Elem() @@ -17,7 +17,7 @@ func DerefType(obj interface{}) (reflect.Type, bool) { // DerefValue dereferences the value of the object until // it is no longer a pointer or an interface. Also returns // false if the underlying value is Nil. -func DerefValue(obj interface{}) (reflect.Value, bool) { +func DerefValue(obj any) (reflect.Value, bool) { v := reflect.ValueOf(obj) for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { if v.IsNil() { // If the chain ends in a nil, skip this diff --git a/internal/generate.go b/internal/generate.go index 2290878..f089336 100644 --- a/internal/generate.go +++ b/internal/generate.go @@ -8,7 +8,7 @@ import ( "strings" ) -// GenVal generates an interface{} from the string values given. +// GenVal generates an any from the string values given. func GenVal(t reflect.Type, value string, values ...string) (reflect.Value, error) { if len(values) > 0 || t.Kind() == reflect.Slice { return genVals(t, append([]string{value}, values...)) diff --git a/main.go b/main.go index 97bce72..be6e36b 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,7 @@ type mainFn struct { outTypes []reflect.Type } -func newMain(fn interface{}) (mainFn, error) { +func newMain(fn any) (mainFn, error) { val := reflect.ValueOf(fn) fnType := val.Type() inTypes, outTypes := []reflect.Type{}, []reflect.Type{} @@ -30,7 +30,7 @@ func newMain(fn interface{}) (mainFn, error) { }, nil } -func (fn mainFn) run(ctx *runctx) interface{} { +func (fn mainFn) run(ctx *runctx) any { inputs, err := ctx.generate(fn.inTypes) if err != nil { return nil diff --git a/main_test.go b/main_test.go index 5d195f4..2507f37 100644 --- a/main_test.go +++ b/main_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(t *testing.T) { +func TestInternals(t *testing.T) { t.Run("simple", func(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) rw := httptest.NewRecorder() @@ -23,7 +23,7 @@ func TestMain(t *testing.T) { }) t.Run("empty interface input", func(t *testing.T) { - _, err := newMain(func(in interface{}) error { + _, err := newMain(func(in any) error { return nil }) require.Error(t, err) diff --git a/standard.go b/standard.go new file mode 100644 index 0000000..922915d --- /dev/null +++ b/standard.go @@ -0,0 +1,73 @@ +package httpwrap + +import ( + "encoding/json" + "io" + "log" + "net/http" +) + +type HTTPResponse interface { + StatusCode() int + WriteBody(io.Writer) error +} + +// The StandardConstructor decodes the request using the following: +// - cookies +// - query params +// - path params +// - headers +// - JSON decoding of the body +func StandardConstructor() Constructor { + decoder := NewDecoder() + return func(rw http.ResponseWriter, req *http.Request, obj any) error { + return decoder.Decode(req, obj) + } +} + +// StandardResponseWriter will try to cast the error and response objects to the +// HTTPResponse interface and use them to send the response to the client. +// By default, it will send a 200 OK and encode the response object as JSON. +func StandardResponseWriter() func(w http.ResponseWriter, res any, err error) { + return func(w http.ResponseWriter, res any, err error) { + if err != nil { + if cast, ok := err.(HTTPResponse); ok { + w.WriteHeader(cast.StatusCode()) + if sendError := cast.WriteBody(w); sendError != nil { + log.Println("error writing response:", sendError) + } + } else { + w.WriteHeader(http.StatusInternalServerError) + if _, sendError := w.Write([]byte(err.Error() + "\n")); sendError != nil { + log.Println("error writing response:", sendError) + } + } + return + } + + if cast, ok := res.(HTTPResponse); ok { + w.WriteHeader(cast.StatusCode()) + if sendError := cast.WriteBody(w); sendError != nil { + log.Println("error writing response:", sendError) + } + return + } + + w.WriteHeader(http.StatusOK) + encoder := json.NewEncoder(w) + encoder.SetIndent("", " ") + if sendError := encoder.Encode(res); sendError != nil { + log.Println("Error writing response:", sendError) + } + } +} + +// NewStandardWrapper returns a new wrapper using the StandardConstructor and the +// StandardResponseWriter. +func NewStandardWrapper() Wrapper { + constructor := StandardConstructor() + responseWriter := StandardResponseWriter() + return New(). + WithConstruct(constructor). + Finally(responseWriter) +} diff --git a/utils.go b/utils.go index 6e95c42..8f99dd7 100644 --- a/utils.go +++ b/utils.go @@ -15,7 +15,7 @@ func isError(t reflect.Type) bool { return t.Implements(_errorType) } -func typesOf(fn interface{}) ([]reflect.Type, []reflect.Type) { +func typesOf(fn any) ([]reflect.Type, []reflect.Type) { val := reflect.ValueOf(fn) fnType := val.Type() inTypes, outTypes := []reflect.Type{}, []reflect.Type{} diff --git a/wrapper.go b/wrapper.go index 19e1bd3..dbc39c8 100644 --- a/wrapper.go +++ b/wrapper.go @@ -28,7 +28,7 @@ func (w Wrapper) WithConstruct(cons Constructor) Wrapper { // Before adds a new function that will execute before the main handler. The chain // of befores will end if a before returns a non-nil error value. -func (w Wrapper) Before(fns ...interface{}) Wrapper { +func (w Wrapper) Before(fns ...any) Wrapper { befores := make([]beforeFn, len(w.befores)+len(fns)) copy(befores, w.befores) for i, before := range fns { @@ -43,7 +43,7 @@ func (w Wrapper) Before(fns ...interface{}) Wrapper { } // Finally sets the last function that will execute during a request. -func (w Wrapper) Finally(fn interface{}) Wrapper { +func (w Wrapper) Finally(fn any) Wrapper { after, err := newAfter(fn) if err != nil { panic(err) @@ -54,7 +54,7 @@ func (w Wrapper) Finally(fn interface{}) Wrapper { // Wrap sets the main handling function to process requests. This Wrap function must // be called to get an `http.Handler` type. -func (w Wrapper) Wrap(fn interface{}) Handler { +func (w Wrapper) Wrap(fn any) Handler { main, err := newMain(fn) if err != nil { panic(err) @@ -72,13 +72,13 @@ type Handler struct { } // Before adds the before functions to the underlying Wrapper. -func (h Handler) Before(fns ...interface{}) Handler { +func (h Handler) Before(fns ...any) Handler { h.Wrapper = h.Wrapper.Before(fns...) return h } // Finally sets the `finally` function of the underlying Wrapper. -func (h Handler) Finally(fn interface{}) Handler { +func (h Handler) Finally(fn any) Handler { h.Wrapper = h.Wrapper.Finally(fn) return h } diff --git a/wrapper_test.go b/wrapper_test.go index ab6b398..0e27e29 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -48,7 +48,7 @@ func TestWrapper(t *testing.T) { type resp struct{ s string } handler := New(). WithConstruct(nopConstructor). - Finally(func(res interface{}, err error) { + Finally(func(res any, err error) { s := fmt.Sprintf("%v", res) require.True(t, strings.Contains(s, "response")) require.Error(t, err) @@ -91,7 +91,7 @@ func TestWrapper(t *testing.T) { Before(func(req *http.Request) (meta, error) { return meta{req.URL.Path}, fmt.Errorf("failed before") }). - Finally(func(rw http.ResponseWriter, m meta, res interface{}, err error) { + Finally(func(rw http.ResponseWriter, m meta, res any, err error) { require.NotNil(t, rw) require.Equal(t, "/test", m.path) require.Nil(t, res) @@ -119,7 +119,7 @@ func TestWrapper(t *testing.T) { require.NotNil(t, req.URL) return meta{req.URL.Path}, nil }). - Finally(func(rw http.ResponseWriter, m meta, res interface{}, err error) { + Finally(func(rw http.ResponseWriter, m meta, res any, err error) { require.NotNil(t, rw) require.Equal(t, "/test", m.path) require.Nil(t, res) @@ -217,7 +217,7 @@ func TestWrapper(t *testing.T) { Before(func() *myerr { return &myerr{} }). - Finally(func(rw http.ResponseWriter, res interface{}, err error) { + Finally(func(rw http.ResponseWriter, res any, err error) { require.NotNil(t, rw) require.Nil(t, res) require.Error(t, err) @@ -232,7 +232,7 @@ func TestWrapper(t *testing.T) { return nil }). Wrap(func() {}). - Finally(func(rw http.ResponseWriter, res interface{}, err error) { + Finally(func(rw http.ResponseWriter, res any, err error) { require.NotNil(t, rw) require.NoError(t, err) rw.WriteHeader(http.StatusCreated)