Skip to content

Commit

Permalink
Merge pull request #575 from danielgtaylor/fix-transform-schema-overw…
Browse files Browse the repository at this point in the history
…rite

fix: prevent overwriting schema validations
  • Loading branch information
danielgtaylor authored Sep 18, 2024
2 parents cce4569 + ca4b5f6 commit 3649df3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 68 deletions.
2 changes: 1 addition & 1 deletion formdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
continue
}

if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required") {
if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) {
requiredFields[i] = name
schema.requiredMap[name] = true
}
Expand Down
2 changes: 1 addition & 1 deletion huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
pfi.TimeFormat = timeFormat
}

if !boolTag(f, "hidden") {
if !boolTag(f, "hidden", false) {
desc := ""
if pfi.Schema != nil {
// If the schema has a description, use it. Some tools will not show
Expand Down
114 changes: 48 additions & 66 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (s *Schema) PrecomputeMessages() {
}
}

func boolTag(f reflect.StructField, tag string) bool {
func boolTag(f reflect.StructField, tag string, def bool) bool {
if v := f.Tag.Get(tag); v != "" {
if v == "true" {
return true
Expand All @@ -339,29 +339,36 @@ func boolTag(f reflect.StructField, tag string) bool {
panic(fmt.Errorf("invalid bool tag '%s' for field '%s': %v", tag, f.Name, v))
}
}
return false
return def
}

func intTag(f reflect.StructField, tag string) *int {
func intTag(f reflect.StructField, tag string, def *int) *int {
if v := f.Tag.Get(tag); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return &i
} else {
panic(fmt.Errorf("invalid int tag '%s' for field '%s': %v (%w)", tag, f.Name, v, err))
}
}
return nil
return def
}

func floatTag(f reflect.StructField, tag string) *float64 {
func floatTag(f reflect.StructField, tag string, def *float64) *float64 {
if v := f.Tag.Get(tag); v != "" {
if i, err := strconv.ParseFloat(v, 64); err == nil {
return &i
} else {
panic(fmt.Errorf("invalid float tag '%s' for field '%s': %v (%w)", tag, f.Name, v, err))
}
}
return nil
return def
}

func stringTag(f reflect.StructField, tag string, def string) string {
if v := f.Tag.Get(tag); v != "" {
return v
}
return def
}

// ensureType panics if the given value does not match the JSON Schema type.
Expand Down Expand Up @@ -508,18 +515,14 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
if fs == nil {
return fs
}
if doc := f.Tag.Get("doc"); doc != "" {
fs.Description = doc
}
fs.Description = stringTag(f, "doc", fs.Description)
if fs.Format == "date-time" && f.Tag.Get("header") != "" {
// Special case: this is a header and uses a different date/time format.
// Note that it can still be overridden by the `format` or `timeFormat`
// tags later.
fs.Format = "date-time-http"
}
if format := f.Tag.Get("format"); format != "" {
fs.Format = format
}
fs.Format = stringTag(f, "format", fs.Format)
if timeFmt := f.Tag.Get("timeFormat"); timeFmt != "" {
switch timeFmt {
case "2006-01-02":
Expand All @@ -530,9 +533,7 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
fs.Format = timeFmt
}
}
if enc := f.Tag.Get("encoding"); enc != "" {
fs.ContentEncoding = enc
}
fs.ContentEncoding = stringTag(f, "encoding", fs.ContentEncoding)
if defaultValue := jsonTag(registry, f, fs, "default"); defaultValue != nil {
fs.Default = defaultValue
}
Expand All @@ -559,56 +560,37 @@ func SchemaFromField(registry Registry, f reflect.StructField, hint string) *Sch
}
}

if _, ok := f.Tag.Lookup("nullable"); ok {
fs.Nullable = boolTag(f, "nullable")
if fs.Nullable && fs.Ref != "" {
// Nullability is only supported for scalar types for now. Objects are
// much more complicated because the `null` type lives within the object
// definition (requiring multiple copies of the object) or needs to use
// `anyOf` or `not` which is not supported by all code generators, or is
// supported poorly & generates hard-to-use code. This is less than ideal
// but a compromise for now to support some nullability built-in.
panic(fmt.Errorf("nullable is not supported for field '%s' which is type '%s'", f.Name, fs.Ref))
}
}

if _, ok := f.Tag.Lookup("minimum"); ok {
fs.Minimum = floatTag(f, "minimum")
}

fs.ExclusiveMinimum = floatTag(f, "exclusiveMinimum")

if _, ok := f.Tag.Lookup("maximum"); ok {
fs.Maximum = floatTag(f, "maximum")
}
fs.ExclusiveMaximum = floatTag(f, "exclusiveMaximum")
fs.MultipleOf = floatTag(f, "multipleOf")
if _, ok := f.Tag.Lookup("minLength"); ok {
fs.MinLength = intTag(f, "minLength")
}

if _, ok := f.Tag.Lookup("maxLength"); ok {
fs.MaxLength = intTag(f, "maxLength")
}
fs.Pattern = f.Tag.Get("pattern")
fs.PatternDescription = f.Tag.Get("patternDescription")
if _, ok := f.Tag.Lookup("minItems"); ok {
fs.MinItems = intTag(f, "minItems")
}
if _, ok := f.Tag.Lookup("maxItems"); ok {
fs.MaxItems = intTag(f, "maxItems")
}
fs.UniqueItems = boolTag(f, "uniqueItems")
fs.MinProperties = intTag(f, "minProperties")
fs.MaxProperties = intTag(f, "maxProperties")
fs.ReadOnly = boolTag(f, "readOnly")
fs.WriteOnly = boolTag(f, "writeOnly")
fs.Deprecated = boolTag(f, "deprecated")
fs.Nullable = boolTag(f, "nullable", fs.Nullable)
if fs.Nullable && fs.Ref != "" {
// Nullability is only supported for scalar types for now. Objects are
// much more complicated because the `null` type lives within the object
// definition (requiring multiple copies of the object) or needs to use
// `anyOf` or `not` which is not supported by all code generators, or is
// supported poorly & generates hard-to-use code. This is less than ideal
// but a compromise for now to support some nullability built-in.
panic(fmt.Errorf("nullable is not supported for field '%s' which is type '%s'", f.Name, fs.Ref))
}

fs.Minimum = floatTag(f, "minimum", fs.Minimum)
fs.ExclusiveMinimum = floatTag(f, "exclusiveMinimum", fs.ExclusiveMinimum)
fs.Maximum = floatTag(f, "maximum", fs.Maximum)
fs.ExclusiveMaximum = floatTag(f, "exclusiveMaximum", fs.ExclusiveMaximum)
fs.MultipleOf = floatTag(f, "multipleOf", fs.MultipleOf)
fs.MinLength = intTag(f, "minLength", fs.MinLength)
fs.MaxLength = intTag(f, "maxLength", fs.MaxLength)
fs.Pattern = stringTag(f, "pattern", fs.Pattern)
fs.PatternDescription = stringTag(f, "patternDescription", fs.PatternDescription)
fs.MinItems = intTag(f, "minItems", fs.MinItems)
fs.MaxItems = intTag(f, "maxItems", fs.MaxItems)
fs.UniqueItems = boolTag(f, "uniqueItems", fs.UniqueItems)
fs.MinProperties = intTag(f, "minProperties", fs.MinProperties)
fs.MaxProperties = intTag(f, "maxProperties", fs.MaxProperties)
fs.ReadOnly = boolTag(f, "readOnly", fs.ReadOnly)
fs.WriteOnly = boolTag(f, "writeOnly", fs.WriteOnly)
fs.Deprecated = boolTag(f, "deprecated", fs.Deprecated)
fs.PrecomputeMessages()

if v := f.Tag.Get("hidden"); v != "" {
fs.hidden = boolTag(f, "hidden")
}
fs.hidden = boolTag(f, "hidden", fs.hidden)

return fs
}
Expand Down Expand Up @@ -830,7 +812,7 @@ func schemaFromType(r Registry, t reflect.Type) *Schema {
}

if _, ok := f.Tag.Lookup("required"); ok {
fieldRequired = boolTag(f, "required")
fieldRequired = boolTag(f, "required", false)
}

if dr := f.Tag.Get("dependentRequired"); strings.TrimSpace(dr) != "" {
Expand Down Expand Up @@ -885,12 +867,12 @@ func schemaFromType(r Registry, t reflect.Type) *Schema {
additionalProps := false
if f, ok := t.FieldByName("_"); ok {
if _, ok = f.Tag.Lookup("additionalProperties"); ok {
additionalProps = boolTag(f, "additionalProperties")
additionalProps = boolTag(f, "additionalProperties", false)
}

if _, ok := f.Tag.Lookup("nullable"); ok {
// Allow overriding nullability per struct.
s.Nullable = boolTag(f, "nullable")
s.Nullable = boolTag(f, "nullable", false)
}
}
s.AdditionalProperties = additionalProps
Expand Down
3 changes: 3 additions & 0 deletions schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,7 @@ type ExampleInputStruct struct {
Email string `json:"email" format:"email" doc:"Contact e-mail address"`
Age *int `json:"age,omitempty" minimum:"0"`
Comment string `json:"comment,omitempty" maxLength:"256"`
Pattern string `json:"pattern" pattern:"^[a-z]+$"`
}

// Implements SchemaTransformer interface, reusing parts of the schema from `ExampleInputStruct`
Expand All @@ -1419,6 +1420,7 @@ type ExampleUpdateStruct struct {
Email *string `json:"email" doc:"Override doc for email"`
Age OmittableNullable[int] `json:"age"`
Comment OmittableNullable[string] `json:"comment"`
Pattern string `json:"pattern"`
}

func (u *ExampleUpdateStruct) TransformSchema(r huma.Registry, s *huma.Schema) *huma.Schema {
Expand Down Expand Up @@ -1449,6 +1451,7 @@ func TestSchemaTransformer(t *testing.T) {
assert.True(t, s.Properties["age"].Nullable)
assert.Equal(t, inputSchema.Properties["comment"].MaxLength, s.Properties["comment"].MaxLength)
assert.True(t, s.Properties["comment"].Nullable)
assert.Equal(t, inputSchema.Properties["pattern"].Pattern, s.Properties["pattern"].Pattern)
}
updateSchema1 := r.Schema(reflect.TypeOf(ExampleUpdateStruct{}), false, "")
validateSchema(updateSchema1)
Expand Down

0 comments on commit 3649df3

Please sign in to comment.