Skip to content

Commit

Permalink
fix: support non-addressable resolver values
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Sep 19, 2024
1 parent 3649df3 commit 40fddbd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
17 changes: 15 additions & 2 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -1307,11 +1307,24 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

resolvers.EveryPB(pb, v, func(item reflect.Value, _ bool) {
if resolver, ok := item.Addr().Interface().(Resolver); ok {
if item.CanAddr() {
item = item.Addr()
} else {
// If the item is non-addressable (example: primitive custom type with
// a resolver as a map value), then we need to create a new pointer to
// the value to ensure the resolver can be called, regardless of whether
// is is a value or pointer resolver type.
// TODO: this is inefficient and could be improved in the future.
ptr := reflect.New(item.Type())
elem := ptr.Elem()
elem.Set(item)
item = ptr
}
if resolver, ok := item.Interface().(Resolver); ok {
if errs := resolver.Resolve(ctx); len(errs) > 0 {
res.Errors = append(res.Errors, errs...)
}
} else if resolver, ok := item.Addr().Interface().(ResolverWithPath); ok {
} else if resolver, ok := item.Interface().(ResolverWithPath); ok {
if errs := resolver.Resolve(ctx, pb); len(errs) > 0 {
res.Errors = append(res.Errors, errs...)
}
Expand Down
21 changes: 21 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2277,6 +2277,27 @@ func TestBodyRace(t *testing.T) {
}
}

type CustomMapValue string

func (v *CustomMapValue) Resolve(ctx huma.Context) []error {
return nil
}

func TestResolverCustomTypePrimitive(t *testing.T) {
_, api := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))
huma.Post(api, "/test", func(ctx context.Context, input *struct {
Body struct {
Tags map[string]CustomMapValue `json:"tags"`
}
}) (*struct{}, error) {
return nil, nil
})

assert.NotPanics(t, func() {
api.Post("/test", map[string]any{"tags": map[string]string{"foo": "bar"}})
})
}

// func BenchmarkSecondDecode(b *testing.B) {
// //nolint: musttag
// type MediumSized struct {
Expand Down

0 comments on commit 40fddbd

Please sign in to comment.