Skip to content

Commit

Permalink
Added some helpful default behavior and types (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
apourchet authored Jan 9, 2024
1 parent 9e1f967 commit efaad80
Show file tree
Hide file tree
Showing 21 changed files with 291 additions and 52 deletions.
2 changes: 1 addition & 1 deletion after.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type afterFn struct {
outTypes []reflect.Type
}

func newAfter(fn interface{}) (afterFn, error) {
func newAfter(fn any) (afterFn, error) {
val := reflect.ValueOf(fn)
fnType := val.Type()
inTypes, outTypes := []reflect.Type{}, []reflect.Type{}
Expand Down
4 changes: 2 additions & 2 deletions after_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAfter(t *testing.T) {
rw := httptest.NewRecorder()
ctx := newRunCtx(rw, req, nopConstructor)

after, err := newAfter(func(w http.ResponseWriter, res interface{}) {
after, err := newAfter(func(w http.ResponseWriter, res any) {
require.NotNil(t, w)
w.WriteHeader(http.StatusOK)
})
Expand All @@ -31,7 +31,7 @@ func TestAfter(t *testing.T) {
ctx := newRunCtx(rw, req, nopConstructor)
ctx.response = reflect.ValueOf(1)

after, err := newAfter(func(w http.ResponseWriter, res interface{}) {
after, err := newAfter(func(w http.ResponseWriter, res any) {
require.NotNil(t, w)
require.Equal(t, res, 1)
})
Expand Down
2 changes: 1 addition & 1 deletion before.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type beforeFn struct {
outTypes []reflect.Type
}

func newBefore(fn interface{}) (beforeFn, error) {
func newBefore(fn any) (beforeFn, error) {
val := reflect.ValueOf(fn)
fnType := val.Type()
inTypes, outTypes := []reflect.Type{}, []reflect.Type{}
Expand Down
2 changes: 1 addition & 1 deletion before_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestBefore(t *testing.T) {
})

t.Run("empty interface", func(t *testing.T) {
_, err := newBefore(func(in interface{}) error {
_, err := newBefore(func(in any) error {
return fmt.Errorf("error")
})
require.Error(t, err)
Expand Down
17 changes: 2 additions & 15 deletions constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,8 @@ import "net/http"

// Constructor is the function signature for unmarshalling an http request into
// an object.
type Constructor func(http.ResponseWriter, *http.Request, interface{}) error
type Constructor func(http.ResponseWriter, *http.Request, any) error

// EmptyConstructor is the default constructor for new wrappers.
// It is a no-op.
func EmptyConstructor(http.ResponseWriter, *http.Request, interface{}) error { return nil }

// The StandardConstructor decodes the request using the following:
// - cookies
// - query params
// - path params
// - headers
// - JSON decoding of the body
func StandardConstructor() Constructor {
decoder := NewDecoder()
return func(rw http.ResponseWriter, req *http.Request, obj interface{}) error {
return decoder.Decode(req, obj)
}
}
func EmptyConstructor(http.ResponseWriter, *http.Request, any) error { return nil }
6 changes: 3 additions & 3 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type runctx struct {
type param struct {
t reflect.Type
v reflect.Value
i interface{}
i any
}

func newRunCtx(
Expand All @@ -30,7 +30,7 @@ func newRunCtx(
req: req,
rw: rw,
cons: cons,
response: reflect.Zero(reflect.TypeOf((*interface{})(nil)).Elem()),
response: reflect.Zero(reflect.TypeOf((*any)(nil)).Elem()),
results: map[reflect.Type]param{},
resultSlice: []param{},
}
Expand All @@ -39,7 +39,7 @@ func newRunCtx(
return ctx
}

func (ctx *runctx) provide(i interface{}) {
func (ctx *runctx) provide(i any) {
if i == nil {
return
}
Expand Down
12 changes: 6 additions & 6 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@ import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
)

func nopConstructor(http.ResponseWriter, *http.Request, interface{}) error { return nil }
func nopConstructor(http.ResponseWriter, *http.Request, any) error { return nil }

func jsonBodyConstructor(_ http.ResponseWriter, req *http.Request, obj interface{}) error {
body, err := ioutil.ReadAll(req.Body)
func jsonBodyConstructor(_ http.ResponseWriter, req *http.Request, obj any) error {
body, err := io.ReadAll(req.Body)
if err != nil {
return err
}
err = json.Unmarshal(body, obj)
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
req.Body = io.NopCloser(bytes.NewBuffer(body))
return err
}

func failedConstructor(http.ResponseWriter, *http.Request, interface{}) error {
func failedConstructor(http.ResponseWriter, *http.Request, any) error {
return fmt.Errorf("error")
}

Expand Down
4 changes: 2 additions & 2 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Decoder struct {

// DecodeFunc is the function signature for decoding a request into an
// object.
type DecodeFunc func(req *http.Request, obj interface{}) error
type DecodeFunc func(req *http.Request, obj any) error

// NewDecoder returns a new decoder with sensible defaults for the
// DecodeBody, Header and Query functions.
Expand All @@ -60,7 +60,7 @@ func NewDecoder() *Decoder {
// request struct
// The Limit field will come from the query string
// The Resource field will come from the resource value of the path
func (d *Decoder) Decode(req *http.Request, obj interface{}) error {
func (d *Decoder) Decode(req *http.Request, obj any) error {
if err := d.DecodeBody(req, obj); err != nil {
return err
}
Expand Down
31 changes: 31 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package httpwrap

import (
"fmt"
"io"
)

// HTTPError implements both the HTTPResponse interface and the standard error
// interface.
type HTTPError struct {
code int
body string
}

func NewHTTPError(code int, format string, args ...any) HTTPError {
return HTTPError{
code: code,
body: fmt.Sprintf(format, args...),
}
}

func (err HTTPError) Error() string {
return fmt.Sprintf("http error: %d: %s", err.code, err.body)
}

func (err HTTPError) StatusCode() int { return err.code }

func (err HTTPError) WriteBody(writer io.Writer) error {
_, writeError := io.WriteString(writer, err.body)
return writeError
}
2 changes: 1 addition & 1 deletion examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (mw *Middlewares) checkAPICreds(creds APICredentials) error {

// sendResponse writes out the response to the client given the output
// of the handler.
func (mw *Middlewares) sendResponse(w http.ResponseWriter, res interface{}, err error) {
func (mw *Middlewares) sendResponse(w http.ResponseWriter, res any, err error) {
switch err {
case ErrBadAPICreds:
w.WriteHeader(http.StatusUnauthorized)
Expand Down
143 changes: 143 additions & 0 deletions examples/standard/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package main

import (
"log"
"net/http"

"github.com/apourchet/httpwrap"
"github.com/gorilla/mux"
)

// ***** Type Definitions *****
type APICredentials struct {
Key string `http:"header=X-PETSTORE-KEY"`
}

type PetStoreHandler struct {
pets map[string]*Pet
}

type Pet struct {
Name string `json:"name"`
Category int `json:"category"`
PhotoURLs []string `json:"photoUrls"`
}

func (pet Pet) IsInCategories(categories []int) bool {
for _, c := range categories {
if pet.Category == c {
return true
}
}
return false
}

var ErrBadAPICreds = httpwrap.NewHTTPError(http.StatusUnauthorized, "bad API credentials")
var ErrPetConflict = httpwrap.NewHTTPError(http.StatusConflict, "duplicate pet")
var ErrPetNotFound = httpwrap.NewHTTPError(http.StatusNotFound, "pet not found")

// ***** Middleware Definitions *****
// checkAPICreds checks the api credentials passed into the request.
func checkAPICreds(creds APICredentials) error {
if creds.Key == "my-secret-key" {
return nil
}
return ErrBadAPICreds
}

// ***** Handler Methods *****
// AddPet adds a new pet to the store.
func (h *PetStoreHandler) AddPet(pet Pet) error {
if _, found := h.pets[pet.Name]; found {
return ErrPetConflict
}
h.pets[pet.Name] = &pet
return nil
}

// GetPets returns the list of pets in the store.
func (h *PetStoreHandler) GetPets() (res []Pet, err error) {
res = make([]Pet, 0, len(h.pets))
for _, pet := range h.pets {
res = append(res, *pet)
}
return res, nil
}

type GetByNameParams struct {
Name string `http:"segment=name"`
}

// GetPetByName returns a pet given its name.
func (h *PetStoreHandler) GetPetByName(params GetByNameParams) (pet *Pet, err error) {
pet, found := h.pets[params.Name]
if !found {
return nil, ErrPetNotFound
}
return pet, nil
}

type UpdateParams struct {
Name string `http:"segment=name"`

Category *int `json:"category"`
PhotoURLs *[]string `json:"photoUrls"`
}

// UpdatePet updates a pet given its name.
func (h *PetStoreHandler) UpdatePet(params UpdateParams) error {
pet, found := h.pets[params.Name]
if !found {
return ErrPetNotFound
}

if params.Category != nil {
pet.Category = *params.Category
}
if params.PhotoURLs != nil {
pet.PhotoURLs = *params.PhotoURLs
}
return nil
}

type FilterPetParams struct {
Categories *[]int `http:"query=categories"`
HasPhotos *bool `http:"query=hasPhotos"`
}

// FilterPets returns a list of pets that match the parameters given.
func (h *PetStoreHandler) FilterPets(params FilterPetParams) []Pet {
res := []Pet{}
for _, pet := range h.pets {
if params.HasPhotos != nil && len(pet.PhotoURLs) == 0 {
continue
} else if params.Categories != nil && !pet.IsInCategories(*params.Categories) {
continue
}
res = append(res, *pet)
}
return res
}

func (h *PetStoreHandler) ClearStore() error {
h.pets = map[string]*Pet{}
return nil
}

func main() {
router := mux.NewRouter()

handler := &PetStoreHandler{pets: map[string]*Pet{}}
wrapper := httpwrap.NewStandardWrapper().Before(checkAPICreds)

router.Handle("/pets", wrapper.Wrap(handler.AddPet)).Methods("POST")
router.Handle("/pets", wrapper.Wrap(handler.GetPets)).Methods("GET")
router.Handle("/pets/filtered", wrapper.Wrap(handler.FilterPets)).Methods("GET")
router.Handle("/pets/{name}", wrapper.Wrap(handler.GetPetByName)).Methods("GET")
router.Handle("/pets/{name}", wrapper.Wrap(handler.UpdatePet)).Methods("PUT")

router.Handle("/clear", wrapper.Wrap(handler.ClearStore)).Methods("POST")

http.Handle("/", router)
log.Fatal(http.ListenAndServe(":3000", router))
}
7 changes: 6 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
module github.com/apourchet/httpwrap

go 1.12
go 1.20

require (
github.com/gorilla/mux v1.7.2
github.com/stretchr/testify v1.3.0
)

require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
)
2 changes: 1 addition & 1 deletion internal/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ var (

// DecodeBody uses a json decoder to decode the body of the request
// into the target object.
func DecodeBody(req *http.Request, obj interface{}) error {
func DecodeBody(req *http.Request, obj any) error {
buf := &bytes.Buffer{}
defer func() { req.Body = ioutil.NopCloser(buf) }()
err := json.NewDecoder(io.TeeReader(req.Body, buf)).Decode(obj)
Expand Down
4 changes: 2 additions & 2 deletions internal/deref.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "reflect"
// DerefType dereferences the type of the object if it is
// a pointer or an interface.
// Returns whether the final type is a struct.
func DerefType(obj interface{}) (reflect.Type, bool) {
func DerefType(obj any) (reflect.Type, bool) {
st := reflect.TypeOf(obj)
for st.Kind() == reflect.Ptr || st.Kind() == reflect.Interface {
st = st.Elem()
Expand All @@ -17,7 +17,7 @@ func DerefType(obj interface{}) (reflect.Type, bool) {
// DerefValue dereferences the value of the object until
// it is no longer a pointer or an interface. Also returns
// false if the underlying value is Nil.
func DerefValue(obj interface{}) (reflect.Value, bool) {
func DerefValue(obj any) (reflect.Value, bool) {
v := reflect.ValueOf(obj)
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
if v.IsNil() { // If the chain ends in a nil, skip this
Expand Down
2 changes: 1 addition & 1 deletion internal/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"
)

// GenVal generates an interface{} from the string values given.
// GenVal generates an any from the string values given.
func GenVal(t reflect.Type, value string, values ...string) (reflect.Value, error) {
if len(values) > 0 || t.Kind() == reflect.Slice {
return genVals(t, append([]string{value}, values...))
Expand Down
Loading

0 comments on commit efaad80

Please sign in to comment.