diff --git a/validate.go b/validate.go index 47da6c65..ac9f1d8d 100644 --- a/validate.go +++ b/validate.go @@ -275,7 +275,7 @@ func validateOneOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v subRes.Reset() } if !found { - res.Add(path, v, "expected value to match exactly one schema but matched none") + res.Add(path, v, validation.MsgExpectedMatchExactlyOneSchema) } } @@ -291,10 +291,49 @@ func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v } if matches == 0 { - res.Add(path, v, validation.MsgExpectedMatchSchema) + res.Add(path, v, validation.MsgExpectedMatchAtLeastOneSchema) } } +func validateDiscriminator(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, res *ValidateResult) { + var kk any + found := true + + if vv, ok := v.(map[string]any); ok { + kk, found = vv[s.Discriminator.PropertyName] + } + + if vv, ok := v.(map[any]any); ok { + kk, found = vv[s.Discriminator.PropertyName] + } + + if !found { + path.Push(s.Discriminator.PropertyName) + res.Add(path, v, validation.MsgExpectedPropertyNameInObject) + return + } + + if kk == nil { + // Either `v` is not a map or the property is set to null. Return so that + // type and enum checks on the field can complete elsewhere. + return + } + + key, ok := kk.(string) + if !ok { + path.Push(s.Discriminator.PropertyName) + return + } + + ref, found := s.Discriminator.Mapping[key] + if !found { + validateOneOf(r, s, path, mode, v, res) + return + } + + Validate(r, r.SchemaFromRef(ref), path, mode, v, res) +} + // Validate an input value against a schema, collecting errors in the validation // result object. If successful, `res.Errors` will be empty. It is suggested // to use a `sync.Pool` to reuse the PathBuffer and ValidateResult objects, @@ -318,7 +357,11 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any, } if s.OneOf != nil { - validateOneOf(r, s, path, mode, v, res) + if s.Discriminator != nil { + validateDiscriminator(r, s, path, mode, v, res) + } else { + validateOneOf(r, s, path, mode, v, res) + } } if s.AnyOf != nil { diff --git a/validate_test.go b/validate_test.go index eeb3c62b..bf4e66b6 100644 --- a/validate_test.go +++ b/validate_test.go @@ -9,8 +9,10 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/validation" ) func Ptr[T any](v T) *T { @@ -1254,6 +1256,88 @@ var validateTests = []struct { input: map[string]any{}, errs: []string{"expected required property field to be present"}, }, + { + name: "discriminator: input expected to be an object", + s: &huma.Schema{ + Type: huma.TypeObject, + OneOf: []*huma.Schema{ + {Type: huma.TypeString}, + }, + Discriminator: &huma.Discriminator{ + PropertyName: "inputType", + }, + }, + input: "test", + errs: []string{validation.MsgExpectedObject}, + }, + { + name: "discriminator: propertyName expected to be present in object", + s: &huma.Schema{ + Type: huma.TypeObject, + OneOf: []*huma.Schema{ + {Type: huma.TypeString}, + }, + Properties: map[string]*huma.Schema{ + "inputType": {Type: huma.TypeString}, + }, + Discriminator: &huma.Discriminator{ + PropertyName: "inputType", + }, + }, + input: map[string]any{"undefined": ""}, + errs: []string{validation.MsgExpectedPropertyNameInObject}, + }, + { + name: "discriminator: propertyName expected to be present in any object", + s: &huma.Schema{ + Type: huma.TypeObject, + OneOf: []*huma.Schema{ + {Type: huma.TypeString}, + }, + Properties: map[string]*huma.Schema{ + "inputType": {Type: huma.TypeString}, + }, + Discriminator: &huma.Discriminator{ + PropertyName: "inputType", + }, + }, + input: map[any]any{"undefined": ""}, + errs: []string{validation.MsgExpectedPropertyNameInObject}, + }, + { + name: "discriminator: propertyName expected to be string", + s: &huma.Schema{ + Type: huma.TypeObject, + OneOf: []*huma.Schema{ + {Type: huma.TypeString}, + }, + Properties: map[string]*huma.Schema{ + "inputType": {Type: huma.TypeString}, + }, + Discriminator: &huma.Discriminator{ + PropertyName: "inputType", + }, + }, + input: map[string]any{"inputType": 1}, + errs: []string{validation.MsgExpectedString}, + }, + { + name: "discriminator: propertyName not explicitly mapped", + s: &huma.Schema{ + Type: huma.TypeObject, + OneOf: []*huma.Schema{ + {Type: huma.TypeString}, + }, + Properties: map[string]*huma.Schema{ + "inputType": {Type: huma.TypeString}, + }, + Discriminator: &huma.Discriminator{ + PropertyName: "inputType", + }, + }, + input: map[string]any{"inputType": "test"}, + errs: []string{validation.MsgExpectedMatchExactlyOneSchema}, + }, } func TestValidate(t *testing.T) { @@ -1389,3 +1473,125 @@ func BenchmarkValidate(b *testing.B) { }) } } + +type Cat struct { + Name string `json:"name" minLength:"2" maxLength:"10"` + Kind string `json:"kind" enum:"cat"` +} + +type Dog struct { + Color string `json:"color" enum:"black,white,brown"` + Kind string `json:"kind" enum:"dog"` +} + +func Test_validateWithDiscriminator(t *testing.T) { + registry := huma.NewMapRegistry("#/components/schemas/", huma.DefaultSchemaNamer) + catSchema := registry.Schema(reflect.TypeOf(Cat{}), true, "Cat") + dogSchema := registry.Schema(reflect.TypeOf(Dog{}), true, "Dog") + + s := &huma.Schema{ + Type: huma.TypeObject, + Description: "Animal", + OneOf: []*huma.Schema{ + {Ref: catSchema.Ref}, + {Ref: dogSchema.Ref}, + }, + Discriminator: &huma.Discriminator{ + PropertyName: "kind", + Mapping: map[string]string{ + "cat": catSchema.Ref, + "dog": dogSchema.Ref, + }, + }, + } + + pb := huma.NewPathBuffer([]byte(""), 0) + res := &huma.ValidateResult{} + + tests := []struct { + name string + input any + wantErrs []string + }{ + { + name: "cat - minLength case", + input: map[string]any{ + "kind": "cat", + "name": "c", + }, + wantErrs: []string{"expected length >= 2"}, + }, + { + name: "cat - maxLength case", + input: map[string]any{ + "kind": "cat", + "name": "aaaaaaaaaaa", + }, + wantErrs: []string{"expected length <= 10"}, + }, + { + name: "cat - invalid schema", + input: map[string]any{ + "kind": "dog", + "name": "cat", + }, + wantErrs: []string{ + "expected required property color to be present", + "unexpected property", + }, + }, + { + name: "cat - any invalid schema", + input: map[any]any{ + "kind": "dog", + "name": "cat", + }, + wantErrs: []string{ + "expected required property color to be present", + "unexpected property", + }, + }, + { + name: "cat - ok", + input: map[string]any{ + "kind": "cat", + "name": "meow", + }, + }, + { + name: "cat - any ok", + input: map[any]any{ + "kind": "cat", + "name": "meow", + }, + }, + { + name: "dog - wrong color", + input: map[string]any{ + "kind": "dog", + "color": "red", + }, + wantErrs: []string{"expected value to be one of \"black, white, brown\""}, + }, + { + name: "unknown kind", + input: map[string]any{ + "kind": "unknown", + "foo": "bar", + }, + wantErrs: []string{validation.MsgExpectedMatchExactlyOneSchema}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + pb.Reset() + res.Reset() + huma.Validate(registry, s, pb, huma.ModeWriteToServer, tc.input, res) + require.Len(t, res.Errors, len(tc.wantErrs)) + for i, wantErr := range tc.wantErrs { + assert.Contains(t, res.Errors[i].Error(), wantErr) + } + }) + } +} diff --git a/validation/messages.go b/validation/messages.go index cb68214c..ea1160d3 100644 --- a/validation/messages.go +++ b/validation/messages.go @@ -17,8 +17,10 @@ var ( MsgExpectedRFC6901JSONPointer = "expected string to be RFC 6901 json-pointer" MsgExpectedRFC6901RelativeJSONPointer = "expected string to be RFC 6901 relative-json-pointer" MsgExpectedRegexp = "expected string to be regex: %v" - MsgExpectedMatchSchema = "expected value to match at least one schema but matched none" + MsgExpectedMatchAtLeastOneSchema = "expected value to match at least one schema but matched none" + MsgExpectedMatchExactlyOneSchema = "expected value to match exactly one schema but matched none" MsgExpectedNotMatchSchema = "expected value to not match schema" + MsgExpectedPropertyNameInObject = "expected propertyName value to be present in object" MsgExpectedBoolean = "expected boolean" MsgExpectedNumber = "expected number" MsgExpectedString = "expected string"