From 9a5d037f99eee63ecbb9254d494a72aa922c6c58 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Mon, 13 Apr 2020 22:47:40 -0700 Subject: [PATCH] feat: functional options for resources/operations --- README.md | 101 ++---- benchmark/huma/main.go | 14 +- dependency.go | 143 +++++--- dependency_test.go | 80 +++-- examples/readme/main.go | 9 +- middelware.go | 54 ++-- middleware_test.go | 20 +- openapi.go | 413 +++++++++++------------- openapi_test.go | 104 +++--- options.go | 400 +++++++++++++++++++++-- resource.go | 188 ++++------- resource_test.go | 101 ++---- router.go | 55 ++-- router_test.go | 243 +++++--------- schema.go => schema/schema.go | 46 +-- schema_test.go => schema/schema_test.go | 113 ++++--- validate.go | 68 ++-- validate_test.go | 213 ++++-------- 18 files changed, 1197 insertions(+), 1168 deletions(-) rename schema.go => schema/schema.go (91%) rename schema_test.go => schema/schema_test.go (75%) diff --git a/README.md b/README.md index 6654d5bd..6fa9612c 100644 --- a/README.md +++ b/README.md @@ -294,29 +294,6 @@ notes.With( ).Get("Get a list of all notes", func () []*NoteSummary { // Implementation goes here }) - -// The above idiom is common enough when needing to change response codes -// or allow certain response headers that there is a shortcut: -notes. - JSON(http.StatusCreated, "Success", "expires"). - Get("Get a list of all notes", func () []*NoteSummary { - // Implementation goes here - }) -``` - -Alternatively you can provide a `*huma.Operation` instance to the resource if you want more flexibility or prefer this style over chaining: - -```go -// Create the operation -notes.Operation(http.MethodGet, &huma.Operation{ - Description: "Get a list of all notes" - Responses: []*huma.Response{ - huma.ResponseJSON(http.StatusOK, "List of notes", "expires"), - }, - Handler: func () []*NoteSummary { - // Implementation goes here - } -}) ``` > :whale: Operations map an HTTP action verb to a resource. You might `POST` a new note or `GET` a user. Sometimes the mapping is less obvious and you can consider using a sub-resource. For example, rather than unliking a post, maybe you `DELETE` the `/posts/{id}/likes` resource. @@ -371,15 +348,17 @@ Get("Get a note by its ID", func(id string) (*huma.ErrorModel, *Note) { You can also declare parameters with additional validation logic: ```go -huma.PathParam("id", "Note ID", &huma.Schema{ +s := &schema.Schema{ MinLength: 1, MaxLength: 32, -}) +} + +huma.PathParam("id", "Note ID", huma.Schema(s)) ``` Once a parameter is declared it will get parsed, validated, and then sent to your handler function. If parsing or validation fails, the client gets a 400-level HTTP error. -> :whale: If a proxy is providing e.g. authentication or rate-limiting and exposes additional internal-only information then use the internal parameters like `huma.HeaderParamInternal("UserID", "Parsed user from the auth system", "nobody")`. Internal parameters are never included in the generated OpenAPI 3 spec or documentation. +> :whale: If a proxy is providing e.g. authentication or rate-limiting and exposes additional internal-only information then use the internal parameters like `huma.HeaderParam("UserID", "Parsed user from the auth system", "nobody", huma.Internal())`. Internal parameters are never included in the generated OpenAPI 3 spec or documentation. ## Request & Response Models @@ -433,11 +412,11 @@ Get("Description", func() (*huma.ErrorModel, *Note) { Whichever model is not `nil` will get sent back to the client. -Empty responses, e.g. a `204 No Content` or `304 Not Modified` are also supported. Use `huma.ResponseEmpty` paired with a simple boolean to return a response without a body. Passing `false` acts like `nil` for models and prevents that response from being sent. +Empty responses, e.g. a `204 No Content` or `304 Not Modified` are also supported by setting a `ContentType` of `""`. Use `huma.Response` paired with a simple boolean to return a response without a body. Passing `false` acts like `nil` for models and prevents that response from being sent. ```go r.Resource("/notes", - huma.ResponseEmpty(http.StatusNoContent, "This should have no body")). + huma.Response(http.StatusNoContent, "This should have no body")). Get("description", func() bool { return true }) @@ -490,8 +469,8 @@ For example: ```go r.Resource("/notes", huma.Header("expires", "Expiration date for this content"), - huma.ResponseText(http.StatusOK, "Success", "expires")). -Get("description", func() (string, string) { + huma.ResponseText(http.StatusOK, "Success", huma.Headers("expires")) +).Get("description", func() (string, string) { expires := time.Now().Add(7 * 24 * time.Hour).MarshalText() return expires, "Hello!" }) @@ -521,9 +500,7 @@ Global dependencies are created by just setting some value, while contextual dep ```go // Register a new database connection dependency -db := &huma.Dependency{ - Value: db.NewConnection(), -} +db := huma.SimpleDependency(db.NewConnection()) // Register a new request logger dependency. This is contextual because we // will print out the requester's IP address with each log message. @@ -531,27 +508,25 @@ type MyLogger struct { Info: func(msg string), } -logger := &huma.Dependency{ - Dependencies: []*huma.Dependency{huma.ContextDependency()}, - Value: func(c *gin.Context) (*MyLogger, error) { +logger := huma.Dependency( + huma.GinContextDependency(), + func(c *gin.Context) (*MyLogger, error) { return &MyLogger{ Info: func(msg string) { fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr) }, }, nil }, -} +) // Use them in any handler by adding them to both `Depends` and the list of // handler function arguments. -r.Resource("/foo").Operation(http.MethodGet, &huma.Operation{ - // ... - Dependencies: []*huma.Dependency{db, logger}, - Handler: func(db *db.Connection, log *MyLogger) string { - log.Info("test") - item := db.Fetch("query") - return item.ID - } +r.Resource("/foo").With( + db, logger +).Get("doc", func(db *db.Connection, log *MyLogger) string { + log.Info("test") + item := db.Fetch("query") + return item.ID }) ``` @@ -624,7 +599,7 @@ r.Resource("/timeout", ### Request Body Timeouts -By default any handler which takes in a request body parameter will have a read timeout of 15 seconds set on it. If set to nonzero for a handler which does **not** take a body, then the timeout will be set on the underlying connection before calling your handler. The timeout value is configurable at the resource and operation level. +By default any handler which takes in a request body parameter will have a read timeout of 15 seconds set on it. If set to nonzero for a handler which does **not** take a body, then the timeout will be set on the underlying connection before calling your handler. When triggered, the server sends a 408 Request Timeout as JSON with a message containing the time waited. @@ -635,21 +610,11 @@ type Input struct { r := huma.NewRouter("My API", "1.0.0") -// Resource-level limit to 5 seconds -r.Resource("/foo").BodyReadTimeout(5 * time.Second).Post( +// Limit to 5 seconds +r.Resource("/foo", huma.BodyReadTimeout(5 * time.Second)).Post( "Create item", func(input *Input) string { return "Hello, " + input.ID }) - -// Operation-level limit -r.Resource("/foo").Operation(http.MethodPost, &huma.Operation{ - // ... - BodyReadTimeout: 5 * time.Second, - Handler: func(input *Input) string { - return "Hello, " + input.ID - }, - // ... -}) ``` You can also access the underlying TCP connection and set deadlines manually: @@ -675,22 +640,15 @@ r.Resource("/foo", huma.GinContextDependency()).Get(func (c *gin.Context) string ### Request Body Size Limits -By default each operation has a 1 MiB reqeuest body size limit. This value is configurable at the resource and operation level. +By default each operation has a 1 MiB reqeuest body size limit. When triggered, the server sends a 413 Request Entity Too Large as JSON with a message containing the maximum body size for this operation. ```go r := huma.NewRouter("My API", "1.0.0") -// Resource-level limit set to 10 MiB -r.Resource("/foo").MaxBodyBytes(10 * 1024 * 1024).Get(...) - -// Operation-level limit -r.Resource("/foo").Operation(http.MethodGet, &huma.Operation{ - // ... - MaxBodyBytes: 10 * 1024 * 1024, - // ... -}) +// Limit set to 10 MiB +r.Resource("/foo", MaxBodyBytes(10 * 1024 * 1024)).Get(...) ``` > :whale: Set to `-1` in order to disable the check, allowing for unlimited request body size for e.g. large streaming file uploads. @@ -702,8 +660,7 @@ Huma provides a Zap-based contextual structured logger built-in. You can access ```go r.Resource("/test", huma.LogDependency(), - huma.ResponseText(http.StatusOK, "Successful")). -Get("Logger test", func(log *zap.SugaredLogger) string { +).Get("Logger test", func(log *zap.SugaredLogger) string { log.Info("I'm using the logger!") return "Hello, world" }) @@ -818,10 +775,10 @@ You can access the root `cobra.Command` via `r.Root()` and add new custom comman ## Middleware -You can make use of any Gin-compatible middleware via the `Middleware()` router option. +You can make use of any Gin-compatible middleware via the `GinMiddleware()` router option. ```go -r := huma.NewRouter("My API", "1.0.0", huma.Middleware(gin.Logger())) +r := huma.NewRouter("My API", "1.0.0", huma.GinMiddleware(gin.Logger())) ``` ## HTTP/2 Setup diff --git a/benchmark/huma/main.go b/benchmark/huma/main.go index b724e984..09325db1 100644 --- a/benchmark/huma/main.go +++ b/benchmark/huma/main.go @@ -26,19 +26,17 @@ func main() { r := huma.NewRouter("Benchmark", "1.0.0", huma.WithGin(g)) - d := &huma.Dependency{ - Params: []*huma.Param{ - huma.HeaderParam("authorization", "Auth header", ""), - }, - Value: func(auth string) (string, error) { + d := huma.Dependency( + huma.HeaderParam("authorization", "Auth header", ""), + func(auth string) (string, error) { return strings.Split(auth, " ")[0], nil }, - } + ) r.Resource("/items", d, huma.PathParam("id", "The item's unique ID"), - huma.Header("x-authinfo", "..."), - huma.ResponseJSON(http.StatusOK, "Successful hello response", "x-authinfo"), + huma.ResponseHeader("x-authinfo", "..."), + huma.ResponseJSON(http.StatusOK, "Successful hello response", huma.Headers("x-authinfo")), ).Get("Huma benchmark test", func(authInfo string, id int) (string, *Item) { return authInfo, &Item{ ID: id, diff --git a/dependency.go b/dependency.go index 4ddd5d29..5fb4a8db 100644 --- a/dependency.go +++ b/dependency.go @@ -11,51 +11,94 @@ import ( // ErrDependencyInvalid is returned when registering a dependency fails. var ErrDependencyInvalid = errors.New("dependency invalid") -// Dependency represents a handler function dependency and its associated +// OpenAPIDependency represents a handler function dependency and its associated // inputs and outputs. Value can be either a struct pointer (global dependency) // or a `func(dependencies, params) (headers, struct pointer, error)` style // function. -type Dependency struct { - Dependencies []*Dependency - Params []*Param - ResponseHeaders []*ResponseHeader - Value interface{} +type OpenAPIDependency struct { + dependencies []*OpenAPIDependency + params []*OpenAPIParam + responseHeaders []*OpenAPIResponseHeader + handler interface{} } -var contextDependency Dependency -var ginContextDependency Dependency -var operationDependency Dependency +// Dependencies returns the dependencies associated with this dependency. +func (d *OpenAPIDependency) Dependencies() []*OpenAPIDependency { + return d.dependencies +} + +// Params returns the params associated with this dependency. +func (d *OpenAPIDependency) Params() []*OpenAPIParam { + return d.params +} + +// ResponseHeaders returns the params associated with this dependency. +func (d *OpenAPIDependency) ResponseHeaders() []*OpenAPIResponseHeader { + return d.responseHeaders +} + +// NewSimpleDependency returns a dependency with a function or value. +func NewSimpleDependency(value interface{}) *OpenAPIDependency { + return NewDependency(nil, value) +} + +// NewDependency returns a dependency with the given option and a handler +// function. +func NewDependency(option DependencyOption, handler interface{}) *OpenAPIDependency { + d := &OpenAPIDependency{ + dependencies: make([]*OpenAPIDependency, 0), + params: make([]*OpenAPIParam, 0), + responseHeaders: make([]*OpenAPIResponseHeader, 0), + handler: handler, + } + + if option != nil { + option.ApplyDependency(d) + } + + return d +} + +var contextDependency OpenAPIDependency +var ginContextDependency OpenAPIDependency +var operationDependency OpenAPIDependency // ContextDependency returns a dependency for the current request's // `context.Context`. This is useful for timeouts & cancellation. -func ContextDependency() *Dependency { - return &contextDependency +func ContextDependency() DependencyOption { + return &dependencyOption{func(d *OpenAPIDependency) { + d.dependencies = append(d.dependencies, &contextDependency) + }} } // GinContextDependency returns a dependency for the current request's // `*gin.Context`. -func GinContextDependency() *Dependency { - return &ginContextDependency +func GinContextDependency() DependencyOption { + return &dependencyOption{func(d *OpenAPIDependency) { + d.dependencies = append(d.dependencies, &ginContextDependency) + }} } // OperationDependency returns a dependency for the current `*huma.Operation`. -func OperationDependency() *Dependency { - return &operationDependency +func OperationDependency() DependencyOption { + return &dependencyOption{func(d *OpenAPIDependency) { + d.dependencies = append(d.dependencies, &operationDependency) + }} } // validate that the dependency deps/params/headers match the function // signature or that the value is not a function. -func (d *Dependency) validate(returnType reflect.Type) { +func (d *OpenAPIDependency) validate(returnType reflect.Type) { if d == &contextDependency || d == &ginContextDependency || d == &operationDependency { // Hard-coded known dependencies. These are special and have no value. return } - if d.Value == nil { - panic(fmt.Errorf("value must be set: %w", ErrDependencyInvalid)) + if d.handler == nil { + panic(fmt.Errorf("handler must be set: %w", ErrDependencyInvalid)) } - v := reflect.ValueOf(d.Value) + v := reflect.ValueOf(d.handler) if v.Kind() != reflect.Func { if returnType != nil && returnType != v.Type() { @@ -63,11 +106,11 @@ func (d *Dependency) validate(returnType reflect.Type) { } // This is just a static value. It shouldn't have params/headers/etc. - if len(d.Params) > 0 { + if len(d.params) > 0 { panic(fmt.Errorf("global dependency should not have params: %w", ErrDependencyInvalid)) } - if len(d.ResponseHeaders) > 0 { + if len(d.responseHeaders) > 0 { panic(fmt.Errorf("global dependency should not set headers: %w", ErrDependencyInvalid)) } @@ -75,43 +118,43 @@ func (d *Dependency) validate(returnType reflect.Type) { } fn := v.Type() - lenArgs := len(d.Dependencies) + len(d.Params) + lenArgs := len(d.dependencies) + len(d.params) if fn.NumIn() != lenArgs { // TODO: generate suggested func signature panic(fmt.Errorf("function signature should have %d args but got %s: %w", lenArgs, fn, ErrDependencyInvalid)) } - for _, dep := range d.Dependencies { + for _, dep := range d.dependencies { dep.validate(nil) } - for i, p := range d.Params { - p.validate(fn.In(len(d.Dependencies) + i)) + for i, p := range d.params { + p.validate(fn.In(len(d.dependencies) + i)) } - lenReturn := len(d.ResponseHeaders) + 2 + lenReturn := len(d.responseHeaders) + 2 if fn.NumOut() != lenReturn { panic(fmt.Errorf("function should return %d values but got %d: %w", lenReturn, fn.NumOut(), ErrDependencyInvalid)) } - for i, h := range d.ResponseHeaders { + for i, h := range d.responseHeaders { h.validate(fn.Out(i)) } } -// AllParams returns all parameters for all dependencies in the graph of this +// allParams returns all parameters for all dependencies in the graph of this // dependency in depth-first order without duplicates. -func (d *Dependency) AllParams() []*Param { - params := []*Param{} - seen := map[*Param]bool{} +func (d *OpenAPIDependency) allParams() []*OpenAPIParam { + params := []*OpenAPIParam{} + seen := map[*OpenAPIParam]bool{} - for _, p := range d.Params { + for _, p := range d.params { seen[p] = true params = append(params, p) } - for _, d := range d.Dependencies { - for _, p := range d.AllParams() { + for _, d := range d.dependencies { + for _, p := range d.allParams() { if _, ok := seen[p]; !ok { seen[p] = true @@ -123,19 +166,19 @@ func (d *Dependency) AllParams() []*Param { return params } -// AllResponseHeaders returns all response headers for all dependencies in +// allResponseHeaders returns all response headers for all dependencies in // the graph of this dependency in depth-first order without duplicates. -func (d *Dependency) AllResponseHeaders() []*ResponseHeader { - headers := []*ResponseHeader{} - seen := map[*ResponseHeader]bool{} +func (d *OpenAPIDependency) allResponseHeaders() []*OpenAPIResponseHeader { + headers := []*OpenAPIResponseHeader{} + seen := map[*OpenAPIResponseHeader]bool{} - for _, h := range d.ResponseHeaders { + for _, h := range d.responseHeaders { seen[h] = true headers = append(headers, h) } - for _, d := range d.Dependencies { - for _, h := range d.AllResponseHeaders() { + for _, d := range d.dependencies { + for _, h := range d.allResponseHeaders() { if _, ok := seen[h]; !ok { seen[h] = true @@ -147,8 +190,8 @@ func (d *Dependency) AllResponseHeaders() []*ResponseHeader { return headers } -// Resolve the value of the dependency. Returns (response headers, value, error). -func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, interface{}, error) { +// resolve the value of the dependency. Returns (response headers, value, error). +func (d *OpenAPIDependency) resolve(c *gin.Context, op *OpenAPIOperation) (map[string]string, interface{}, error) { // Identity dependencies are first. Just return if it's one of them. if d == &contextDependency { return nil, c.Request.Context(), nil @@ -162,10 +205,10 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, return nil, op, nil } - v := reflect.ValueOf(d.Value) + v := reflect.ValueOf(d.handler) if v.Kind() != reflect.Func { // Not a function, just return the global value. - return nil, d.Value, nil + return nil, d.handler, nil } // Generate the input arguments @@ -173,8 +216,8 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, headers := map[string]string{} // Resolve each sub-dependency - for _, dep := range d.Dependencies { - dHeaders, dVal, err := dep.Resolve(c, op) + for _, dep := range d.dependencies { + dHeaders, dVal, err := dep.resolve(c, op) if err != nil { return nil, nil, err } @@ -187,7 +230,7 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, } // Get each input parameter - for _, param := range d.Params { + for _, param := range d.params { v, ok := getParamValue(c, param) if !ok { return nil, nil, fmt.Errorf("could not get param value") @@ -205,9 +248,9 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, } // Get the headers & response value. - for i, h := range d.ResponseHeaders { + for i, h := range d.responseHeaders { headers[h.Name] = out[i].Interface().(string) } - return headers, out[len(d.ResponseHeaders)].Interface(), nil + return headers, out[len(d.responseHeaders)].Interface(), nil } diff --git a/dependency_test.go b/dependency_test.go index caab0546..a993b819 100644 --- a/dependency_test.go +++ b/dependency_test.go @@ -12,7 +12,7 @@ import ( ) func TestGlobalDepEmpty(t *testing.T) { - d := Dependency{} + d := OpenAPIDependency{} typ := reflect.TypeOf(123) @@ -22,8 +22,8 @@ func TestGlobalDepEmpty(t *testing.T) { } func TestGlobalDepWrongType(t *testing.T) { - d := Dependency{ - Value: "test", + d := OpenAPIDependency{ + handler: "test", } typ := reflect.TypeOf(123) @@ -34,13 +34,12 @@ func TestGlobalDepWrongType(t *testing.T) { } func TestGlobalDepParams(t *testing.T) { - d := Dependency{ - Params: []*Param{ - HeaderParam("foo", "description", "hello"), - }, - Value: "test", + d := OpenAPIDependency{ + handler: "test", } + HeaderParam("foo", "description", "hello").ApplyDependency(&d) + typ := reflect.TypeOf("test") assert.Panics(t, func() { @@ -49,11 +48,12 @@ func TestGlobalDepParams(t *testing.T) { } func TestGlobalDepHeaders(t *testing.T) { - d := Dependency{ - ResponseHeaders: []*ResponseHeader{Header("foo", "description")}, - Value: "test", + d := OpenAPIDependency{ + handler: "test", } + ResponseHeader("foo", "description").ApplyDependency(&d) + typ := reflect.TypeOf("test") assert.Panics(t, func() { @@ -62,11 +62,11 @@ func TestGlobalDepHeaders(t *testing.T) { } func TestDepContext(t *testing.T) { - d := Dependency{ - Dependencies: []*Dependency{ - ContextDependency(), + d := OpenAPIDependency{ + dependencies: []*OpenAPIDependency{ + &contextDependency, }, - Value: func(ctx context.Context) (context.Context, error) { return ctx, nil }, + handler: func(ctx context.Context) (context.Context, error) { return ctx, nil }, } mock, _ := gin.CreateTestContext(nil) @@ -75,17 +75,17 @@ func TestDepContext(t *testing.T) { typ := reflect.TypeOf(mock) d.validate(typ) - _, v, err := d.Resolve(mock, &Operation{}) + _, v, err := d.resolve(mock, &OpenAPIOperation{}) assert.NoError(t, err) assert.Equal(t, v, mock.Request.Context()) } func TestDepGinContext(t *testing.T) { - d := Dependency{ - Dependencies: []*Dependency{ - GinContextDependency(), + d := OpenAPIDependency{ + dependencies: []*OpenAPIDependency{ + &ginContextDependency, }, - Value: func(c *gin.Context) (*gin.Context, error) { return c, nil }, + handler: func(c *gin.Context) (*gin.Context, error) { return c, nil }, } mock, _ := gin.CreateTestContext(nil) @@ -93,56 +93,54 @@ func TestDepGinContext(t *testing.T) { typ := reflect.TypeOf(mock) d.validate(typ) - _, v, err := d.Resolve(mock, &Operation{}) + _, v, err := d.resolve(mock, &OpenAPIOperation{}) assert.NoError(t, err) assert.Equal(t, v, mock) } func TestDepOperation(t *testing.T) { - d := Dependency{ - Dependencies: []*Dependency{ - OperationDependency(), + d := OpenAPIDependency{ + dependencies: []*OpenAPIDependency{ + &operationDependency, }, - Value: func(o *Operation) (*Operation, error) { return o, nil }, + handler: func(o *OpenAPIOperation) (*OpenAPIOperation, error) { return o, nil }, } - mock := &Operation{} + mock := &OpenAPIOperation{} typ := reflect.TypeOf(mock) d.validate(typ) - _, v, err := d.Resolve(&gin.Context{}, mock) + _, v, err := d.resolve(&gin.Context{}, mock) assert.NoError(t, err) assert.Equal(t, v, mock) } func TestDepFuncWrongArgs(t *testing.T) { - d := Dependency{ - Params: []*Param{ - HeaderParam("foo", "desc", ""), - }, - Value: func() (string, error) { + d := OpenAPIDependency{ + handler: func() (string, error) { return "", nil }, } + HeaderParam("foo", "desc", "").ApplyDependency(&d) + assert.Panics(t, func() { d.validate(reflect.TypeOf("")) }) } func TestDepFunc(t *testing.T) { - d := Dependency{ - Params: []*Param{ - HeaderParam("x-in", "desc", ""), - }, - ResponseHeaders: []*ResponseHeader{ - Header("x-out", "desc"), - }, - Value: func(xin string) (string, string, error) { + d := OpenAPIDependency{ + handler: func(xin string) (string, string, error) { return "xout", "value", nil }, } + DependencyOptions( + HeaderParam("x-in", "desc", ""), + ResponseHeader("x-out", "desc"), + ).ApplyDependency(&d) + c := &gin.Context{ Request: &http.Request{ Header: http.Header{ @@ -152,7 +150,7 @@ func TestDepFunc(t *testing.T) { } d.validate(reflect.TypeOf("")) - h, v, err := d.Resolve(c, &Operation{}) + h, v, err := d.resolve(c, &OpenAPIOperation{}) assert.NoError(t, err) assert.Equal(t, "xout", h["x-out"]) assert.Equal(t, "value", v) diff --git a/examples/readme/main.go b/examples/readme/main.go index ed3e9075..25a25ed9 100644 --- a/examples/readme/main.go +++ b/examples/readme/main.go @@ -6,6 +6,7 @@ import ( "time" "github.com/danielgtaylor/huma" + "github.com/danielgtaylor/huma/schema" ) // NoteSummary is used to list notes. It does not include the (potentially) @@ -46,9 +47,11 @@ func main() { }) // Add an `id` path parameter to create a note resource. - note := notes.With(huma.PathParam("id", "Note ID", &huma.Schema{ - Pattern: "^[a-zA-Z0-9._-]{1,32}$", - })) + note := notes.With(huma.PathParam("id", "Note ID", + huma.Schema(&schema.Schema{ + Pattern: "^[a-zA-Z0-9._-]{1,32}$", + }), + )) notFound := huma.ResponseError(http.StatusNotFound, "Note not found") diff --git a/middelware.go b/middelware.go index 02d6194a..bfa6d65d 100644 --- a/middelware.go +++ b/middelware.go @@ -24,18 +24,21 @@ var logLevel *zap.AtomicLevel // panic when using the recovery middleware. Defaults to 10KB. var MaxLogBodyBytes int64 = 10 * 1024 -// BufferedReadCloser will read and buffer up to max bytes into buf. Additional +// Middleware TODO ... +type Middleware = gin.HandlerFunc + +// bufferedReadCloser will read and buffer up to max bytes into buf. Additional // reads bypass the buffer. -type BufferedReadCloser struct { +type bufferedReadCloser struct { reader io.ReadCloser buf *bytes.Buffer max int64 } -// NewBufferedReadCloser returns a new BufferedReadCloser that wraps reader +// newBufferedReadCloser returns a new BufferedReadCloser that wraps reader // and reads up to max bytes into the buffer. -func NewBufferedReadCloser(reader io.ReadCloser, buffer *bytes.Buffer, max int64) *BufferedReadCloser { - return &BufferedReadCloser{ +func newBufferedReadCloser(reader io.ReadCloser, buffer *bytes.Buffer, max int64) *bufferedReadCloser { + return &bufferedReadCloser{ reader: reader, buf: buffer, max: max, @@ -43,7 +46,7 @@ func NewBufferedReadCloser(reader io.ReadCloser, buffer *bytes.Buffer, max int64 } // Read data into p. Returns number of bytes read and an error, if any. -func (r *BufferedReadCloser) Read(p []byte) (n int, err error) { +func (r *bufferedReadCloser) Read(p []byte) (n int, err error) { // Read from the underlying reader like normal. n, err = r.reader.Read(p) @@ -61,12 +64,12 @@ func (r *BufferedReadCloser) Read(p []byte) (n int, err error) { } // Close the underlying reader. -func (r *BufferedReadCloser) Close() error { +func (r *bufferedReadCloser) Close() error { return r.reader.Close() } // Recovery prints stack traces on panic when used with the logging middleware. -func Recovery() func(*gin.Context) { +func Recovery() Middleware { bufPool := sync.Pool{ New: func() interface{} { return new(bytes.Buffer) @@ -82,7 +85,7 @@ func Recovery() func(*gin.Context) { buf = bufPool.Get().(*bytes.Buffer) defer bufPool.Put(buf) - c.Request.Body = NewBufferedReadCloser(c.Request.Body, buf, MaxLogBodyBytes) + c.Request.Body = newBufferedReadCloser(c.Request.Body, buf, MaxLogBodyBytes) } // Recovering comes *after* the above so the buffer is not returned to @@ -137,7 +140,7 @@ func NewLogger() (*zap.Logger, error) { // Gin context under the `log` key. It debug logs request info. If passed `nil` // for the logger, then it creates one. If the current terminal is a TTY, it // will try to use colored output automatically. -func LogMiddleware(l *zap.Logger, tags map[string]string) func(*gin.Context) { +func LogMiddleware(l *zap.Logger, tags map[string]string) Middleware { var err error if l == nil { if l, err = NewLogger(); err != nil { @@ -180,19 +183,22 @@ func LogMiddleware(l *zap.Logger, tags map[string]string) func(*gin.Context) { // LogDependency returns a dependency that resolves to a `*zap.SugaredLogger` // for the current request. This dependency *requires* the use of // `LogMiddleware` and will error if the logger is not in the request context. -func LogDependency() *Dependency { - return &Dependency{ - Dependencies: []*Dependency{ContextDependency(), OperationDependency()}, - Value: func(c *gin.Context, op *Operation) (*zap.SugaredLogger, error) { - l, ok := c.Get("log") - if !ok { - return nil, fmt.Errorf("missing logger in context") - } - sl := l.(*zap.SugaredLogger).With("operation", op.ID) - sl.Desugar() - return sl, nil - }, - } +func LogDependency() DependencyOption { + dep := NewDependency(DependencyOptions( + GinContextDependency(), + OperationDependency(), + ), func(c *gin.Context, op *OpenAPIOperation) (*zap.SugaredLogger, error) { + l, ok := c.Get("log") + if !ok { + return nil, fmt.Errorf("missing logger in context") + } + sl := l.(*zap.SugaredLogger).With("operation", op.id) + return sl, nil + }) + + return &dependencyOption{func(d *OpenAPIDependency) { + d.dependencies = append(d.dependencies, dep) + }} } // Handler404 will return JSON responses for 404 errors. @@ -226,7 +232,7 @@ func (w *minimalWriter) WriteHeader(statusCode int) { // PreferMinimalMiddleware will remove the response body and return 204 No // Content for any 2xx response where the request had the Prefer: return=minimal // set on the request. -func PreferMinimalMiddleware() func(*gin.Context) { +func PreferMinimalMiddleware() Middleware { return func(c *gin.Context) { // Wrap the response writer if c.GetHeader("prefer") == "return=minimal" { diff --git a/middleware_test.go b/middleware_test.go index 5d6bc9de..f71e76e9 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -15,14 +15,8 @@ func TestRecoveryMiddleware(t *testing.T) { r := NewTestRouter(t) r.GinEngine().Use(Recovery()) - r.Register(http.MethodGet, "/panic", &Operation{ - Description: "Panic recovery test", - Responses: []*Response{ - ResponseText(http.StatusOK, "Success"), - }, - Handler: func() string { - panic(fmt.Errorf("Some error")) - }, + r.Resource("/panic").Get("Panic recovery test", func() string { + panic(fmt.Errorf("Some error")) }) w := httptest.NewRecorder() @@ -36,14 +30,8 @@ func TestRecoveryMiddlewareLogBody(t *testing.T) { r := NewTestRouter(t) r.GinEngine().Use(Recovery()) - r.Register(http.MethodPut, "/panic", &Operation{ - Description: "Panic recovery test", - Responses: []*Response{ - ResponseText(http.StatusOK, "Success"), - }, - Handler: func(in map[string]string) string { - panic(fmt.Errorf("Some error")) - }, + r.Resource("/panic").Put("Panic recovery test", func(in map[string]string) string { + panic(fmt.Errorf("Some error")) }) w := httptest.NewRecorder() diff --git a/openapi.go b/openapi.go index a41d14ba..8e408656 100644 --- a/openapi.go +++ b/openapi.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Jeffail/gabs" + "github.com/danielgtaylor/huma/schema" "github.com/gin-gonic/gin" ) @@ -21,216 +22,182 @@ const ( InHeader ParamLocation = "header" ) -// Param describes an OpenAPI 3 parameter -type Param struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - In ParamLocation `json:"in"` - Required bool `json:"required,omitempty"` - Schema *Schema `json:"schema,omitempty"` - Deprecated bool `json:"deprecated,omitempty"` - Example interface{} `json:"example,omitempty"` +// OpenAPIParam describes an OpenAPI 3 parameter +type OpenAPIParam struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + In ParamLocation `json:"in"` + Required bool `json:"required,omitempty"` + Schema *schema.Schema `json:"schema,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + Example interface{} `json:"example,omitempty"` // Internal params are excluded from the OpenAPI document and can set up // params sent between a load balander / proxy and the service internally. - internal bool - def interface{} - typ reflect.Type -} + Internal bool -// PathParam returns a new required path parameter -func PathParam(name string, description string, schema ...*Schema) *Param { - return PathParamExample(name, description, nil, schema...) + def interface{} + typ reflect.Type } -// PathParamExample returns a new required path parameter with example -func PathParamExample(name string, description string, example interface{}, schema ...*Schema) *Param { - p := &Param{ +// NewOpenAPIParam returns a new parameter instance. +func NewOpenAPIParam(name, description string, in ParamLocation, options ...ParamOption) *OpenAPIParam { + p := &OpenAPIParam{ Name: name, Description: description, - In: InPath, - Required: true, - Example: example, + In: in, } - if len(schema) > 0 { - p.Schema = schema[0] + for _, option := range options { + option.ApplyParam(p) } return p } -// QueryParam returns a new optional query string parameter -func QueryParam(name string, description string, defaultValue interface{}, schema ...*Schema) *Param { - return QueryParamExample(name, description, defaultValue, nil, schema...) +// OpenAPIResponse describes an OpenAPI 3 response +type OpenAPIResponse struct { + Description string + ContentType string + StatusCode int + Schema *schema.Schema + Headers []string } -// QueryParamExample returns a new optional query string parameter with example -func QueryParamExample(name string, description string, defaultValue interface{}, example interface{}, schema ...*Schema) *Param { - p := &Param{ - Name: name, +// NewOpenAPIResponse returns a new response instance. +func NewOpenAPIResponse(statusCode int, description string, options ...ResponseOption) *OpenAPIResponse { + r := &OpenAPIResponse{ + StatusCode: statusCode, Description: description, - In: InQuery, - Example: example, - def: defaultValue, } - if len(schema) > 0 { - p.Schema = schema[0] + for _, option := range options { + option.ApplyResponse(r) } - return p + return r } -// QueryParamInternal returns a new optional internal query string parameter -func QueryParamInternal(name string, description string, defaultValue interface{}) *Param { - return &Param{ - Name: name, - Description: description, - In: InQuery, - internal: true, - def: defaultValue, - } +// OpenAPIResponseHeader describes a response header +type OpenAPIResponseHeader struct { + Name string `json:"-"` + Description string `json:"description,omitempty"` + Schema *schema.Schema `json:"schema,omitempty"` } -// HeaderParam returns a new optional header parameter -func HeaderParam(name string, description string, defaultValue interface{}, schema ...*Schema) *Param { - return HeaderParamExample(name, description, defaultValue, nil, schema...) +// OpenAPISecurityRequirement defines the security schemes and scopes required to use +// an operation. +type OpenAPISecurityRequirement map[string][]string + +// OpenAPIOperation describes an OpenAPI 3 operation on a path +type OpenAPIOperation struct { + *OpenAPIDependency + id string + summary string + description string + tags []string + security []OpenAPISecurityRequirement + requestContentType string + requestSchema *schema.Schema + responses []*OpenAPIResponse + extra map[string]interface{} + + // maxBodyBytes limits the size of the request body that will be read before + // an error is returned. Defaults to 1MiB if set to zero. Set to -1 for + // unlimited. + maxBodyBytes int64 + + // bodyReadTimeout sets the duration until reading the body is given up and + // aborted with an error. Defaults to 15 seconds if the body is automatically + // read and parsed into a struct, otherwise unset. Set to -1 for unlimited. + bodyReadTimeout time.Duration } -// HeaderParamExample returns a new optional header parameter with example -func HeaderParamExample(name string, description string, defaultValue interface{}, example interface{}, schema ...*Schema) *Param { - p := &Param{ - Name: name, - Description: description, - In: InHeader, - Example: example, - def: defaultValue, - } +// ID returns the unique identifier for this operation. If not set manually, +// it is generated from the path and HTTP method. +func (o *OpenAPIOperation) ID() string { + return o.id +} - if len(schema) > 0 { - p.Schema = schema[0] +// NewOperation creates a new operation with the given options applied. +func NewOperation(options ...OperationOption) *OpenAPIOperation { + op := &OpenAPIOperation{ + OpenAPIDependency: &OpenAPIDependency{ + dependencies: make([]*OpenAPIDependency, 0), + params: make([]*OpenAPIParam, 0), + responseHeaders: make([]*OpenAPIResponseHeader, 0), + }, + tags: make([]string, 0), + security: make([]OpenAPISecurityRequirement, 0), + responses: make([]*OpenAPIResponse, 0), + extra: make(map[string]interface{}), } - return p -} - -// HeaderParamInternal returns a new optional internal header parameter -func HeaderParamInternal(name string, description string, defaultValue interface{}) *Param { - return &Param{ - Name: name, - Description: description, - In: InHeader, - internal: true, - def: defaultValue, + for _, option := range options { + option.ApplyOperation(op) } -} -// Response describes an OpenAPI 3 response -type Response struct { - Description string - ContentType string - StatusCode int - Schema *Schema - Headers []string - empty bool + return op } -// ResponseEmpty creates a new response with no content type. -func ResponseEmpty(statusCode int, description string, headers ...string) *Response { - return &Response{ - Description: description, - StatusCode: statusCode, - Headers: headers, - empty: true, - } -} +// Copy creates a new shallow copy of the operation. New arrays are created for +// e.g. parameters so they can be safely appended. Existing params are not +// deeply copied and should not be modified. +func (o *OpenAPIOperation) Copy() *OpenAPIOperation { + extraCopy := map[string]interface{}{} -// ResponseText creates a new string response model. -func ResponseText(statusCode int, description string, headers ...string) *Response { - return &Response{ - Description: description, - ContentType: "text/plain", - StatusCode: statusCode, - Headers: headers, + for k, v := range o.extra { + extraCopy[k] = v } -} -// ResponseJSON creates a new JSON response model. -func ResponseJSON(statusCode int, description string, headers ...string) *Response { - return &Response{ - Description: description, - ContentType: "application/json", - StatusCode: statusCode, - Headers: headers, + newOp := &OpenAPIOperation{ + OpenAPIDependency: &OpenAPIDependency{ + dependencies: append([]*OpenAPIDependency{}, o.dependencies...), + params: append([]*OpenAPIParam{}, o.params...), + responseHeaders: append([]*OpenAPIResponseHeader{}, o.responseHeaders...), + handler: o.handler, + }, + id: o.id, + summary: o.summary, + description: o.description, + tags: append([]string{}, o.tags...), + security: append([]OpenAPISecurityRequirement{}, o.security...), + requestContentType: o.requestContentType, + requestSchema: o.requestSchema, + responses: append([]*OpenAPIResponse{}, o.responses...), + extra: extraCopy, + maxBodyBytes: o.maxBodyBytes, + bodyReadTimeout: o.bodyReadTimeout, } -} -// ResponseError creates a new error response model. Alias for ResponseJSON -// mainly useful for documentation. -func ResponseError(status int, description string, headers ...string) *Response { - return ResponseJSON(status, description, headers...) + return newOp } -// ResponseHeader describes a response header -type ResponseHeader struct { - Name string `json:"-"` - Description string `json:"description,omitempty"` - Schema *Schema `json:"schema,omitempty"` -} +// With applies options to the operation. It makes it easy to set up new params, +// responese headers, responses, etc. It always creates a new copy. +func (o *OpenAPIOperation) With(options ...OperationOption) *OpenAPIOperation { + copy := o.Copy() -// Header returns a new header -func Header(name, description string) *ResponseHeader { - return &ResponseHeader{ - Name: name, - Description: description, + for _, option := range options { + option.ApplyOperation(copy) } -} - -// SecurityRequirement defines the security schemes and scopes required to use -// an operation. -type SecurityRequirement map[string][]string - -// Operation describes an OpenAPI 3 operation on a path -type Operation struct { - ID string - Summary string - Description string - Tags []string - Security []SecurityRequirement - Dependencies []*Dependency - Params []*Param - RequestContentType string - RequestSchema *Schema - ResponseHeaders []*ResponseHeader - Responses []*Response - Handler interface{} - Extra map[string]interface{} - - // MaxBodyBytes limits the size of the request body that will be read before - // an error is returned. Defaults to 1MiB if set to zero. Set to -1 for - // unlimited. - MaxBodyBytes int64 - // BodyReadTimeout sets the duration until reading the body is given up and - // aborted with an error. Defaults to 15 seconds if the body is automatically - // read and parsed into a struct, otherwise unset. Set to -1 for unlimited. - BodyReadTimeout time.Duration + return copy } -// AllParams returns a list of all the parameters for this operation, including +// allParams returns a list of all the parameters for this operation, including // those for dependencies. -func (o *Operation) AllParams() []*Param { - params := []*Param{} - seen := map[*Param]bool{} +func (o *OpenAPIOperation) allParams() []*OpenAPIParam { + params := []*OpenAPIParam{} + seen := map[*OpenAPIParam]bool{} - for _, p := range o.Params { + for _, p := range o.params { seen[p] = true params = append(params, p) } - for _, d := range o.Dependencies { - for _, p := range d.AllParams() { + for _, d := range o.dependencies { + for _, p := range d.allParams() { if _, ok := seen[p]; !ok { seen[p] = true @@ -242,19 +209,19 @@ func (o *Operation) AllParams() []*Param { return params } -// AllResponseHeaders returns a list of all the parameters for this operation, +// allResponseHeaders returns a list of all the parameters for this operation, // including those for dependencies. -func (o *Operation) AllResponseHeaders() []*ResponseHeader { - headers := []*ResponseHeader{} - seen := map[*ResponseHeader]bool{} +func (o *OpenAPIOperation) allResponseHeaders() []*OpenAPIResponseHeader { + headers := []*OpenAPIResponseHeader{} + seen := map[*OpenAPIResponseHeader]bool{} - for _, h := range o.ResponseHeaders { + for _, h := range o.responseHeaders { seen[h] = true headers = append(headers, h) } - for _, d := range o.Dependencies { - for _, h := range d.AllResponseHeaders() { + for _, d := range o.dependencies { + for _, h := range d.allResponseHeaders() { if _, ok := seen[h]; !ok { seen[h] = true @@ -266,57 +233,45 @@ func (o *Operation) AllResponseHeaders() []*ResponseHeader { return headers } -// Server describes an OpenAPI 3 API server location -type Server struct { +// OpenAPIServer describes an OpenAPI 3 API server location +type OpenAPIServer struct { URL string `json:"url"` Description string `json:"description,omitempty"` } -// Contact information for this API. -type Contact struct { +// OpenAPIContact information for this API. +type OpenAPIContact struct { Name string `json:"name"` URL string `json:"url"` Email string `json:"email"` } -// OAuthFlow describes the URLs and scopes to get tokens via a specific flow. -type OAuthFlow struct { +// OpenAPIOAuthFlow describes the URLs and scopes to get tokens via a specific flow. +type OpenAPIOAuthFlow struct { AuthorizationURL string `json:"authorizationUrl"` TokenURL string `json:"tokenUrl"` RefreshURL string `json:"refreshUrl,omitempty"` Scopes map[string]string `json:"scopes"` } -// OAuthFlows describes the configuration for each flow type. -type OAuthFlows struct { - Implicit *OAuthFlow `json:"implicit,omitempty"` - Password *OAuthFlow `json:"password,omitempty"` - ClientCredentials *OAuthFlow `json:"clientCredentials,omitempty"` - AuthorizationCode *OAuthFlow `json:"authorizationCode,omitempty"` -} - -// SecurityScheme describes the auth mechanism(s) for this API. -type SecurityScheme struct { - Type string `json:"type"` - Description string `json:"description,omitempty"` - Name string `json:"name,omitempty"` - In string `json:"in,omitempty"` - Scheme string `json:"scheme,omitempty"` - BearerFormat string `json:"bearerFormat,omitempty"` - Flows *OAuthFlows `json:"flows,omitempty"` - OpenIDConnectURL string `json:"openIdConnectUrl,omitempty"` +// OpenAPIOAuthFlows describes the configuration for each flow type. +type OpenAPIOAuthFlows struct { + Implicit *OpenAPIOAuthFlow `json:"implicit,omitempty"` + Password *OpenAPIOAuthFlow `json:"password,omitempty"` + ClientCredentials *OpenAPIOAuthFlow `json:"clientCredentials,omitempty"` + AuthorizationCode *OpenAPIOAuthFlow `json:"authorizationCode,omitempty"` } -// SecurityRef references a previously defined `SecurityScheme` by name along -// with any required scopes. -func SecurityRef(name string, scopes ...string) []SecurityRequirement { - if scopes == nil { - scopes = []string{} - } - - return []SecurityRequirement{ - {name: scopes}, - } +// OpenAPISecurityScheme describes the auth mechanism(s) for this API. +type OpenAPISecurityScheme struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Name string `json:"name,omitempty"` + In string `json:"in,omitempty"` + Scheme string `json:"scheme,omitempty"` + BearerFormat string `json:"bearerFormat,omitempty"` + Flows *OpenAPIOAuthFlows `json:"flows,omitempty"` + OpenIDConnectURL string `json:"openIdConnectUrl,omitempty"` } // OpenAPI describes the OpenAPI 3 API @@ -324,11 +279,11 @@ type OpenAPI struct { Title string Version string Description string - Contact *Contact - Servers []*Server - SecuritySchemes map[string]*SecurityScheme - Security []SecurityRequirement - Paths map[string]map[string]*Operation + Contact *OpenAPIContact + Servers []*OpenAPIServer + SecuritySchemes map[string]*OpenAPISecurityScheme + Security []OpenAPISecurityRequirement + Paths map[string]map[string]*OpenAPIOperation // Extra allows setting extra keys in the OpenAPI root structure. Extra map[string]interface{} @@ -338,9 +293,9 @@ type OpenAPI struct { Hook func(*gabs.Container) } -// OpenAPIHandler returns a new handler function to generate an OpenAPI spec. -func OpenAPIHandler(api *OpenAPI) func(*gin.Context) { - respSchema400, _ := GenerateSchema(reflect.ValueOf(ErrorInvalidModel{}).Type()) +// openAPIHandler returns a new handler function to generate an OpenAPI spec. +func openAPIHandler(api *OpenAPI) gin.HandlerFunc { + respSchema400, _ := schema.Generate(reflect.ValueOf(ErrorInvalidModel{}).Type()) return func(c *gin.Context) { openapi := gabs.New() @@ -382,51 +337,51 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) { for method, op := range methods { method := strings.ToLower(method) - for k, v := range op.Extra { + for k, v := range op.extra { openapi.Set(v, "paths", path, method, k) } - openapi.Set(op.ID, "paths", path, method, "operationId") - if op.Summary != "" { - openapi.Set(op.Summary, "paths", path, method, "summary") + openapi.Set(op.id, "paths", path, method, "operationId") + if op.summary != "" { + openapi.Set(op.summary, "paths", path, method, "summary") } - openapi.Set(op.Description, "paths", path, method, "description") - if len(op.Tags) > 0 { - openapi.Set(op.Tags, "paths", path, method, "tags") + openapi.Set(op.description, "paths", path, method, "description") + if len(op.tags) > 0 { + openapi.Set(op.tags, "paths", path, method, "tags") } - if len(op.Security) > 0 { - openapi.Set(op.Security, "paths", path, method, "security") + if len(op.security) > 0 { + openapi.Set(op.security, "paths", path, method, "security") } - for _, param := range op.AllParams() { - if param.internal { + for _, param := range op.allParams() { + if param.Internal { // Skip internal-only parameters. continue } openapi.ArrayAppend(param, "paths", path, method, "parameters") } - if op.RequestSchema != nil { - ct := op.RequestContentType + if op.requestSchema != nil { + ct := op.requestContentType if ct == "" { ct = "application/json" } - openapi.Set(op.RequestSchema, "paths", path, method, "requestBody", "content", ct, "schema") + openapi.Set(op.requestSchema, "paths", path, method, "requestBody", "content", ct, "schema") } - responses := make([]*Response, 0, len(op.Responses)) + responses := make([]*OpenAPIResponse, 0, len(op.responses)) found400 := false - for _, resp := range op.Responses { + for _, resp := range op.responses { responses = append(responses, resp) if resp.StatusCode == http.StatusBadRequest { found400 = true } } - if op.RequestSchema != nil && !found400 { + if op.requestSchema != nil && !found400 { // Add a 400-level response in case parsing the request fails. - responses = append(responses, &Response{ + responses = append(responses, &OpenAPIResponse{ Description: "Invalid input", ContentType: "application/json", StatusCode: http.StatusBadRequest, @@ -434,12 +389,12 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) { }) } - headerMap := map[string]*ResponseHeader{} - for _, header := range op.AllResponseHeaders() { + headerMap := map[string]*OpenAPIResponseHeader{} + for _, header := range op.allResponseHeaders() { headerMap[header.Name] = header } - for _, resp := range op.Responses { + for _, resp := range op.responses { status := fmt.Sprintf("%v", resp.StatusCode) openapi.Set(resp.Description, "paths", path, method, "responses", status, "description") @@ -449,8 +404,8 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) { headers = append(headers, name) seen[name] = true } - for _, dep := range op.Dependencies { - for _, header := range dep.AllResponseHeaders() { + for _, dep := range op.dependencies { + for _, header := range dep.allResponseHeaders() { if _, ok := seen[header.Name]; !ok { headers = append(headers, header.Name) seen[header.Name] = true diff --git a/openapi_test.go b/openapi_test.go index f11508b5..e9e9c516 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -6,13 +6,14 @@ import ( "net/http/httptest" "testing" + "github.com/danielgtaylor/huma/schema" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/assert" ) var paramFuncsTable = []struct { n string - param *Param + param OperationOption name string description string in ParamLocation @@ -22,28 +23,30 @@ var paramFuncsTable = []struct { example interface{} }{ {"PathParam", PathParam("test", "desc"), "test", "desc", InPath, true, false, nil, nil}, - {"PathParamSchema", PathParam("test", "desc", &Schema{}), "test", "desc", InPath, true, false, nil, nil}, - {"PathParamExample", PathParamExample("test", "desc", 123), "test", "desc", InPath, true, false, nil, 123}, + {"PathParamSchema", PathParam("test", "desc", Schema(&schema.Schema{})), "test", "desc", InPath, true, false, nil, nil}, + {"PathParamExample", PathParam("test", "desc", Example(123)), "test", "desc", InPath, true, false, nil, 123}, {"QueryParam", QueryParam("test", "desc", "def"), "test", "desc", InQuery, false, false, "def", nil}, - {"QueryParamSchema", QueryParam("test", "desc", "def", &Schema{}), "test", "desc", InQuery, false, false, "def", nil}, - {"QueryParamExample", QueryParamExample("test", "desc", "def", "foo"), "test", "desc", InQuery, false, false, "def", "foo"}, - {"QueryParamInternal", QueryParamInternal("test", "desc", "def"), "test", "desc", InQuery, false, true, "def", nil}, + {"QueryParamSchema", QueryParam("test", "desc", "def", Schema(&schema.Schema{})), "test", "desc", InQuery, false, false, "def", nil}, + {"QueryParamExample", QueryParam("test", "desc", "def", Example("foo")), "test", "desc", InQuery, false, false, "def", "foo"}, + {"QueryParamInternal", QueryParam("test", "desc", "def", Internal()), "test", "desc", InQuery, false, true, "def", nil}, {"HeaderParam", HeaderParam("test", "desc", "def"), "test", "desc", InHeader, false, false, "def", nil}, - {"HeaderParamSchema", HeaderParam("test", "desc", "def", &Schema{}), "test", "desc", InHeader, false, false, "def", nil}, - {"HeaderParamExample", HeaderParamExample("test", "desc", "def", "foo"), "test", "desc", InHeader, false, false, "def", "foo"}, - {"HeaderParamInternal", HeaderParamInternal("test", "desc", "def"), "test", "desc", InHeader, false, true, "def", nil}, + {"HeaderParamSchema", HeaderParam("test", "desc", "def", Schema(&schema.Schema{})), "test", "desc", InHeader, false, false, "def", nil}, + {"HeaderParamExample", HeaderParam("test", "desc", "def", Example("foo")), "test", "desc", InHeader, false, false, "def", "foo"}, + {"HeaderParamInternal", HeaderParam("test", "desc", "def", Internal()), "test", "desc", InHeader, false, true, "def", nil}, } func TestParamFuncs(outer *testing.T) { for _, tt := range paramFuncsTable { local := tt outer.Run(fmt.Sprintf("%v", tt.n), func(t *testing.T) { - param := local.param + op := NewOperation() + local.param.ApplyOperation(op) + param := op.params[0] assert.Equal(t, local.name, param.Name) assert.Equal(t, local.description, param.Description) assert.Equal(t, local.in, param.In) assert.Equal(t, local.required, param.Required) - assert.Equal(t, local.internal, param.internal) + assert.Equal(t, local.internal, param.Internal) assert.Equal(t, local.def, param.def) assert.Equal(t, local.example, param.Example) }) @@ -52,23 +55,25 @@ func TestParamFuncs(outer *testing.T) { var responseFuncsTable = []struct { n string - resp *Response + resp OperationOption statusCode int description string headers []string contentType string }{ - {"ResponseEmpty", ResponseEmpty(204, "desc", "head1", "head2"), 204, "desc", []string{"head1", "head2"}, ""}, - {"ResponseText", ResponseText(200, "desc", "head1", "head2"), 200, "desc", []string{"head1", "head2"}, "application/json"}, - {"ResponseJSON", ResponseJSON(200, "desc", "head1", "head2"), 200, "desc", []string{"head1", "head2"}, "application/json"}, - {"ResponseError", ResponseJSON(200, "desc", "head1", "head2"), 200, "desc", []string{"head1", "head2"}, "application/json"}, + {"ResponseEmpty", Response(204, "desc", Headers("head1", "head2")), 204, "desc", []string{"head1", "head2"}, ""}, + {"ResponseText", ResponseText(200, "desc", Headers("head1", "head2")), 200, "desc", []string{"head1", "head2"}, "application/json"}, + {"ResponseJSON", ResponseJSON(200, "desc", Headers("head1", "head2")), 200, "desc", []string{"head1", "head2"}, "application/json"}, + {"ResponseError", ResponseJSON(200, "desc", Headers("head1", "head2")), 200, "desc", []string{"head1", "head2"}, "application/json"}, } func TestResponseFuncs(outer *testing.T) { for _, tt := range responseFuncsTable { local := tt outer.Run(fmt.Sprintf("%v", tt.n), func(t *testing.T) { - resp := local.resp + op := NewOperation() + local.resp.ApplyOperation(op) + resp := op.responses[0] assert.Equal(t, local.statusCode, resp.StatusCode) assert.Equal(t, local.description, resp.Description) assert.Equal(t, local.headers, resp.Headers) @@ -141,52 +146,29 @@ func TestOpenAPIHandler(t *testing.T) { Extra("x-foo", "bar"), ) - dep1 := &Dependency{ - Params: []*Param{ - QueryParam("q", "Test query param", ""), - }, - ResponseHeaders: []*ResponseHeader{ - Header("dep", "description"), - }, - Value: func(q string) (string, string, error) { - return "header", "foo", nil - }, - } + dep1 := Dependency(DependencyOptions( + QueryParam("q", "Test query param", ""), + ResponseHeader("dep", "description"), + ), func(q string) (string, string, error) { + return "header", "foo", nil + }) - dep2 := &Dependency{ - Dependencies: []*Dependency{dep1}, - Value: func(q string) (string, error) { - return q, nil - }, - } + dep2 := Dependency(dep1, func(q string) (string, error) { + return q, nil + }) - r.Register(http.MethodPut, "/hello", &Operation{ - ID: "put-hello", - Summary: "Summary message", - Description: "Get a welcome message", - Tags: []string{"Messages"}, - Security: SecurityRef("basic"), - Dependencies: []*Dependency{ - dep2, - }, - Params: []*Param{ - QueryParam("greet", "Whether to greet or not", false), - HeaderParamInternal("user", "User from auth token", ""), - }, - ResponseHeaders: []*ResponseHeader{ - Header("etag", "Content hash for caching"), - }, - Responses: []*Response{ - ResponseJSON(200, "Successful response", "etag"), - }, - Extra: map[string]interface{}{ - "x-foo": "bar", - }, - Handler: func(q string, greet bool, user string, body *HelloRequest) (string, *HelloResponse) { - return "etag", &HelloResponse{ - Message: "Hello", - } - }, + r.Resource("/hello", + dep2, + SecurityRef("basic"), + QueryParam("greet", "Whether to greet or not", false), + HeaderParam("user", "User from auth token", "", Internal()), + ResponseHeader("etag", "Content hash for caching"), + ResponseJSON(200, "Successful response", Headers("etag")), + Extra("x-foo", "bar"), + ).Put("Get a welcome message", func(q string, greet bool, user string, body *HelloRequest) (string, *HelloResponse) { + return "etag", &HelloResponse{ + Message: "Hello", + } }) w := httptest.NewRecorder() diff --git a/options.go b/options.go index 455bd93f..75c57e8d 100644 --- a/options.go +++ b/options.go @@ -1,9 +1,12 @@ package huma import ( + "fmt" "net/http" + "time" "github.com/Jeffail/gabs" + "github.com/danielgtaylor/huma/schema" "github.com/gin-gonic/gin" ) @@ -12,95 +15,266 @@ type RouterOption interface { ApplyRouter(r *Router) } +// routerOption is a shorthand struct used to create API options easily. +type routerOption struct { + handler func(*Router) +} + +func (o *routerOption) ApplyRouter(router *Router) { + o.handler(router) +} + // ResourceOption sets an option on the resource to be used in sub-resources // and operations. type ResourceOption interface { ApplyResource(r *Resource) } -// SharedOption sets an option on either a router/API or resource. -type SharedOption interface { - RouterOption +// resourceOption is a shorthand struct used to create resource options easily. +type resourceOption struct { + handler func(*Resource) +} + +func (o *resourceOption) ApplyResource(r *Resource) { + o.handler(r) +} + +// OperationOption sets an option on an operation or resource object. +type OperationOption interface { ResourceOption + ApplyOperation(o *OpenAPIOperation) } -type extraOption struct { - extra map[string]interface{} +// operationOption is a shorthand struct used to create operation options +// easily. Options created with it can be applied to either operations or +// resources. +type operationOption struct { + handler func(*OpenAPIOperation) } -func (o *extraOption) ApplyRouter(r *Router) { - for k, v := range o.extra { - r.api.Extra[k] = v - } +func (o *operationOption) ApplyResource(r *Resource) { + o.handler(r.OpenAPIOperation) +} + +func (o *operationOption) ApplyOperation(op *OpenAPIOperation) { + o.handler(op) +} + +// DependencyOption sets an option on a dependency, operation, or resource +// object. +type DependencyOption interface { + OperationOption + ApplyDependency(d *OpenAPIDependency) +} + +// dependencyOption is a shorthand struct used to create dependency options +// easily. Options created with it can be applied to dependencies, operations, +// and resources. +type dependencyOption struct { + handler func(*OpenAPIDependency) +} + +func (o *dependencyOption) ApplyResource(r *Resource) { + o.handler(r.OpenAPIDependency) +} + +func (o *dependencyOption) ApplyOperation(op *OpenAPIOperation) { + o.handler(op.OpenAPIDependency) +} + +func (o *dependencyOption) ApplyDependency(d *OpenAPIDependency) { + o.handler(d) +} + +// DependencyOptions composes together a set of options into one. +func DependencyOptions(options ...DependencyOption) DependencyOption { + return &dependencyOption{func(d *OpenAPIDependency) { + for _, option := range options { + option.ApplyDependency(d) + } + }} +} + +// ParamOption sets an option on an OpenAPI parameter. +type ParamOption interface { + ApplyParam(*OpenAPIParam) +} + +type paramOption struct { + apply func(*OpenAPIParam) +} + +func (o *paramOption) ApplyParam(p *OpenAPIParam) { + o.apply(p) } -func (o *extraOption) ApplyResource(r *Resource) { - // for k, v := range o.extra { - // r.extra[k] = v - // } +// ResponseHeaderOption sets an option on an OpenAPI response header. +type ResponseHeaderOption interface { + ApplyResponseHeader(*OpenAPIResponseHeader) +} + +// ResponseOption sets an option on an OpenAPI response. +type ResponseOption interface { + ApplyResponse(*OpenAPIResponse) +} + +type responseOption struct { + apply func(*OpenAPIResponse) +} + +func (o *responseOption) ApplyResponse(r *OpenAPIResponse) { + o.apply(r) +} + +// sharedOption sets an option on any combination of objects. +type sharedOption struct { + Set func(v interface{}) +} + +func (o *sharedOption) ApplyRouter(r *Router) { + o.Set(r) +} + +func (o *sharedOption) ApplyResource(r *Resource) { + o.Set(r) +} + +func (o *sharedOption) ApplyOperation(op *OpenAPIOperation) { + o.Set(op) +} + +func (o *sharedOption) ApplyParam(p *OpenAPIParam) { + o.Set(p) +} + +func (o *sharedOption) ApplyResponseHeader(r *OpenAPIResponseHeader) { + o.Set(r) +} + +func (o *sharedOption) ApplyResponse(r *OpenAPIResponse) { + o.Set(r) +} + +// Schema manually sets a JSON Schema on the object. If the top-level `type` is +// blank then the type will be guessed from the handler function. +func Schema(s *schema.Schema) interface { + ParamOption + ResponseHeaderOption + ResponseOption +} { + return &sharedOption{func(v interface{}) { + switch cast := v.(type) { + case *OpenAPIParam: + cast.Schema = s + case *OpenAPIResponseHeader: + cast.Schema = s + case *OpenAPIResponse: + cast.Schema = s + } + }} +} + +// SecurityRef adds a security reference by name with optional scopes. +func SecurityRef(name string, scopes ...string) interface { + RouterOption + OperationOption +} { + if scopes == nil { + scopes = []string{} + } + + return &sharedOption{ + Set: func(v interface{}) { + req := OpenAPISecurityRequirement{name: scopes} + + switch cast := v.(type) { + case *Router: + cast.api.Security = append(cast.api.Security, req) + case *Resource: + cast.security = append(cast.security, req) + case *OpenAPIOperation: + cast.security = append(cast.security, req) + } + }, + } } // Extra sets extra values in the generated OpenAPI 3 spec. -func Extra(pairs ...interface{}) SharedOption { +func Extra(pairs ...interface{}) interface { + RouterOption + OperationOption +} { extra := map[string]interface{}{} + if len(pairs)%2 > 0 { + panic(fmt.Errorf("requires key-value pairs but got: %v", pairs)) + } + for i := 0; i < len(pairs); i += 2 { k := pairs[i].(string) v := pairs[i+1] extra[k] = v } - return &extraOption{extra} -} + return &sharedOption{ + Set: func(v interface{}) { + var x map[string]interface{} -// routerOption is a shorthand struct used to create API options easily. -type routerOption struct { - handler func(*Router) -} + switch cast := v.(type) { + case *Router: + x = cast.api.Extra + case *Resource: + x = cast.extra + case *OpenAPIOperation: + x = cast.extra + } -func (o *routerOption) ApplyRouter(router *Router) { - o.handler(router) + for k, v := range extra { + x[k] = v + } + }, + } } // ProdServer sets the production server URL on the API. func ProdServer(url string) RouterOption { return &routerOption{func(r *Router) { - r.api.Servers = append(r.api.Servers, &Server{url, "Production server"}) + r.api.Servers = append(r.api.Servers, &OpenAPIServer{url, "Production server"}) }} } // DevServer sets the development server URL on the API. func DevServer(url string) RouterOption { return &routerOption{func(r *Router) { - r.api.Servers = append(r.api.Servers, &Server{url, "Development server"}) + r.api.Servers = append(r.api.Servers, &OpenAPIServer{url, "Development server"}) }} } // ContactFull sets the API contact information. func ContactFull(name, url, email string) RouterOption { return &routerOption{func(r *Router) { - r.api.Contact = &Contact{name, url, email} + r.api.Contact = &OpenAPIContact{name, url, email} }} } // ContactURL sets the API contact name & URL information. func ContactURL(name, url string) RouterOption { return &routerOption{func(r *Router) { - r.api.Contact = &Contact{Name: name, URL: url} + r.api.Contact = &OpenAPIContact{Name: name, URL: url} }} } // ContactEmail sets the API contact name & email information. func ContactEmail(name, email string) RouterOption { return &routerOption{func(r *Router) { - r.api.Contact = &Contact{Name: name, Email: email} + r.api.Contact = &OpenAPIContact{Name: name, Email: email} }} } // BasicAuth adds a named HTTP Basic Auth security scheme. func BasicAuth(name string) RouterOption { return &routerOption{func(r *Router) { - r.api.SecuritySchemes[name] = &SecurityScheme{ + r.api.SecuritySchemes[name] = &OpenAPISecurityScheme{ Type: "http", Scheme: "basic", } @@ -112,7 +286,7 @@ func BasicAuth(name string) RouterOption { // `header`, or `cookie`. func APIKeyAuth(name, keyName, in string) RouterOption { return &routerOption{func(r *Router) { - r.api.SecuritySchemes[name] = &SecurityScheme{ + r.api.SecuritySchemes[name] = &OpenAPISecurityScheme{ Type: "apiKey", Name: keyName, In: in, @@ -124,7 +298,7 @@ func APIKeyAuth(name, keyName, in string) RouterOption { // header. func JWTBearerAuth(name string) RouterOption { return &routerOption{func(r *Router) { - r.api.SecuritySchemes[name] = &SecurityScheme{ + r.api.SecuritySchemes[name] = &OpenAPISecurityScheme{ Type: "http", Scheme: "bearer", BearerFormat: "JWT", @@ -177,3 +351,169 @@ func OpenAPIHook(f func(*gabs.Container)) RouterOption { r.api.Hook = f }} } + +// SimpleDependency adds a new dependency with just a value or function. +func SimpleDependency(handler interface{}) DependencyOption { + dep := &OpenAPIDependency{ + handler: handler, + } + + return &dependencyOption{func(d *OpenAPIDependency) { + d.dependencies = append(d.dependencies, dep) + }} +} + +// Dependency adds a dependency. +func Dependency(option DependencyOption, handler interface{}) DependencyOption { + dep := NewDependency(option, handler) + return &dependencyOption{func(d *OpenAPIDependency) { + d.dependencies = append(d.dependencies, dep) + }} +} + +// Example sets an example value, used for documentation and mocks. +func Example(value interface{}) ParamOption { + return ¶mOption{func(p *OpenAPIParam) { + p.Example = value + }} +} + +// Internal marks this parameter as internal-only, meaning it will not be +// included in the OpenAPI 3 JSON. Useful for things like auth headers set +// by a load balancer / gateway. +func Internal() ParamOption { + return ¶mOption{func(p *OpenAPIParam) { + p.Internal = true + }} +} + +// Deprecated marks this parameter as deprecated. +func Deprecated() ParamOption { + return ¶mOption{func(p *OpenAPIParam) { + p.Deprecated = true + }} +} + +func newParamOption(name, description string, required bool, def interface{}, in ParamLocation, options ...ParamOption) DependencyOption { + p := NewOpenAPIParam(name, description, in, options...) + p.Required = required + p.def = def + + return &dependencyOption{func(d *OpenAPIDependency) { + d.params = append(d.params, p) + }} +} + +// PathParam adds a new required path parameter +func PathParam(name string, description string, options ...ParamOption) DependencyOption { + return newParamOption(name, description, true, nil, InPath, options...) +} + +// QueryParam returns a new optional query string parameter +func QueryParam(name string, description string, defaultValue interface{}, options ...ParamOption) DependencyOption { + return newParamOption(name, description, false, defaultValue, InQuery, options...) +} + +// HeaderParam returns a new optional header parameter +func HeaderParam(name string, description string, defaultValue interface{}, options ...ParamOption) DependencyOption { + return newParamOption(name, description, false, defaultValue, InHeader, options...) +} + +// ResponseHeader returns a new response header +func ResponseHeader(name, description string) DependencyOption { + r := &OpenAPIResponseHeader{ + Name: name, + Description: description, + } + + return &dependencyOption{func(d *OpenAPIDependency) { + d.responseHeaders = append(d.responseHeaders, r) + }} +} + +// OperationID manually sets the operation's unique ID. If not set, it will +// be auto-generated from the resource path and operation verb. +func OperationID(id string) OperationOption { + return &operationOption{func(o *OpenAPIOperation) { + o.id = id + }} +} + +// Tags sets one or more text tags on the operation. +func Tags(values ...string) OperationOption { + return &operationOption{func(o *OpenAPIOperation) { + o.tags = append(o.tags, values...) + }} +} + +// RequestContentType sets the request content type on the operation. +func RequestContentType(name string) OperationOption { + return &operationOption{func(o *OpenAPIOperation) { + o.requestContentType = name + }} +} + +// RequestSchema sets the request body schema on the operation. +func RequestSchema(schema *schema.Schema) OperationOption { + return &operationOption{func(o *OpenAPIOperation) { + o.requestSchema = schema + }} +} + +// ContentType sets the content type for this response. If blank, an empty +// response is returned. +func ContentType(value string) ResponseOption { + return &responseOption{func(r *OpenAPIResponse) { + r.ContentType = value + }} +} + +// Headers sets a list of allowed response headers. +func Headers(values ...string) ResponseOption { + return &responseOption{func(r *OpenAPIResponse) { + r.Headers = values + }} +} + +// Response adds a new response to the operation. +func Response(statusCode int, description string, options ...ResponseOption) OperationOption { + r := NewOpenAPIResponse(statusCode, description, options...) + + return &operationOption{func(o *OpenAPIOperation) { + o.responses = append(o.responses, r) + }} +} + +// ResponseText adds a new string response to the operation. +func ResponseText(statusCode int, description string, options ...ResponseOption) OperationOption { + options = append(options, ContentType("text/plain")) + return Response(statusCode, description, options...) +} + +// ResponseJSON adds a new JSON response model to the operation. +func ResponseJSON(statusCode int, description string, options ...ResponseOption) OperationOption { + options = append(options, ContentType("application/json")) + return Response(statusCode, description, options...) +} + +// ResponseError adds a new error response model. Alias for ResponseJSON +// mainly useful for documentation purposes. +func ResponseError(statusCode int, description string, options ...ResponseOption) OperationOption { + return ResponseJSON(statusCode, description, options...) +} + +// MaxBodyBytes sets the max number of bytes read from a request body before +// the handler aborts and returns an error. Applies to all sub-resources. +func MaxBodyBytes(value int64) OperationOption { + return &operationOption{func(o *OpenAPIOperation) { + o.maxBodyBytes = value + }} +} + +// BodyReadTimeout sets the duration after which the read is aborted and an +// error is returned. +func BodyReadTimeout(value time.Duration) OperationOption { + return &operationOption{func(o *OpenAPIOperation) { + o.bodyReadTimeout = value + }} +} diff --git a/resource.go b/resource.go index 932327ba..11c22b98 100644 --- a/resource.go +++ b/resource.go @@ -1,40 +1,34 @@ package huma import ( - "fmt" "net/http" "reflect" "strings" - "time" ) // Resource describes a REST resource at a given URI path. Resources are // typically created from a router or as a sub-resource of an existing resource. type Resource struct { - router *Router - path string - deps []*Dependency - security []SecurityRequirement - params []*Param - responseHeaders []*ResponseHeader - responses []*Response - maxBodyBytes int64 - bodyReadTimeout time.Duration + *OpenAPIOperation + router *Router + path string } // NewResource creates a new resource with the given router and path. All // dependencies, security requirements, params, headers, and responses are // empty. -func NewResource(router *Router, path string) *Resource { - return &Resource{ - router: router, - path: path, - deps: make([]*Dependency, 0), - security: make([]SecurityRequirement, 0), - params: make([]*Param, 0), - responseHeaders: make([]*ResponseHeader, 0), - responses: make([]*Response, 0), +func NewResource(router *Router, path string, options ...ResourceOption) *Resource { + r := &Resource{ + OpenAPIOperation: NewOperation(), + router: router, + path: path, + } + + for _, option := range options { + option.ApplyResource(r) } + + return r } // Copy the resource. New arrays are created for dependencies, security @@ -42,60 +36,24 @@ func NewResource(router *Router, path string) *Resource { // pointer values themselves are the same. func (r *Resource) Copy() *Resource { return &Resource{ - router: r.router, - path: r.path, - deps: append([]*Dependency{}, r.deps...), - security: append([]SecurityRequirement{}, r.security...), - params: append([]*Param{}, r.params...), - responseHeaders: append([]*ResponseHeader{}, r.responseHeaders...), - responses: append([]*Response{}, r.responses...), - maxBodyBytes: r.maxBodyBytes, - bodyReadTimeout: r.bodyReadTimeout, + OpenAPIOperation: r.OpenAPIOperation.Copy(), + router: r.router, + path: r.path, } } // With returns a copy of this resource with the given dependencies, security // requirements, params, response headers, or responses added to it. -func (r *Resource) With(depsParamHeadersOrResponses ...interface{}) *Resource { +func (r *Resource) With(options ...ResourceOption) *Resource { c := r.Copy() - // For each input, determine which type it is and store it. - for _, dph := range depsParamHeadersOrResponses { - switch v := dph.(type) { - case *Dependency: - c.deps = append(c.deps, v) - case []SecurityRequirement: - c.security = v - case SecurityRequirement: - c.security = append(c.security, v) - case *Param: - c.params = append(c.params, v) - case *ResponseHeader: - c.responseHeaders = append(c.responseHeaders, v) - case *Response: - c.responses = append(c.responses, v) - default: - panic(fmt.Errorf("unsupported type %v", v)) - } + for _, option := range options { + option.ApplyResource(c) } return c } -// MaxBodyBytes sets the max number of bytes read from a request body before -// the handler aborts and returns an error. Applies to all sub-resources. -func (r *Resource) MaxBodyBytes(value int64) *Resource { - r.maxBodyBytes = value - return r -} - -// BodyReadTimeout sets the duration after which the read is aborted and an -// error is returned. -func (r *Resource) BodyReadTimeout(value time.Duration) *Resource { - r.bodyReadTimeout = value - return r -} - // Path returns the generated path including any path parameters. func (r *Resource) Path() string { generated := r.path @@ -117,7 +75,7 @@ func (r *Resource) Path() string { // SubResource creates a new resource at the given path, which is appended // to the existing resource path after adding any existing path parameters. -func (r *Resource) SubResource(path string, depsParamHeadersOrResponses ...interface{}) *Resource { +func (r *Resource) SubResource(path string, options ...ResourceOption) *Resource { // Apply all existing params to the path. newPath := r.Path() @@ -131,7 +89,7 @@ func (r *Resource) SubResource(path string, depsParamHeadersOrResponses ...inter newPath += path // Clone the resource and update the path. - c := r.With(depsParamHeadersOrResponses...) + c := r.With(options...) c.path = newPath return c @@ -139,16 +97,16 @@ func (r *Resource) SubResource(path string, depsParamHeadersOrResponses ...inter // Operation adds the operation to this resource's router with all the // combined deps, security requirements, params, headers, responses, etc. -func (r *Resource) Operation(method string, op *Operation) { +func (r *Resource) operation(method string, op *OpenAPIOperation) { // Set params, etc - allDeps := append([]*Dependency{}, r.deps...) - allDeps = append(allDeps, op.Dependencies...) - op.Dependencies = allDeps + allDeps := append([]*OpenAPIDependency{}, r.dependencies...) + allDeps = append(allDeps, op.dependencies...) + op.dependencies = allDeps // Combine resource and operation params. Update path with any required // path parameters if they are not yet present. - allParams := append([]*Param{}, r.params...) - allParams = append(allParams, op.Params...) + allParams := append([]*OpenAPIParam{}, r.params...) + allParams = append(allParams, op.params...) path := r.path for _, p := range allParams { if p.In == "path" { @@ -161,67 +119,47 @@ func (r *Resource) Operation(method string, op *Operation) { } } } - op.Params = allParams + op.params = allParams - allHeaders := append([]*ResponseHeader{}, r.responseHeaders...) - allHeaders = append(allHeaders, op.ResponseHeaders...) - op.ResponseHeaders = allHeaders + allHeaders := append([]*OpenAPIResponseHeader{}, r.responseHeaders...) + allHeaders = append(allHeaders, op.responseHeaders...) + op.responseHeaders = allHeaders - allResponses := append([]*Response{}, r.responses...) - allResponses = append(allResponses, op.Responses...) - op.Responses = allResponses + allResponses := append([]*OpenAPIResponse{}, r.responses...) + allResponses = append(allResponses, op.responses...) + op.responses = allResponses - if op.Handler != nil { - t := reflect.TypeOf(op.Handler) - if t.NumOut() == len(op.ResponseHeaders)+len(op.Responses)+1 { + if op.handler != nil { + t := reflect.TypeOf(op.handler) + if t.NumOut() == len(op.responseHeaders)+len(op.responses)+1 { rtype := t.Out(t.NumOut() - 1) switch rtype.Kind() { case reflect.Bool: - op.Responses = append(op.Responses, ResponseEmpty(http.StatusNoContent, "Success")) + op = op.With(Response(http.StatusNoContent, "Success")) case reflect.String: - op.Responses = append(op.Responses, ResponseText(http.StatusOK, "Success")) + op = op.With(ResponseText(http.StatusOK, "Success")) default: - op.Responses = append(op.Responses, ResponseJSON(http.StatusOK, "Success")) + op = op.With(ResponseJSON(http.StatusOK, "Success")) } } } - if op.MaxBodyBytes == 0 { - op.MaxBodyBytes = r.maxBodyBytes + if op.maxBodyBytes == 0 { + op.maxBodyBytes = r.maxBodyBytes } - if op.BodyReadTimeout == 0 { - op.BodyReadTimeout = r.bodyReadTimeout + if op.bodyReadTimeout == 0 { + op.bodyReadTimeout = r.bodyReadTimeout } r.router.Register(method, path, op) } -// Text is shorthand for `r.With(huma.ResponseText(...))`. -func (r *Resource) Text(statusCode int, description string, headers ...string) *Resource { - return r.With(ResponseText(statusCode, description, headers...)) -} - -// JSON is shorthand for `r.With(huma.ResponseJSON(...))`. -func (r *Resource) JSON(statusCode int, description string, headers ...string) *Resource { - return r.With(ResponseJSON(statusCode, description, headers...)) -} - -// NoContent is shorthand for `r.With(huma.ResponseEmpty(http.StatusNoContent, ...)` -func (r *Resource) NoContent(description string, headers ...string) *Resource { - return r.With(ResponseEmpty(http.StatusNoContent, description, headers...)) -} - -// Empty is shorthand for `r.With(huma.ResponseEmpty(...))`. -func (r *Resource) Empty(statusCode int, description string, headers ...string) *Resource { - return r.With(ResponseEmpty(statusCode, description, headers...)) -} - // Head creates an HTTP HEAD operation on the resource. func (r *Resource) Head(description string, handler interface{}) { - r.Operation(http.MethodHead, &Operation{ - Description: description, - Handler: handler, + r.operation(http.MethodHead, &OpenAPIOperation{ + description: description, + OpenAPIDependency: &OpenAPIDependency{handler: handler}, }) } @@ -232,40 +170,40 @@ func (r *Resource) List(description string, handler interface{}) { // Get creates an HTTP GET operation on the resource. func (r *Resource) Get(description string, handler interface{}) { - r.Operation(http.MethodGet, &Operation{ - Description: description, - Handler: handler, + r.operation(http.MethodGet, &OpenAPIOperation{ + description: description, + OpenAPIDependency: &OpenAPIDependency{handler: handler}, }) } // Post creates an HTTP POST operation on the resource. func (r *Resource) Post(description string, handler interface{}) { - r.Operation(http.MethodPost, &Operation{ - Description: description, - Handler: handler, + r.operation(http.MethodPost, &OpenAPIOperation{ + description: description, + OpenAPIDependency: &OpenAPIDependency{handler: handler}, }) } // Put creates an HTTP PUT operation on the resource. func (r *Resource) Put(description string, handler interface{}) { - r.Operation(http.MethodPut, &Operation{ - Description: description, - Handler: handler, + r.operation(http.MethodPut, &OpenAPIOperation{ + description: description, + OpenAPIDependency: &OpenAPIDependency{handler: handler}, }) } // Patch creates an HTTP PATCH operation on the resource. func (r *Resource) Patch(description string, handler interface{}) { - r.Operation(http.MethodPatch, &Operation{ - Description: description, - Handler: handler, + r.operation(http.MethodPatch, &OpenAPIOperation{ + description: description, + OpenAPIDependency: &OpenAPIDependency{handler: handler}, }) } // Delete creates an HTTP DELETE operation on the resource. func (r *Resource) Delete(description string, handler interface{}) { - r.Operation(http.MethodDelete, &Operation{ - Description: description, - Handler: handler, + r.operation(http.MethodDelete, &OpenAPIOperation{ + description: description, + OpenAPIDependency: &OpenAPIDependency{handler: handler}, }) } diff --git a/resource_test.go b/resource_test.go index d61a751e..ce285c6b 100644 --- a/resource_test.go +++ b/resource_test.go @@ -13,44 +13,38 @@ func TestResourceCopy(t *testing.T) { r1 := NewResource(nil, "/test") r2 := r1.Copy() - assert.NotSame(t, r1.deps, r2.deps) + assert.NotSame(t, r1.dependencies, r2.dependencies) assert.NotSame(t, r1.params, r2.params) assert.NotSame(t, r1.responseHeaders, r2.responseHeaders) assert.NotSame(t, r1.responses, r2.responses) } -func TestResourceWithBadInput(t *testing.T) { - assert.Panics(t, func() { - NewResource(nil, "/test").With("bad-value") - }) -} - func TestResourceWithDep(t *testing.T) { - dep1 := &Dependency{Value: "dep1"} - dep2 := &Dependency{Value: "dep2"} + dep1 := SimpleDependency("dep1") + dep2 := SimpleDependency("dep2") r1 := NewResource(nil, "/test") r2 := r1.With(dep1) r3 := r1.With(dep2) - assert.Contains(t, r2.deps, dep1) - assert.NotContains(t, r2.deps, dep2) - assert.Contains(t, r3.deps, dep2) - assert.NotContains(t, r3.deps, dep1) + assert.NotEmpty(t, r2.dependencies) + assert.NotEmpty(t, r3.dependencies) + + assert.NotSame(t, r2.dependencies[0], r3.dependencies[0]) } func TestResourceWithSecurity(t *testing.T) { sec1 := SecurityRef("sec1") - sec2 := SecurityRef("sec2")[0] + sec2 := SecurityRef("sec2") r1 := NewResource(nil, "/test") r2 := r1.With(sec1) r3 := r1.With(sec2) - assert.Equal(t, r2.security, sec1) - assert.NotContains(t, r2.security, sec2) - assert.Contains(t, r3.security, sec2) - assert.NotEqual(t, r3.security, sec1) + assert.NotEmpty(t, r2.security) + assert.NotEmpty(t, r3.security) + + assert.NotSame(t, r2.security[0], r3.security[0]) } func TestResourceWithParam(t *testing.T) { @@ -61,27 +55,27 @@ func TestResourceWithParam(t *testing.T) { r2 := r1.With(param1) r3 := r1.With(param2) - assert.Contains(t, r2.params, param1) - assert.NotContains(t, r2.params, param2) - assert.Contains(t, r3.params, param2) - assert.NotContains(t, r3.params, param1) + assert.NotEmpty(t, r2.params) + assert.NotEmpty(t, r3.params) + + assert.NotSame(t, r2.params[0], r3.params[0]) assert.Equal(t, "/test/{p1}", r2.Path()) assert.Equal(t, "/test/{p2}", r3.Path()) } func TestResourceWithHeader(t *testing.T) { - header1 := Header("h1", "desc") - header2 := Header("h2", "desc") + header1 := ResponseHeader("h1", "desc") + header2 := ResponseHeader("h2", "desc") r1 := NewResource(nil, "/test") r2 := r1.With(header1) r3 := r1.With(header2) - assert.Contains(t, r2.responseHeaders, header1) - assert.NotContains(t, r2.responseHeaders, header2) - assert.Contains(t, r3.responseHeaders, header2) - assert.NotContains(t, r3.responseHeaders, header1) + assert.NotEmpty(t, r2.responseHeaders) + assert.NotEmpty(t, r3.responseHeaders) + + assert.NotSame(t, r2.responseHeaders[0], r3.responseHeaders[0]) } func TestResourceWithResponse(t *testing.T) { @@ -92,10 +86,10 @@ func TestResourceWithResponse(t *testing.T) { r2 := r1.With(resp1) r3 := r1.With(resp2) - assert.Contains(t, r2.responses, resp1) - assert.NotContains(t, r2.responses, resp2) - assert.Contains(t, r3.responses, resp2) - assert.NotContains(t, r3.responses, resp1) + assert.NotEmpty(t, r2.responses) + assert.NotEmpty(t, r3.responses) + + assert.NotSame(t, r2.responses[0], r3.responses[0]) } func TestSubResource(t *testing.T) { @@ -137,7 +131,7 @@ func TestResourceFuncs(outer *testing.T) { local := tt outer.Run(fmt.Sprintf("%v", tt), func(t *testing.T) { r := NewTestRouter(t) - res := NewResource(r, "/test").Text(http.StatusOK, "desc") + res := NewResource(r, "/test") var f func(string, interface{}) @@ -180,34 +174,6 @@ var resourceShorthandFuncs = []struct { {"Empty", http.StatusNotModified, "", "desc"}, } -func TestResourceShorthandFuncs(outer *testing.T) { - for _, tt := range resourceShorthandFuncs { - local := tt - outer.Run(fmt.Sprintf("%v", local.n), func(t *testing.T) { - r := NewTestRouter(t) - res := NewResource(r, "/test") - - switch local.n { - case "Text": - res = res.Text(local.statusCode, local.desc, "header") - case "JSON": - res = res.JSON(local.statusCode, local.desc, "header") - case "NoContent": - res = res.NoContent(local.desc, "header") - case "Empty": - res = res.Empty(local.statusCode, local.desc, "header") - default: - panic("invalid case " + local.n) - } - - resp := res.responses[0] - assert.Equal(t, local.statusCode, resp.StatusCode) - assert.Equal(t, local.contentType, resp.ContentType) - assert.Equal(t, local.desc, resp.Description) - }) - } -} - func TestResourceAutoJSON(t *testing.T) { r := NewTestRouter(t) @@ -218,8 +184,8 @@ func TestResourceAutoJSON(t *testing.T) { return &MyResponse{} }) - assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].Responses[0].StatusCode) - assert.Equal(t, "application/json", r.api.Paths["/test"][http.MethodGet].Responses[0].ContentType) + assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].responses[0].StatusCode) + assert.Equal(t, "application/json", r.api.Paths["/test"][http.MethodGet].responses[0].ContentType) } func TestResourceAutoText(t *testing.T) { @@ -230,8 +196,8 @@ func TestResourceAutoText(t *testing.T) { return "Hello, world" }) - assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].Responses[0].StatusCode) - assert.Equal(t, "text/plain", r.api.Paths["/test"][http.MethodGet].Responses[0].ContentType) + assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].responses[0].StatusCode) + assert.Equal(t, "text/plain", r.api.Paths["/test"][http.MethodGet].responses[0].ContentType) } func TestResourceAutoNoContent(t *testing.T) { @@ -242,7 +208,6 @@ func TestResourceAutoNoContent(t *testing.T) { return true }) - assert.Equal(t, http.StatusNoContent, r.api.Paths["/test"][http.MethodGet].Responses[0].StatusCode) - assert.Equal(t, "", r.api.Paths["/test"][http.MethodGet].Responses[0].ContentType) - assert.Equal(t, true, r.api.Paths["/test"][http.MethodGet].Responses[0].empty) + assert.Equal(t, http.StatusNoContent, r.api.Paths["/test"][http.MethodGet].responses[0].StatusCode) + assert.Equal(t, "", r.api.Paths["/test"][http.MethodGet].responses[0].ContentType) } diff --git a/router.go b/router.go index 44c9afdf..89cb52bc 100644 --- a/router.go +++ b/router.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/danielgtaylor/huma/schema" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" "github.com/spf13/cobra" @@ -30,6 +31,8 @@ var ErrInvalidParamLocation = errors.New("invalid parameter location") // context value. var ConnContextKey = struct{}{} +var timeType = reflect.TypeOf(time.Time{}) + // GetConn gets the underlying `net.Conn` from a request. func GetConn(r *http.Request) net.Conn { conn := r.Context().Value(ConnContextKey) @@ -40,7 +43,7 @@ func GetConn(r *http.Request) net.Conn { } // Checks if data validates against the given schema. Returns false on failure. -func validAgainstSchema(c *gin.Context, label string, schema *Schema, data []byte) bool { +func validAgainstSchema(c *gin.Context, label string, schema *schema.Schema, data []byte) bool { defer func() { // Catch panics from the `gojsonschema` library. if err := recover(); err != nil { @@ -149,7 +152,7 @@ func parseParamValue(c *gin.Context, name string, typ reflect.Type, pstr string) return pv, true } -func getParamValue(c *gin.Context, param *Param) (interface{}, bool) { +func getParamValue(c *gin.Context, param *OpenAPIParam) (interface{}, bool) { var pstr string switch param.In { case InPath: @@ -198,18 +201,18 @@ func getParamValue(c *gin.Context, param *Param) (interface{}, bool) { return pv, true } -func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{}, bool) { +func getRequestBody(c *gin.Context, t reflect.Type, op *OpenAPIOperation) (interface{}, bool) { val := reflect.New(t).Interface() - if op.RequestSchema != nil { + if op.requestSchema != nil { body, err := ioutil.ReadAll(c.Request.Body) if err != nil { if strings.Contains(err.Error(), "request body too large") { c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, ErrorModel{ - Message: fmt.Sprintf("Request body too large, limit = %d bytes", op.MaxBodyBytes), + Message: fmt.Sprintf("Request body too large, limit = %d bytes", op.maxBodyBytes), }) } else if e, ok := err.(net.Error); ok && e.Timeout() { c.AbortWithStatusJSON(http.StatusRequestTimeout, ErrorModel{ - Message: fmt.Sprintf("Request body took too long to read: timed out after %v", op.BodyReadTimeout), + Message: fmt.Sprintf("Request body took too long to read: timed out after %v", op.bodyReadTimeout), }) } else { panic(err) @@ -219,7 +222,7 @@ func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{}, c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) - if !validAgainstSchema(c, "request body", op.RequestSchema, body) { + if !validAgainstSchema(c, "request body", op.requestSchema, body) { // Error already handled, just return. return nil, false } @@ -274,10 +277,10 @@ func NewRouter(docs, version string, options ...RouterOption) *Router { Title: title, Description: desc, Version: version, - Servers: make([]*Server, 0), - SecuritySchemes: make(map[string]*SecurityScheme, 0), - Security: make([]SecurityRequirement, 0), - Paths: make(map[string]map[string]*Operation), + Servers: make([]*OpenAPIServer, 0), + SecuritySchemes: make(map[string]*OpenAPISecurityScheme, 0), + Security: make([]OpenAPISecurityRequirement, 0), + Paths: make(map[string]map[string]*OpenAPIOperation), Extra: make(map[string]interface{}), }, engine: g, @@ -298,7 +301,7 @@ func NewRouter(docs, version string, options ...RouterOption) *Router { } // Set up handlers for the auto-generated spec and docs. - r.engine.GET("/openapi.json", OpenAPIHandler(r.api)) + r.engine.GET("/openapi.json", openAPIHandler(r.api)) r.engine.GET("/docs", func(c *gin.Context) { r.docsHandler(c, r.api) @@ -330,19 +333,19 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Resource creates a new resource at the given path with the given // dependencies, parameters, response headers, and responses defined. -func (r *Router) Resource(path string, depsParamsHeadersOrResponses ...interface{}) *Resource { - return NewResource(r, path).With(depsParamsHeadersOrResponses...) +func (r *Router) Resource(path string, options ...ResourceOption) *Resource { + return NewResource(r, path).With(options...) } // Register a new operation. -func (r *Router) Register(method, path string, op *Operation) { +func (r *Router) Register(method, path string, op *OpenAPIOperation) { // First, make sure the operation and handler make sense, as well as pre- // generating any schemas for use later during request handling. op.validate(method, path) // Add the operation to the list of operations for the path entry. if r.api.Paths[path] == nil { - r.api.Paths[path] = make(map[string]*Operation) + r.api.Paths[path] = make(map[string]*OpenAPIOperation) } r.api.Paths[path][method] = op @@ -376,12 +379,12 @@ func (r *Router) Register(method, path string, op *Operation) { // Then call it to register our handler function. f(path, func(c *gin.Context) { - method := reflect.ValueOf(op.Handler) + method := reflect.ValueOf(op.handler) in := make([]reflect.Value, 0, method.Type().NumIn()) // Limit the body size if c.Request.Body != nil { - maxBody := op.MaxBodyBytes + maxBody := op.maxBodyBytes if maxBody == 0 { // 1 MiB default maxBody = 1024 * 1024 @@ -394,8 +397,8 @@ func (r *Router) Register(method, path string, op *Operation) { } // Process any dependencies first. - for _, dep := range op.Dependencies { - headers, value, err := dep.Resolve(c, op) + for _, dep := range op.dependencies { + headers, value, err := dep.resolve(c, op) if err != nil { if !c.IsAborted() { // Nothing else has handled the error, so treat it like a general @@ -413,7 +416,7 @@ func (r *Router) Register(method, path string, op *Operation) { in = append(in, reflect.ValueOf(value)) } - for _, param := range op.Params { + for _, param := range op.params { pv, ok := getParamValue(c, param) if !ok { // Error has already been handled. @@ -423,7 +426,7 @@ func (r *Router) Register(method, path string, op *Operation) { in = append(in, reflect.ValueOf(pv)) } - readTimeout := op.BodyReadTimeout + readTimeout := op.bodyReadTimeout if len(in) != method.Type().NumIn() { if readTimeout == 0 { // Default to 15s when reading/parsing/validating automatically. @@ -458,14 +461,14 @@ func (r *Router) Register(method, path string, op *Operation) { // from the registered `huma.Response` struct. // This breaks down with scalar types... so they need to be passed // as a pointer and we'll dereference it automatically. - for i, o := range out[len(op.ResponseHeaders):] { + for i, o := range out[len(op.responseHeaders):] { if !o.IsZero() { body := o.Interface() - r := op.Responses[i] + r := op.responses[i] // Set response headers - for j, header := range op.ResponseHeaders { + for j, header := range op.responseHeaders { value := out[j] found := false @@ -498,7 +501,7 @@ func (r *Router) Register(method, path string, op *Operation) { } } - if r.empty { + if r.ContentType == "" { // No body allowed, e.g. for HTTP 204. c.Status(r.StatusCode) break diff --git a/router_test.go b/router_test.go index fc012481..1560d104 100644 --- a/router_test.go +++ b/router_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/danielgtaylor/huma/schema" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -53,16 +54,10 @@ func BenchmarkGin(b *testing.B) { func BenchmarkHuma(b *testing.B) { r := NewRouter("Benchmark test", "1.0.0", WithGin(gin.New())) - r.Register(http.MethodGet, "/hello", &Operation{ - Description: "Greet the world", - Responses: []*Response{ - ResponseJSON(200, "Return a greeting"), - }, - Handler: func() *helloResponse { - return &helloResponse{ - Message: "Hello, world", - } - }, + r.Resource("/hello").Get("Greet the world", func() *helloResponse { + return &helloResponse{ + Message: "Hello, world", + } }) b.ResetTimer() @@ -120,56 +115,35 @@ func BenchmarkGinComplex(b *testing.B) { func BenchmarkHumaComplex(b *testing.B) { r := NewRouter("Benchmark test", "1.0.0", WithGin(gin.New())) - dep1 := &Dependency{ - Value: "dep1", - } + dep1 := SimpleDependency("dep1") - dep2 := &Dependency{ - Dependencies: []*Dependency{ContextDependency(), dep1}, - Params: []*Param{ - HeaderParam("x-foo", "desc", ""), - }, - Value: func(c *gin.Context, d1 string, xfoo string) (string, error) { - return "dep2", nil - }, - } + dep2 := Dependency(DependencyOptions( + ContextDependency(), dep1, HeaderParam("x-foo", "desc", ""), + ), func(c *gin.Context, d1 string, xfoo string) (string, error) { + return "dep2", nil + }) - dep3 := &Dependency{ - Dependencies: []*Dependency{dep1}, - ResponseHeaders: []*ResponseHeader{ - Header("x-bar", "desc"), - }, - Value: func(d1 string) (string, string, error) { - return "xbar", "dep3", nil - }, - } + dep3 := Dependency(DependencyOptions( + dep1, ResponseHeader("x-bar", "desc"), + ), func(d1 string) (string, string, error) { + return "xbar", "dep3", nil + }) - r.Register(http.MethodGet, "/hello", &Operation{ - Description: "Greet the world", - Dependencies: []*Dependency{ - ContextDependency(), dep2, dep3, - }, - Params: []*Param{ - QueryParam("name", "desc", "world"), - }, - ResponseHeaders: []*ResponseHeader{ - Header("x-baz", "desc"), - }, - Responses: []*Response{ - ResponseJSON(200, "Return a greeting", "x-baz"), - ResponseError(500, "desc"), - }, - Handler: func(c *gin.Context, d2, d3, name string) (string, *helloResponse, *ErrorModel) { - if name == "test" { - return "", nil, &ErrorModel{ - Message: "Name cannot be test", - } + r.Resource("/hello", dep1, dep2, dep3, + QueryParam("name", "desc", "world"), + ResponseHeader("x-baz", "desc"), + ResponseJSON(200, "Return a greeting", Headers("x-baz")), + ResponseError(500, "desc"), + ).Get("Greet the world", func(c *gin.Context, d2, d3, name string) (string, *helloResponse, *ErrorModel) { + if name == "test" { + return "", nil, &ErrorModel{ + Message: "Name cannot be test", } + } - return "xbaz", &helloResponse{ - Message: "Hello, " + name, - }, nil - }, + return "xbaz", &helloResponse{ + Message: "Hello, " + name, + }, nil }) b.ResetTimer() @@ -193,28 +167,22 @@ func TestRouter(t *testing.T) { r := NewTestRouter(t) - r.Register(http.MethodPut, "/echo/{word}", &Operation{ - Description: "Echo back an input word.", - Params: []*Param{ - PathParam("word", "The word to echo back"), - QueryParam("greet", "Return a greeting", false), - }, - Responses: []*Response{ - ResponseJSON(http.StatusOK, "Successful echo response"), - ResponseError(http.StatusBadRequest, "Invalid input"), - }, - Handler: func(word string, greet bool) (*EchoResponse, *ErrorModel) { - if word == "test" { - return nil, &ErrorModel{Message: "Value not allowed: test"} - } + r.Resource("/echo", + PathParam("word", "The word to echo back"), + QueryParam("greet", "Return a greeting", false), + ResponseJSON(http.StatusOK, "Successful echo response"), + ResponseError(http.StatusBadRequest, "Invalid input"), + ).Put("Echo back an input word.", func(word string, greet bool) (*EchoResponse, *ErrorModel) { + if word == "test" { + return nil, &ErrorModel{Message: "Value not allowed: test"} + } - v := word - if greet { - v = "Hello, " + word - } + v := word + if greet { + v = "Hello, " + word + } - return &EchoResponse{Value: v}, nil - }, + return &EchoResponse{Value: v}, nil }) w := httptest.NewRecorder() @@ -257,14 +225,8 @@ func TestRouterRequestBody(t *testing.T) { r := NewTestRouter(t) - r.Register(http.MethodPut, "/echo", &Operation{ - Description: "Echo back an input word.", - Responses: []*Response{ - ResponseJSON(http.StatusOK, "Successful echo response"), - }, - Handler: func(in *EchoRequest) *EchoResponse { - return &EchoResponse{Value: in.Value} - }, + r.Resource("/echo").Put("Echo back an input word.", func(in *EchoRequest) *EchoResponse { + return &EchoResponse{Value: in.Value} }) w := httptest.NewRecorder() @@ -283,14 +245,8 @@ func TestRouterRequestBody(t *testing.T) { func TestRouterScalarResponse(t *testing.T) { r := NewTestRouter(t) - r.Register(http.MethodPut, "/hello", &Operation{ - Description: "Say hello.", - Responses: []*Response{ - ResponseText(http.StatusOK, "Successful hello response"), - }, - Handler: func() string { - return "hello" - }, + r.Resource("/hello").Put("Say hello", func() string { + return "hello" }) w := httptest.NewRecorder() @@ -304,15 +260,9 @@ func TestRouterScalarResponse(t *testing.T) { func TestRouterZeroScalarResponse(t *testing.T) { r := NewTestRouter(t) - r.Register(http.MethodPut, "/bool", &Operation{ - Description: "Say hello.", - Responses: []*Response{ - ResponseText(http.StatusOK, "Successful zero bool response"), - }, - Handler: func() *bool { - resp := false - return &resp - }, + r.Resource("/bool").Put("Bool response", func() *bool { + resp := false + return &resp }) w := httptest.NewRecorder() @@ -320,27 +270,21 @@ func TestRouterZeroScalarResponse(t *testing.T) { r.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "false", w.Body.String()) + assert.Equal(t, "false\n", w.Body.String()) } func TestRouterResponseHeaders(t *testing.T) { r := NewTestRouter(t) - r.Register(http.MethodGet, "/test", &Operation{ - Description: "Test operation", - ResponseHeaders: []*ResponseHeader{ - Header("Etag", "Identifies a specific version of this resource"), - Header("X-Test", "Custom test header"), - Header("X-Missing", "Won't get sent"), - }, - Responses: []*Response{ - ResponseText(http.StatusOK, "Successful test", "Etag", "X-Test", "X-Missing"), - ResponseError(http.StatusBadRequest, "Error example", "X-Test"), - }, - Handler: func() (etag string, xTest *string, xMissing string, success string, fail string) { - test := "test" - return "\"abc123\"", &test, "", "hello", "" - }, + r.Resource("/test", + ResponseHeader("Etag", "Identifies a specific version of this resource"), + ResponseHeader("X-Test", "Custom test header"), + ResponseHeader("X-Missing", "Won't get sent"), + ResponseText(http.StatusOK, "Successful test", Headers("Etag", "X-Test", "X-Missing")), + ResponseError(http.StatusBadRequest, "Error example", Headers("X-Test")), + ).Get("Test operation", func() (etag string, xTest *string, xMissing string, success string, fail string) { + test := "test" + return "\"abc123\"", &test, "", "hello", "" }) w := httptest.NewRecorder() @@ -362,11 +306,9 @@ func TestRouterDependencies(t *testing.T) { } // Datastore is a global dependency, set by value. - db := &Dependency{ - Value: &DB{ - Get: func() string { - return "Hello, " - }, + db := &DB{ + Get: func() string { + return "Hello, " }, } @@ -376,35 +318,25 @@ func TestRouterDependencies(t *testing.T) { // Logger is a contextual instance from the gin request context. captured := "" - log := &Dependency{ - Dependencies: []*Dependency{ - GinContextDependency(), - }, - Value: func(c *gin.Context) (*Logger, error) { - return &Logger{ - Log: func(msg string) { - captured = fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()) - }, - }, nil - }, - } + log := Dependency(GinContextDependency(), func(c *gin.Context) (*Logger, error) { + return &Logger{ + Log: func(msg string) { + captured = fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()) + }, + }, nil + }) - r.Register(http.MethodGet, "/hello", &Operation{ - Description: "Basic hello world", - Dependencies: []*Dependency{GinContextDependency(), db, log}, - Params: []*Param{ - QueryParam("name", "Your name", ""), - }, - Responses: []*Response{ - ResponseText(http.StatusOK, "Successful hello response"), - }, - Handler: func(c *gin.Context, db *DB, l *Logger, name string) string { - if name == "" { - name = c.Request.RemoteAddr - } - l.Log("Hello logger!") - return db.Get() + name - }, + r.Resource("/hello", + GinContextDependency(), + SimpleDependency(db), + log, + QueryParam("name", "Your name", ""), + ).Get("Basic hello world", func(c *gin.Context, db *DB, l *Logger, name string) string { + if name == "" { + name = c.Request.RemoteAddr + } + l.Log("Hello logger!") + return db.Get() + name }) w := httptest.NewRecorder() @@ -421,7 +353,7 @@ func TestRouterBadHeader(t *testing.T) { g := gin.New() g.Use(LogMiddleware(l, nil)) r := NewRouter("Test API", "1.0.0", WithGin(g)) - r.Resource("/test", Header("foo", "desc"), ResponseError(http.StatusBadRequest, "desc", "foo")).Get("desc", func() (string, *ErrorModel, string) { + r.Resource("/test", ResponseHeader("foo", "desc"), ResponseError(http.StatusBadRequest, "desc", Headers("foo"))).Get("desc", func() (string, *ErrorModel, string) { return "header-value", nil, "response" }) @@ -441,7 +373,7 @@ func TestRouterParams(t *testing.T) { QueryParam("i", "desc", int16(0)), QueryParam("f32", "desc", float32(0.0)), QueryParam("f64", "desc", 0.0), - QueryParam("schema", "desc", "test", &Schema{Pattern: "^a-z+$"}), + QueryParam("schema", "desc", "test", Schema(&schema.Schema{Pattern: "^a-z+$"})), QueryParam("items", "desc", []int{}), QueryParam("start", "desc", time.Time{}), ).Get("desc", func(id string, i int16, f32 float32, f64 float64, schema string, items []int, start time.Time) string { @@ -501,8 +433,11 @@ func TestRouterParams(t *testing.T) { func TestInvalidParamLocation(t *testing.T) { r := NewTestRouter(t) + test := r.Resource("/test", PathParam("name", "desc")) + test.params[len(test.params)-1].In = "bad'" + assert.Panics(t, func() { - r.Resource("/test", &Param{Name: "test", In: "bad"}).Get("desc", func(test string) string { + test.Get("desc", func(test string) string { return "Hello, test!" }) }) @@ -523,7 +458,7 @@ func TestTooBigBody(t *testing.T) { ID string } - r.Resource("/test").MaxBodyBytes(5).Put("desc", func(input *Input) string { + r.Resource("/test", MaxBodyBytes(5)).Put("desc", func(input *Input) string { return "hello, " + input.ID }) @@ -561,7 +496,7 @@ func TestBodySlow(t *testing.T) { ID string } - r.Resource("/test").BodyReadTimeout(1).Put("desc", func(input *Input) string { + r.Resource("/test", BodyReadTimeout(1)).Put("desc", func(input *Input) string { return "hello, " + input.ID }) diff --git a/schema.go b/schema/schema.go similarity index 91% rename from schema.go rename to schema/schema.go index 83d89b0d..e2311c0e 100644 --- a/schema.go +++ b/schema/schema.go @@ -1,4 +1,6 @@ -package huma +// Package schema implements OpenAPI 3 compatible JSON Schema which can be +// generated from structs. +package schema import ( "encoding/json" @@ -16,18 +18,18 @@ import ( // ErrSchemaInvalid is sent when there is a problem building the schema. var ErrSchemaInvalid = errors.New("schema is invalid") -// SchemaMode defines whether the schema is being generated for read or +// Mode defines whether the schema is being generated for read or // write mode. Read-only fields are dropped when in write mode, for example. -type SchemaMode int +type Mode int const ( - // SchemaModeAll is for general purpose use and includes all fields. - SchemaModeAll SchemaMode = iota - // SchemaModeRead is for HTTP HEAD & GET and will hide write-only fields. - SchemaModeRead - // SchemaModeWrite is for HTTP POST, PUT, PATCH, DELETE and will hide + // ModeAll is for general purpose use and includes all fields. + ModeAll Mode = iota + // ModeRead is for HTTP HEAD & GET and will hide write-only fields. + ModeRead + // ModeWrite is for HTTP POST, PUT, PATCH, DELETE and will hide // read-only fields. - SchemaModeWrite + ModeWrite ) var ( @@ -120,20 +122,20 @@ func (s *Schema) HasValidation() bool { return false } -// GenerateSchema creates a JSON schema for a Go type. Struct field tags +// Generate creates a JSON schema for a Go type. Struct field tags // can be used to provide additional metadata such as descriptions and // validation. -func GenerateSchema(t reflect.Type) (*Schema, error) { - return GenerateSchemaWithMode(t, SchemaModeAll, nil) +func Generate(t reflect.Type) (*Schema, error) { + return GenerateWithMode(t, ModeAll, nil) } -// GenerateSchemaWithMode creates a JSON schema for a Go type. Struct field +// GenerateWithMode creates a JSON schema for a Go type. Struct field // tags can be used to provide additional metadata such as descriptions and // validation. The mode can be all, read, or write. In read or write mode // any field that is marked as the opposite will be excluded, e.g. a // write-only field would not be included in read mode. If a schema is given // as input, add to it, otherwise creates a new schema. -func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*Schema, error) { +func GenerateWithMode(t reflect.Type, mode Mode, schema *Schema) (*Schema, error) { if schema == nil { schema = &Schema{} } @@ -167,7 +169,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S name = jsonTags[0] } - s, err := GenerateSchemaWithMode(f.Type, mode, nil) + s, err := GenerateWithMode(f.Type, mode, nil) if err != nil { return nil, err } @@ -177,6 +179,10 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S s.Description = tag } + if tag, ok := f.Tag.Lookup("doc"); ok { + s.Description = tag + } + if tag, ok := f.Tag.Lookup("format"); ok { s.Format = tag } @@ -326,7 +332,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S } s.ReadOnly = tag == "true" - if s.ReadOnly && mode == SchemaModeWrite { + if s.ReadOnly && mode == ModeWrite { delete(properties, name) continue } @@ -338,7 +344,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S } s.WriteOnly = tag == "true" - if s.WriteOnly && mode == SchemaModeRead { + if s.WriteOnly && mode == ModeRead { delete(properties, name) continue } @@ -372,14 +378,14 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S case reflect.Map: schema.Type = "object" - s, err := GenerateSchemaWithMode(t.Elem(), mode, nil) + s, err := GenerateWithMode(t.Elem(), mode, nil) if err != nil { return nil, err } schema.AdditionalProperties = s case reflect.Slice, reflect.Array: schema.Type = "array" - s, err := GenerateSchemaWithMode(t.Elem(), mode, nil) + s, err := GenerateWithMode(t.Elem(), mode, nil) if err != nil { return nil, err } @@ -410,7 +416,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S case reflect.String: schema.Type = "string" case reflect.Ptr: - return GenerateSchemaWithMode(t.Elem(), mode, schema) + return GenerateWithMode(t.Elem(), mode, schema) default: return nil, fmt.Errorf("unsupported type %s from %s", t.Kind(), t) } diff --git a/schema_test.go b/schema/schema_test.go similarity index 75% rename from schema_test.go rename to schema/schema_test.go index 41ef04ce..01189e90 100644 --- a/schema_test.go +++ b/schema/schema_test.go @@ -1,4 +1,4 @@ -package huma +package schema import ( "fmt" @@ -11,6 +11,21 @@ import ( "github.com/stretchr/testify/assert" ) +func Example() { + type MyObject struct { + ID string `doc:"Object ID" readOnly:"true"` + Rate float64 `doc:"Rate of change" minimum:"0"` + Coords []int `doc:"X,Y coordinates" minItems:"2" maxItems:"2"` + } + + generated, err := Generate(reflect.TypeOf(MyObject{})) + if err != nil { + panic(err) + } + fmt.Println(generated.Properties["id"].ReadOnly) + // output: true +} + var types = []struct { in interface{} out string @@ -34,7 +49,7 @@ func TestSchemaTypes(outer *testing.T) { local := tt outer.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) { t.Parallel() - s, err := GenerateSchema(reflect.ValueOf(local.in).Type()) + s, err := Generate(reflect.ValueOf(local.in).Type()) assert.NoError(t, err) assert.Equal(t, local.out, s.Type) assert.Equal(t, local.format, s.Format) @@ -48,7 +63,7 @@ func TestSchemaRequiredFields(t *testing.T) { Required string `json:"required"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Len(t, s.Properties, 2) assert.NotContains(t, s.Required, "optional") @@ -60,7 +75,7 @@ func TestSchemaRenameField(t *testing.T) { Foo string `json:"bar"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Empty(t, s.Properties["foo"]) assert.NotEmpty(t, s.Properties["bar"]) @@ -71,7 +86,7 @@ func TestSchemaDescription(t *testing.T) { Foo string `json:"foo" description:"I am a test"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, "I am a test", s.Properties["foo"].Description) } @@ -81,7 +96,7 @@ func TestSchemaFormat(t *testing.T) { Foo string `json:"foo" format:"date-time"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, "date-time", s.Properties["foo"].Format) } @@ -91,7 +106,7 @@ func TestSchemaEnum(t *testing.T) { Foo string `json:"foo" enum:"one,two,three"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, []interface{}{"one", "two", "three"}, s.Properties["foo"].Enum) } @@ -101,7 +116,7 @@ func TestSchemaDefault(t *testing.T) { Foo string `json:"foo" default:"def"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, "def", s.Properties["foo"].Default) } @@ -111,7 +126,7 @@ func TestSchemaExample(t *testing.T) { Foo string `json:"foo" example:"ex"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, "ex", s.Properties["foo"].Example) } @@ -121,7 +136,7 @@ func TestSchemaNullable(t *testing.T) { Foo string `json:"foo" nullable:"true"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, true, s.Properties["foo"].Nullable) } @@ -131,7 +146,7 @@ func TestSchemaNullableError(t *testing.T) { Foo string `json:"foo" nullable:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -140,7 +155,7 @@ func TestSchemaReadOnly(t *testing.T) { Foo string `json:"foo" readOnly:"true"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, true, s.Properties["foo"].ReadOnly) } @@ -150,7 +165,7 @@ func TestSchemaReadOnlyError(t *testing.T) { Foo string `json:"foo" readOnly:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -159,7 +174,7 @@ func TestSchemaWriteOnly(t *testing.T) { Foo string `json:"foo" writeOnly:"true"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, true, s.Properties["foo"].WriteOnly) } @@ -169,7 +184,7 @@ func TestSchemaWriteOnlyError(t *testing.T) { Foo string `json:"foo" writeOnly:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -178,7 +193,7 @@ func TestSchemaDeprecated(t *testing.T) { Foo string `json:"foo" deprecated:"true"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, true, s.Properties["foo"].Deprecated) } @@ -188,7 +203,7 @@ func TestSchemaDeprecatedError(t *testing.T) { Foo string `json:"foo" deprecated:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -197,7 +212,7 @@ func TestSchemaMinimum(t *testing.T) { Foo float64 `json:"foo" minimum:"1"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, 1.0, *s.Properties["foo"].Minimum) } @@ -207,7 +222,7 @@ func TestSchemaMinimumError(t *testing.T) { Foo float64 `json:"foo" minimum:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -216,7 +231,7 @@ func TestSchemaExclusiveMinimum(t *testing.T) { Foo float64 `json:"foo" exclusiveMinimum:"1"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, 1.0, *s.Properties["foo"].ExclusiveMinimum) } @@ -226,7 +241,7 @@ func TestSchemaExclusiveMinimumError(t *testing.T) { Foo float64 `json:"foo" exclusiveMinimum:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -235,7 +250,7 @@ func TestSchemaMaximum(t *testing.T) { Foo float64 `json:"foo" maximum:"0"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, 0.0, *s.Properties["foo"].Maximum) } @@ -245,7 +260,7 @@ func TestSchemaMaximumError(t *testing.T) { Foo float64 `json:"foo" maximum:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -254,7 +269,7 @@ func TestSchemaExclusiveMaximum(t *testing.T) { Foo float64 `json:"foo" exclusiveMaximum:"0"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, 0.0, *s.Properties["foo"].ExclusiveMaximum) } @@ -264,7 +279,7 @@ func TestSchemaExclusiveMaximumError(t *testing.T) { Foo float64 `json:"foo" exclusiveMaximum:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -273,7 +288,7 @@ func TestSchemaMultipleOf(t *testing.T) { Foo float64 `json:"foo" multipleOf:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, 10.0, s.Properties["foo"].MultipleOf) } @@ -283,7 +298,7 @@ func TestSchemaMultipleOfError(t *testing.T) { Foo float64 `json:"foo" multipleOf:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -292,7 +307,7 @@ func TestSchemaMinLength(t *testing.T) { Foo string `json:"foo" minLength:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint64(10), *s.Properties["foo"].MinLength) } @@ -302,7 +317,7 @@ func TestSchemaMinLengthError(t *testing.T) { Foo string `json:"foo" minLength:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -311,7 +326,7 @@ func TestSchemaMaxLength(t *testing.T) { Foo string `json:"foo" maxLength:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint64(10), *s.Properties["foo"].MaxLength) } @@ -321,7 +336,7 @@ func TestSchemaMaxLengthError(t *testing.T) { Foo string `json:"foo" maxLength:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -330,7 +345,7 @@ func TestSchemaPattern(t *testing.T) { Foo string `json:"foo" pattern:"a-z+"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, "a-z+", s.Properties["foo"].Pattern) } @@ -340,7 +355,7 @@ func TestSchemaPatternError(t *testing.T) { Foo string `json:"foo" pattern:"(.*"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -349,7 +364,7 @@ func TestSchemaMinItems(t *testing.T) { Foo []string `json:"foo" minItems:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint64(10), *s.Properties["foo"].MinItems) } @@ -359,7 +374,7 @@ func TestSchemaMinItemsError(t *testing.T) { Foo []string `json:"foo" minItems:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -368,7 +383,7 @@ func TestSchemaMaxItems(t *testing.T) { Foo []string `json:"foo" maxItems:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint64(10), *s.Properties["foo"].MaxItems) } @@ -378,7 +393,7 @@ func TestSchemaMaxItemsError(t *testing.T) { Foo []string `json:"foo" maxItems:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -387,7 +402,7 @@ func TestSchemaUniqueItems(t *testing.T) { Foo []string `json:"foo" uniqueItems:"true"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, true, s.Properties["foo"].UniqueItems) } @@ -397,7 +412,7 @@ func TestSchemaUniqueItemsError(t *testing.T) { Foo []string `json:"foo" uniqueItems:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -406,7 +421,7 @@ func TestSchemaMinProperties(t *testing.T) { Foo []string `json:"foo" minProperties:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint64(10), *s.Properties["foo"].MinProperties) } @@ -416,7 +431,7 @@ func TestSchemaMinPropertiesError(t *testing.T) { Foo []string `json:"foo" minProperties:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -425,7 +440,7 @@ func TestSchemaMaxProperties(t *testing.T) { Foo []string `json:"foo" maxProperties:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint64(10), *s.Properties["foo"].MaxProperties) } @@ -435,12 +450,12 @@ func TestSchemaMaxPropertiesError(t *testing.T) { Foo []string `json:"foo" maxProperties:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } func TestSchemaMap(t *testing.T) { - s, err := GenerateSchema(reflect.TypeOf(map[string]string{})) + s, err := Generate(reflect.TypeOf(map[string]string{})) assert.NoError(t, err) assert.Equal(t, &Schema{ Type: "object", @@ -451,7 +466,7 @@ func TestSchemaMap(t *testing.T) { } func TestSchemaSlice(t *testing.T) { - s, err := GenerateSchema(reflect.TypeOf([]string{})) + s, err := Generate(reflect.TypeOf([]string{})) assert.NoError(t, err) assert.Equal(t, &Schema{ Type: "array", @@ -462,7 +477,7 @@ func TestSchemaSlice(t *testing.T) { } func TestSchemaUnsigned(t *testing.T) { - s, err := GenerateSchema(reflect.TypeOf(uint(10))) + s, err := Generate(reflect.TypeOf(uint(10))) assert.NoError(t, err) min := 0.0 assert.Equal(t, &Schema{ @@ -477,7 +492,7 @@ func TestSchemaNonStringExample(t *testing.T) { Foo uint32 `json:"foo" example:"10"` } - s, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, uint32(10), s.Properties["foo"].Example) } @@ -487,7 +502,7 @@ func TestSchemaNonStringExampleErrorUnmarshal(t *testing.T) { Foo uint32 `json:"foo" example:"bad"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } @@ -496,7 +511,7 @@ func TestSchemaNonStringExampleErrorCast(t *testing.T) { Foo bool `json:"foo" example:"1"` } - _, err := GenerateSchema(reflect.ValueOf(Example{}).Type()) + _, err := Generate(reflect.ValueOf(Example{}).Type()) assert.Error(t, err) } diff --git a/validate.go b/validate.go index 78695ec5..2a2bfd4a 100644 --- a/validate.go +++ b/validate.go @@ -8,6 +8,7 @@ import ( "regexp" "strings" + "github.com/danielgtaylor/huma/schema" "github.com/gosimple/slug" ) @@ -37,7 +38,7 @@ func (a *OpenAPI) validate() error { } // validate the parameter and generate schemas -func (p *Param) validate(t reflect.Type) { +func (p *OpenAPIParam) validate(t reflect.Type) { switch p.In { case InPath, InQuery, InHeader: default: @@ -65,7 +66,7 @@ func (p *Param) validate(t reflect.Type) { p.typ = t if p.Schema == nil || p.Schema.Type == "" { - s, err := GenerateSchemaWithMode(p.typ, SchemaModeWrite, p.Schema) + s, err := schema.GenerateWithMode(p.typ, schema.ModeWrite, p.Schema) if err != nil { panic(fmt.Errorf("parameter %s schema generation error: %w", p.Name, err)) } @@ -84,10 +85,10 @@ func (p *Param) validate(t reflect.Type) { } // validate the header and generate schemas -func (h *ResponseHeader) validate(t reflect.Type) { +func (h *OpenAPIResponseHeader) validate(t reflect.Type) { if h.Schema == nil { // Generate the schema from the handler function types. - s, err := GenerateSchemaWithMode(t, SchemaModeRead, nil) + s, err := schema.GenerateWithMode(t, schema.ModeRead, nil) if err != nil { panic(fmt.Errorf("response header %s schema generation error: %w", h.Name, err)) } @@ -97,39 +98,39 @@ func (h *ResponseHeader) validate(t reflect.Type) { // validate checks that the operation is well-formed (e.g. handler signature // matches the given params) and generates schemas if needed. -func (o *Operation) validate(method, path string) { +func (o *OpenAPIOperation) validate(method, path string) { prefix := method + " " + path + ":" - if o.Description == "" { + if o.description == "" { panic(fmt.Errorf("%s description field required: %w", prefix, ErrOperationInvalid)) } - if len(o.Responses) == 0 { + if len(o.responses) == 0 { panic(fmt.Errorf("%s at least one response is required: %w", prefix, ErrOperationInvalid)) } - if o.Handler == nil { + if o.handler == nil { panic(fmt.Errorf("%s handler is required: %w", prefix, ErrOperationInvalid)) } - handler := reflect.ValueOf(o.Handler).Type() + handler := reflect.ValueOf(o.handler).Type() - totalIn := len(o.Dependencies) + len(o.Params) - totalOut := len(o.ResponseHeaders) + len(o.Responses) + totalIn := len(o.dependencies) + len(o.params) + totalOut := len(o.responseHeaders) + len(o.responses) if !(handler.NumIn() == totalIn || (method != http.MethodGet && handler.NumIn() == totalIn+1)) || handler.NumOut() != totalOut { expected := "func(" - for _, dep := range o.Dependencies { - expected += "? " + reflect.ValueOf(dep.Value).Type().String() + ", " + for _, dep := range o.dependencies { + expected += "? " + reflect.ValueOf(dep.handler).Type().String() + ", " } - for _, param := range o.Params { + for _, param := range o.params { expected += param.Name + " ?, " } expected = strings.TrimRight(expected, ", ") expected += ") (" - for _, h := range o.ResponseHeaders { + for _, h := range o.responseHeaders { expected += h.Name + " ?, " } - for _, r := range o.Responses { + for _, r := range o.responses { expected += fmt.Sprintf("*Response%d, ", r.StatusCode) } expected = strings.TrimRight(expected, ", ") @@ -138,7 +139,7 @@ func (o *Operation) validate(method, path string) { panic(fmt.Errorf("%s expected handler %s but found %s: %w", prefix, expected, handler, ErrOperationInvalid)) } - if o.ID == "" { + if o.id == "" { verb := method // Try to detect calls returning lists of things. @@ -152,10 +153,10 @@ func (o *Operation) validate(method, path string) { // Remove variables from path so they aren't in the generated name. path := paramRe.ReplaceAllString(path, "") - o.ID = slug.Make(verb + path) + o.id = slug.Make(verb + path) } - for i, dep := range o.Dependencies { + for i, dep := range o.dependencies { paramType := handler.In(i) // Catch common errors. @@ -163,7 +164,7 @@ func (o *Operation) validate(method, path string) { panic(fmt.Errorf("%s gin.Context should be pointer *gin.Context: %w", prefix, ErrOperationInvalid)) } - if paramType.String() == "huma.Operation" { + if paramType.String() == "huma.OpenAPIOperation" { panic(fmt.Errorf("%s huma.Operation should be pointer *huma.Operation: %w", prefix, ErrOperationInvalid)) } @@ -171,13 +172,13 @@ func (o *Operation) validate(method, path string) { } types := []reflect.Type{} - for i := len(o.Dependencies); i < handler.NumIn(); i++ { + for i := len(o.dependencies); i < handler.NumIn(); i++ { paramType := handler.In(i) switch paramType.String() { case "gin.Context", "*gin.Context": panic(fmt.Errorf("%s expected param but found gin.Context: %w", prefix, ErrOperationInvalid)) - case "huma.Operation", "*huma.Operation": + case "huma.Operation", "*huma.OpenAPIOperation": panic(fmt.Errorf("%s expected param but found huma.Operation: %w", prefix, ErrOperationInvalid)) } @@ -185,37 +186,38 @@ func (o *Operation) validate(method, path string) { } requestBody := false - if len(types) == len(o.Params)+1 { + if len(types) == len(o.params)+1 { requestBody = true } for i, paramType := range types { if i == len(types)-1 && requestBody { // The last item has no associated param. It is a request body. - if o.RequestSchema == nil { - s, err := GenerateSchemaWithMode(paramType, SchemaModeWrite, nil) + if o.requestSchema == nil { + s, err := schema.GenerateWithMode(paramType, schema.ModeWrite, nil) if err != nil { panic(fmt.Errorf("%s request body schema generation error: %w", prefix, err)) } - o.RequestSchema = s + o.requestSchema = s } continue } - p := o.Params[i] + p := o.params[i] p.validate(paramType) } - for i, header := range o.ResponseHeaders { + for i, header := range o.responseHeaders { header.validate(handler.Out(i)) } - for i, resp := range o.Responses { - respType := handler.Out(len(o.ResponseHeaders) + i) - // HTTP 204 explicitly forbids a response body. - if !resp.empty && resp.Schema == nil { + for i, resp := range o.responses { + respType := handler.Out(len(o.responseHeaders) + i) + // HTTP 204 explicitly forbids a response body. We model this with an + // empty content type. + if resp.ContentType != "" && resp.Schema == nil { // Generate the schema from the handler function types. - s, err := GenerateSchemaWithMode(respType, SchemaModeRead, nil) + s, err := schema.GenerateWithMode(respType, schema.ModeRead, nil) if err != nil { panic(fmt.Errorf("%s response %d schema generation error: %w", prefix, resp.StatusCode, err)) } diff --git a/validate_test.go b/validate_test.go index 0fe7f105..33d42ca3 100644 --- a/validate_test.go +++ b/validate_test.go @@ -12,7 +12,7 @@ func TestOperationDescriptionRequired(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{}) + r.Register(http.MethodGet, "/", &OpenAPIOperation{}) }) } @@ -20,21 +20,8 @@ func TestOperationResponseRequired(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - }) - }) -} - -func TestOperationHandlerMissing(t *testing.T) { - r := NewTestRouter(t) - - assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Responses: []*Response{ - ResponseText(200, "Test"), - }, + r.Register(http.MethodGet, "/", &OpenAPIOperation{ + description: "Test", }) }) } @@ -42,26 +29,13 @@ func TestOperationHandlerMissing(t *testing.T) { func TestOperationHandlerInput(t *testing.T) { r := NewTestRouter(t) - d := &Dependency{ - Value: func() (string, error) { - return "test", nil - }, - } - assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Dependencies: []*Dependency{d}, - Params: []*Param{ - QueryParam("foo", "Test", ""), - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func() string { - // Wrong number of inputs! - return "fails" - }, + r.Resource("/", + SimpleDependency("test"), + ResponseText(200, "Test"), + ).Get("Test", func() string { + // Wrong number of inputs! + return "fails" }) }) } @@ -70,18 +44,12 @@ func TestOperationHandlerOutput(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - ResponseHeaders: []*ResponseHeader{ - Header("x-test", "Test"), - }, - Responses: []*Response{ - ResponseText(200, "Test", "x-test"), - }, - Handler: func() string { - // Wrong number of outputs! - return "fails" - }, + r.Resource("/", + ResponseHeader("x-test", "Test"), + ResponseText(200, "Test", Headers("x-test")), + ).Get("Test", func() string { + // Wrong number of outputs! + return "fails" }) }) } @@ -89,36 +57,23 @@ func TestOperationHandlerOutput(t *testing.T) { func TestOperationListAutoID(t *testing.T) { r := NewTestRouter(t) - o := &Operation{ - Description: "Test", - Responses: []*Response{ - ResponseJSON(200, "Test"), - }, - Handler: func() []string { - return []string{"test"} - }, - } + r.Resource("/items").Get("Test", func() []string { + return []string{"test"} + }) - r.Register(http.MethodGet, "/items", o) + o := r.OpenAPI().Paths["/items"][http.MethodGet] - assert.Equal(t, "list-items", o.ID) + assert.Equal(t, "list-items", o.id) } func TestOperationContextPointer(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Dependencies: []*Dependency{ - ContextDependency(), - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(c gin.Context) string { - return "test" - }, + r.Resource("/", + GinContextDependency(), + ).Get("Test", func(c gin.Context) string { + return "test" }) }) } @@ -127,17 +82,10 @@ func TestOperationOperationPointer(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Dependencies: []*Dependency{ - OperationDependency(), - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(o Operation) string { - return "test" - }, + r.Resource("/", + OperationDependency(), + ).Get("Test", func(o OpenAPIOperation) string { + return "test" }) }) } @@ -146,17 +94,10 @@ func TestOperationInvalidDep(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Dependencies: []*Dependency{ - &Dependency{}, - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(string) string { - return "test" - }, + r.Resource("/", + SimpleDependency(nil), + ).Get("Test", func(o OpenAPIOperation) string { + return "test" }) }) } @@ -165,32 +106,18 @@ func TestOperationParamDep(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Params: []*Param{ - QueryParam("foo", "Test", ""), - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(c *gin.Context) string { - return "test" - }, + r.Resource("/", + QueryParam("foo", "Test", ""), + ).Get("Test", func(c *gin.Context) string { + return "test" }) }) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Params: []*Param{ - QueryParam("foo", "Test", ""), - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(o *Operation) string { - return "test" - }, + r.Resource("/", + QueryParam("foo", "Test", ""), + ).Get("Test", func(c *OpenAPIOperation) string { + return "test" }) }) } @@ -198,31 +125,13 @@ func TestOperationParamDep(t *testing.T) { func TestOperationParamRedeclare(t *testing.T) { r := NewTestRouter(t) - p := QueryParam("foo", "Test", 0) + param := QueryParam("foo", "Test", 0) - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Params: []*Param{p}, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(p int) string { - return "test" - }, - }) + r.Resource("/a", param).Get("Test", func(p int) string { return "a" }) - // Param p was declared as `int` above but is `string` here. + // Redeclare param `p` as a string while it was an int above. assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Params: []*Param{p}, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(p string) string { - return "test" - }, - }) + r.Resource("/b", param).Get("Test", func(p string) string { return "b" }) }) } @@ -230,17 +139,10 @@ func TestOperationParamExampleType(t *testing.T) { r := NewTestRouter(t) assert.Panics(t, func() { - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Params: []*Param{ - QueryParamExample("foo", "Test", "", 123), - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(p string) string { - return "test" - }, + r.Resource("/", + QueryParam("foo", "Test", "", Example(123)), + ).Get("Test", func(p string) string { + return "test" }) }) } @@ -248,20 +150,13 @@ func TestOperationParamExampleType(t *testing.T) { func TestOperationParamExampleSchema(t *testing.T) { r := NewTestRouter(t) - p := QueryParamExample("foo", "Test", 0, 123) + p := QueryParam("foo", "Test", 0, Example(123)) - r.Register(http.MethodGet, "/", &Operation{ - Description: "Test", - Params: []*Param{ - p, - }, - Responses: []*Response{ - ResponseText(200, "Test"), - }, - Handler: func(p int) string { - return "test" - }, + r.Resource("/", p).Get("Test", func(p int) string { + return "test" }) - assert.Equal(t, 123, p.Schema.Example) + param := r.OpenAPI().Paths["/"][http.MethodGet].params[0] + + assert.Equal(t, 123, param.Schema.Example) }