Skip to content

Commit

Permalink
fix: validation bugs, add a bunch of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Aug 25, 2023
1 parent 7f396d0 commit 2ce9dfa
Show file tree
Hide file tree
Showing 5 changed files with 908 additions and 33 deletions.
146 changes: 146 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,160 @@
package huma

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"

"github.com/danielgtaylor/huma/v2/queryparam"
"github.com/go-chi/chi"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/assert"
)

type testContext struct {
r *http.Request
w http.ResponseWriter
}

func (ctx *testContext) GetMatched() string {
return chi.RouteContext(ctx.r.Context()).RoutePattern()
}

func (ctx *testContext) GetContext() context.Context {
return ctx.r.Context()
}

func (ctx *testContext) GetURL() url.URL {
return *ctx.r.URL
}

func (ctx *testContext) GetParam(name string) string {
return chi.URLParam(ctx.r, name)
}

func (ctx *testContext) GetQuery(name string) string {
return queryparam.Get(ctx.r.URL.RawQuery, name)
}

func (ctx *testContext) GetHeader(name string) string {
return ctx.r.Header.Get(name)
}

func (ctx *testContext) GetBody() ([]byte, error) {
return io.ReadAll(ctx.r.Body)
}

func (ctx *testContext) GetBodyReader() io.Reader {
return ctx.r.Body
}

func (ctx *testContext) WriteStatus(code int) {
ctx.w.WriteHeader(code)
}

func (ctx *testContext) AppendHeader(name string, value string) {
ctx.w.Header().Add(name, value)
}

func (ctx *testContext) WriteHeader(name string, value string) {
ctx.w.Header().Set(name, value)
}

func (ctx *testContext) BodyWriter() io.Writer {
return ctx.w
}

type testAdapter struct {
router chi.Router
}

func (a *testAdapter) Handle(method, path string, handler func(Context)) {
a.router.MethodFunc(method, path, func(w http.ResponseWriter, r *http.Request) {
handler(&testContext{r: r, w: w})
})
}

func NewTestAdapter(r chi.Router, config Config) API {
return NewAPI(config, &testAdapter{router: r})
}

type ExhaustiveErrorsInputBody struct {
Name string `json:"name" maxLength:"10"`
Count int `json:"count" minimum:"1"`
}

func (b *ExhaustiveErrorsInputBody) Resolve(ctx Context) []error {
return []error{fmt.Errorf("body resolver error")}
}

type ExhaustiveErrorsInput struct {
ID string `path:"id" maxLength:"5"`
Body ExhaustiveErrorsInputBody `json:"body"`
}

func (i *ExhaustiveErrorsInput) Resolve(ctx Context) []error {
return []error{&ErrorDetail{
Location: "path.id",
Message: "input resolver error",
Value: i.ID,
}}
}

type ExhaustiveErrorsOutput struct {
}

func TestExhaustiveErrors(t *testing.T) {
r := chi.NewRouter()
app := NewTestAdapter(r, DefaultConfig("Test API", "1.0.0"))
Register(app, Operation{
OperationID: "test",
Method: http.MethodPut,
Path: "/errors/{id}",
}, func(ctx context.Context, input *ExhaustiveErrorsInput) (*ExhaustiveErrorsOutput, error) {
return &ExhaustiveErrorsOutput{}, nil
})

req, _ := http.NewRequest(http.MethodPut, "/errors/123456", strings.NewReader(`{"name": "12345678901", "count": 0}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnprocessableEntity, w.Code)
assert.JSONEq(t, `{
"$schema": "https:///schemas/ErrorModel.json",
"title": "Unprocessable Entity",
"status": 422,
"detail": "validation failed",
"errors": [
{
"message": "expected length <= 5",
"location": "path.id",
"value": "123456"
}, {
"message": "expected length <= 10",
"location": "body.name",
"value": "12345678901"
}, {
"message": "expected number >= 1",
"location": "body.count",
"value": 0
}, {
"message": "input resolver error",
"location": "path.id",
"value": "123456"
}, {
"message": "body resolver error"
}
]
}`, w.Body.String())
}

func BenchmarkSecondDecode(b *testing.B) {
type MediumSized struct {
ID int `json:"id"`
Expand Down
4 changes: 4 additions & 0 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ type mapRegistry struct {
func (r *mapRegistry) Schema(t reflect.Type, allowRef bool, hint string) *Schema {
t = deref(t)
getsRef := t.Kind() == reflect.Struct
if t == timeType {
// Special case: time.Time is always a string.
getsRef = false
}

name := r.namer(t, hint)

Expand Down
24 changes: 14 additions & 10 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,23 @@ type Schema struct {
}

func (s *Schema) PrecomputeMessages() {
s.msgEnum = "expected string to be one of \"" + strings.Join(mapTo(s.Enum, func(v any) string {
s.msgEnum = "expected value to be one of \"" + strings.Join(mapTo(s.Enum, func(v any) string {
return fmt.Sprintf("%v", v)
}), ", ") + "\""
if s.Minimum != nil {
s.msgMinimum = fmt.Sprintf("expected number >= %f", *s.Minimum)
s.msgMinimum = fmt.Sprintf("expected number >= %v", *s.Minimum)
}
if s.ExclusiveMinimum != nil {
s.msgExclusiveMinimum = fmt.Sprintf("expected number < %f", *s.ExclusiveMinimum)
s.msgExclusiveMinimum = fmt.Sprintf("expected number > %v", *s.ExclusiveMinimum)
}
if s.Maximum != nil {
s.msgMaximum = fmt.Sprintf("expected number <= %f", *s.Maximum)
s.msgMaximum = fmt.Sprintf("expected number <= %v", *s.Maximum)
}
if s.ExclusiveMaximum != nil {
s.msgExclusiveMaximum = fmt.Sprintf("expected number < %f", *s.ExclusiveMaximum)
s.msgExclusiveMaximum = fmt.Sprintf("expected number < %v", *s.ExclusiveMaximum)
}
if s.MultipleOf != nil {
s.msgMultipleOf = fmt.Sprintf("expected number to be a multiple of %f", *s.MultipleOf)
s.msgMultipleOf = fmt.Sprintf("expected number to be a multiple of %v", *s.MultipleOf)
}
if s.MinLength != nil {
s.msgMinLength = fmt.Sprintf("expected length >= %d", *s.MinLength)
Expand All @@ -128,10 +128,10 @@ func (s *Schema) PrecomputeMessages() {
s.msgPattern = "expected string to match pattern " + s.Pattern
}
if s.MinItems != nil {
s.msgMinItems = fmt.Sprintf("expected array with at least %d items", *s.MinItems)
s.msgMinItems = fmt.Sprintf("expected array length >= %d", *s.MinItems)
}
if s.MaxItems != nil {
s.msgMaxItems = fmt.Sprintf("expected array with at most %d items", *s.MaxItems)
s.msgMaxItems = fmt.Sprintf("expected array length <= %d", *s.MaxItems)
}
if s.MinProperties != nil {
s.msgMinProperties = fmt.Sprintf("expected object with at least %d properties", *s.MinProperties)
Expand Down Expand Up @@ -252,8 +252,12 @@ func SchemaFromField(registry Registry, parent reflect.Type, f reflect.StructFie
}
fs := registry.Schema(f.Type, true, parentName+f.Name+"Struct")
fs.Description = f.Tag.Get("doc")
fs.Format = f.Tag.Get("format")
fs.ContentEncoding = f.Tag.Get("encoding")
if fmt := f.Tag.Get("format"); fmt != "" {
fs.Format = fmt
}
if enc := f.Tag.Get("encoding"); enc != "" {
fs.ContentEncoding = enc
}
fs.Default = jsonTag(f, "default", false)
fs.Example = jsonTag(f, "example", false)

Expand Down
50 changes: 27 additions & 23 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"regexp"
"strconv"
"time"
"unsafe"

"github.com/google/uuid"
"golang.org/x/net/idna"
Expand Down Expand Up @@ -129,63 +130,64 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
}
if !found {
res.Add(path, str, "expected string to be RFC3339 date-time")
res.Add(path, str, "expected string to be RFC 3339 date-time")
}
case "date":
if _, err := time.Parse("2006-01-02", str); err != nil {
res.Add(path, str, "expected string to be RFC3339 date")
res.Add(path, str, "expected string to be RFC 3339 date")
}
case "time":
if _, err := time.Parse("15:04:05", str); err != nil {
if _, err := time.Parse("15:04:05Z07:00", str); err != nil {
res.Add(path, str, "expected string to be RFC3339 time")
res.Add(path, str, "expected string to be RFC 3339 time")
}
}
// TODO: duration
case "email", "idn-email":
if _, err := mail.ParseAddress(str); err != nil {
res.Addf(path, str, "expected string to be RFC5322 email: %v", err)
res.Addf(path, str, "expected string to be RFC 5322 email: %v", err)
}
case "hostname":
if !(rxHostname.MatchString(str) && len(str) < 256) {
res.Add(path, str, "expected string to be RFC5890 hostname")
res.Add(path, str, "expected string to be RFC 5890 hostname")
}
case "idn-hostname":
if _, err := idna.ToASCII(str); err != nil {
res.Addf(path, str, "expected string to be RFC5890 hostname: %v", err)
res.Addf(path, str, "expected string to be RFC 5890 hostname: %v", err)
}
case "ipv4":
if ip := net.ParseIP(str); ip == nil || ip.To4() == nil {
res.Add(path, str, "expected string to be RFC2673 ipv4")
res.Add(path, str, "expected string to be RFC 2673 ipv4")
}
case "ipv6":
if ip := net.ParseIP(str); ip == nil || ip.To16() == nil {
res.Add(path, str, "expected string to be RFC2373 ipv6")
res.Add(path, str, "expected string to be RFC 2373 ipv6")
}
case "uri", "uri-reference", "iri", "iri-reference":
if _, err := url.Parse(str); err != nil {
res.Addf(path, str, "expected string to be RFC3986 uri: %v", err)
res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err)
}
// TODO: check if it's actually a reference?
case "uuid":
if _, err := uuid.Parse(str); err != nil {
res.Addf(path, str, "expected string to be RFC4122 uuid: %v", err)
res.Addf(path, str, "expected string to be RFC 4122 uuid: %v", err)
}
case "uri-template":
u, err := url.Parse(str)
if err != nil {
res.Addf(path, str, "expected string to be RFC3986 uri: %v", err)
res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err)
return
}
if !rxURITemplate.MatchString(u.Path) {
res.Add(path, str, "expected string to be RFC6570 uri-template")
res.Add(path, str, "expected string to be RFC 6570 uri-template")
}
case "json-pointer":
if !rxJSONPointer.MatchString(str) {
res.Add(path, str, "expected string to be RFC6901 json-pointee")
res.Add(path, str, "expected string to be RFC 6901 json-pointer")
}
case "relative-json-pointer":
if !rxRelJSONPointer.MatchString(str) {
res.Add(path, str, "expected string to be RFC6901 relative-json-pointer")
res.Add(path, str, "expected string to be RFC 6901 relative-json-pointer")
}
case "regex":
if _, err := regexp.Compile(str); err != nil {
Expand Down Expand Up @@ -253,8 +255,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
case TypeString:
str, ok := v.(string)
if !ok {
res.Add(path, v, "expected string")
return
if b, ok := v.([]byte); ok {
str = *(*string)(unsafe.Pointer(&b))
} else {
res.Add(path, v, "expected string")
return
}
}

if s.MinLength != nil {
Expand Down Expand Up @@ -301,7 +307,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
}

if s.UniqueItems {
seen := make(map[any]struct{})
seen := make(map[any]struct{}, len(arr))
for _, item := range arr {
if _, ok := seen[item]; ok {
res.Add(path, v, "expected array items to be unique")
Expand Down Expand Up @@ -359,11 +365,9 @@ func handleMapString(r Registry, s *Schema, path *PathBuffer, mode ValidateMode,
// We should be permissive by default to enable easy round-trips for the
// client without needing to remove read-only values.
// TODO: should we make this configurable?
if mode == ModeWriteToServer && s.ReadOnly {
continue
}

if mode == ModeReadFromServer && s.WriteOnly && m[k] == nil && !reflect.ValueOf(m[k]).IsZero() {
// Be stricter for responses, enabling validation of the server if desired.
if mode == ModeReadFromServer && v.WriteOnly && m[k] != nil && !reflect.ValueOf(m[k]).IsZero() {
res.Add(path, m[k], "write only property is non-zero")
continue
}
Expand All @@ -372,8 +376,8 @@ func handleMapString(r Registry, s *Schema, path *PathBuffer, mode ValidateMode,
if !s.requiredMap[k] {
continue
}
if (mode == ModeWriteToServer && s.ReadOnly) ||
(mode == ModeReadFromServer && s.WriteOnly) {
if (mode == ModeWriteToServer && v.ReadOnly) ||
(mode == ModeReadFromServer && v.WriteOnly) {
// These are not required for the current mode.
continue
}
Expand Down
Loading

0 comments on commit 2ce9dfa

Please sign in to comment.