Skip to content

Commit

Permalink
fix: improved tests, fixed a few small issues
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Sep 6, 2023
1 parent b8d2334 commit 2014487
Show file tree
Hide file tree
Showing 10 changed files with 525 additions and 60 deletions.
8 changes: 6 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ func (r *api) Unmarshal(contentType string, data []byte, v any) error {
if end == -1 {
end = len(contentType)
}
f, ok := r.formats[contentType[start:end]]
ct := contentType[start:end]
if ct == "" {
// Default to assume JSON since this is an API.
ct = "application/json"
}
f, ok := r.formats[ct]
if !ok {
return fmt.Errorf("unknown content type: %s", contentType)
}
Expand All @@ -177,7 +182,6 @@ func (r *api) Negotiate(accept string) (string, error) {
}

func (a *api) Marshal(ctx Context, respKey string, ct string, v any) error {
// fmt.Println("marshaling", ct)
var err error

for _, t := range a.transformers {
Expand Down
16 changes: 16 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package huma

import (
"testing"

"github.com/go-chi/chi"
"github.com/stretchr/testify/assert"
)

func TestBlankConfig(t *testing.T) {
adapter := &testAdapter{chi.NewMux()}

assert.NotPanics(t, func() {
NewAPI(Config{}, adapter)
})
}
2 changes: 1 addition & 1 deletion conditional/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (p *Params) PreconditionFailed(etag string, modified time.Time) huma.Status
)
}

return huma.Status304NotModied()
return huma.Status304NotModified()
}

return nil
Expand Down
10 changes: 3 additions & 7 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ type StatusError interface {
Error() string
}

// Ensure the default error model satisfies these interfaces.
var _ StatusError = (*ErrorModel)(nil)
var _ ContentTypeFilter = (*ErrorModel)(nil)

// NewError creates a new instance of an error model with the given status code,
// message, and errors. If the error implements the `ErrorDetailer` interface,
// the error details will be used. Otherwise, the error message will be used.
Expand Down Expand Up @@ -149,9 +145,9 @@ func WriteErr(api API, ctx Context, status int, msg string, errs ...error) {
api.Marshal(ctx, strconv.Itoa(status), ct, err)
}

// Status304NotModied returns a 304. This is not really an error, but provides
// a way to send non-default responses.
func Status304NotModied() StatusError {
// Status304NotModified returns a 304. This is not really an error, but
// provides a way to send non-default responses.
func Status304NotModified() StatusError {
return NewError(http.StatusNotModified, "")
}

Expand Down
70 changes: 70 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package huma

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

// Ensure the default error model satisfies these interfaces.
var _ StatusError = (*ErrorModel)(nil)
var _ ContentTypeFilter = (*ErrorModel)(nil)

func TestError(t *testing.T) {
err := &ErrorModel{
Status: 400,
Detail: "test err",
}

// Add some children.
err.Add(&ErrorDetail{
Message: "test detail",
Location: "body.foo",
Value: "bar",
})

err.Add(fmt.Errorf("plain error"))

// Confirm errors were added.
assert.Equal(t, "test err", err.Error())
assert.Len(t, err.Errors, 2)
assert.Equal(t, "test detail (body.foo: bar)", err.Errors[0].Error())
assert.Equal(t, "plain error", err.Errors[1].Error())

// Ensure problem content types.
assert.Equal(t, "application/problem+json", err.ContentType("application/json"))
assert.Equal(t, "application/problem+cbor", err.ContentType("application/cbor"))
assert.Equal(t, "other", err.ContentType("other"))
}

func TestErrorResponses(t *testing.T) {
// NotModified has a slightly different signature.
assert.Equal(t, 304, Status304NotModified().GetStatus())

for _, item := range []struct {
constructor func(msg string, errs ...error) StatusError
expected int
}{
{Error400BadRequest, 400},
{Error401Unauthorized, 401},
{Error403Forbidden, 403},
{Error404NotFound, 404},
{Error405MethodNotAllowed, 405},
{Error406NotAcceptable, 406},
{Error409Conflict, 409},
{Error410Gone, 410},
{Error412PreconditionFailed, 412},
{Error415UnsupportedMediaType, 415},
{Error422UnprocessableEntity, 422},
{Error429TooManyRequests, 429},
{Error500InternalServerError, 500},
{Error501NotImplemented, 501},
{Error502BadGateway, 502},
{Error503ServiceUnavailable, 503},
{Error504GatewayTimeout, 504},
} {
err := item.constructor("test")
assert.Equal(t, item.expected, err.GetStatus())
}
}
21 changes: 10 additions & 11 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ var bodyCallbackType = reflect.TypeOf(func(Context) {})
// if possible. If not, it will not incur any allocations (unlike the stdlib
// `http.ResponseController`).
func SetReadDeadline(w http.ResponseWriter, deadline time.Time) error {
rw := w
for {
switch t := rw.(type) {
switch t := w.(type) {
case interface{ SetReadDeadline(time.Time) error }:
return t.SetReadDeadline(deadline)
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
w = t.Unwrap()
default:
return errDeadlineUnsupported
}
Expand Down Expand Up @@ -170,12 +169,13 @@ type findResult[T comparable] struct {
}

func (r *findResult[T]) every(current reflect.Value, path []int, v T, f func(reflect.Value, T)) {
if len(path) == 0 {
f(current, v)
return
}

switch current.Kind() {
case reflect.Struct:
if len(path) == 0 {
f(current, v)
return
}
r.every(reflect.Indirect(current.Field(path[0])), path[1:], v, f)
case reflect.Slice:
for j := 0; j < current.Len(); j++ {
Expand All @@ -186,10 +186,6 @@ func (r *findResult[T]) every(current reflect.Value, path []int, v T, f func(ref
r.every(reflect.Indirect(current.MapIndex(k)), path, v, f)
}
default:
if len(path) == 0 {
f(current, v)
return
}
panic("unsupported")
}
}
Expand Down Expand Up @@ -619,6 +615,9 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)

buf := bufPool.Get().(*bytes.Buffer)
reader := ctx.BodyReader()
if reader == nil {
reader = bytes.NewReader(nil)
}
if closer, ok := reader.(io.Closer); ok {
defer closer.Close()
}
Expand Down
Loading

0 comments on commit 2014487

Please sign in to comment.