diff --git a/examples/unsafe/main.go b/examples/unsafe/main.go new file mode 100644 index 00000000..f60b1d43 --- /dev/null +++ b/examples/unsafe/main.go @@ -0,0 +1,39 @@ +package main + +import ( + "net/http" + "reflect" + + "github.com/danielgtaylor/huma" + "github.com/danielgtaylor/huma/schema" +) + +// Item stores some value. +type Item struct { + ID string `json:"id"` + Value int32 `json:"value"` +} + +func main() { + r := huma.NewRouter("Unsafe Test", "1.0.0") + + // Generate response schema for docs. + s, _ := schema.Generate(reflect.TypeOf(Item{})) + + r.Resource("/unsafe", + huma.PathParam("id", "desc"), + huma.ResponseJSON(http.StatusOK, "doc", huma.Schema(*s)), + ).Get("doc", huma.UnsafeHandler(func(inputs ...interface{}) []interface{} { + // Get the ID, which is the first input and will be a string since it's + // a path parameter. + id := inputs[0].(string) + + // Return an item with the passed in ID. + return []interface{}{&Item{ + ID: id, + Value: 123, + }} + })) + + r.Run() +} diff --git a/openapi.go b/openapi.go index 0b9e66ef..4aeae0f4 100644 --- a/openapi.go +++ b/openapi.go @@ -232,6 +232,15 @@ func (o *openAPIOperation) allResponseHeaders() []*openAPIResponseHeader { return headers } +// unsafe returns true if the operation's handler was made with UnsafeHandler. +func (o *openAPIOperation) unsafe() bool { + if _, ok := o.handler.(*unsafeHandler); ok { + return true + } + + return false +} + // openAPIServer describes an OpenAPI 3 API server location type openAPIServer struct { URL string `json:"url"` diff --git a/resource.go b/resource.go index 98b56cdd..5067f939 100644 --- a/resource.go +++ b/resource.go @@ -118,16 +118,19 @@ func (r *Resource) operation(method string, docs string, handler interface{}) { op.handler = handler if op.handler != nil { - t := reflect.TypeOf(op.handler) - if t.NumOut() == len(op.responseHeaders)+len(op.responses)+1 { - rtype := t.Out(t.NumOut() - 1) - switch rtype.Kind() { - case reflect.Bool: - op = op.With(Response(http.StatusNoContent, "Success")) - case reflect.String: - op = op.With(ResponseText(http.StatusOK, "Success")) - default: - op = op.With(ResponseJSON(http.StatusOK, "Success")) + // Only apply auto-response if it's *not* an unsafe handler. + if !op.unsafe() { + t := reflect.TypeOf(op.handler) + if t.NumOut() == len(op.responseHeaders)+len(op.responses)+1 { + rtype := t.Out(t.NumOut() - 1) + switch rtype.Kind() { + case reflect.Bool: + op = op.With(Response(http.StatusNoContent, "Success")) + case reflect.String: + op = op.With(ResponseText(http.StatusOK, "Success")) + default: + op = op.With(ResponseJSON(http.StatusOK, "Success")) + } } } } diff --git a/resource_test.go b/resource_test.go index 1c167e91..70904dab 100644 --- a/resource_test.go +++ b/resource_test.go @@ -213,3 +213,21 @@ func TestResourceGetPathParams(t *testing.T) { assert.Equal(t, []string{"foo", "bar"}, res.PathParams()) } + +func TestResourceUnsafeHandler(t *testing.T) { + r := NewTestRouter(t) + + assert.Panics(t, func() { + r.Resource("/unsafe").Get("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} { + return []interface{}{true} + })) + }) + + assert.NotPanics(t, func() { + r.Resource("/unsafe", + Response(http.StatusNoContent, "doc"), + ).Get("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} { + return []interface{}{true} + })) + }) +} diff --git a/router.go b/router.go index 641e0101..6d5fa648 100644 --- a/router.go +++ b/router.go @@ -81,6 +81,27 @@ var connContextKey = struct{}{} var timeType = reflect.TypeOf(time.Time{}) +type unsafeHandler struct { + handler func(inputs ...interface{}) []interface{} +} + +// UnsafeHandler is used to register programmatic handlers without argument +// count and type checking. This is useful for libraries that want to +// programmatically create new resources/operations. Using UnsafeHandler outside +// of that use-case is discouraged. +// +// The function's inputs are the ordered resolved dependencies, parsed +// parameters, and potentially an input body for PUT/POST requests that have +// a request schema defined. The output is a slice of response headers and +// response models. +// +// When using UnsafeHandler, you must manually define schemas for request +// and response bodies. They will be unmarshalled as `interface{}` when +// passed to the handler. +func UnsafeHandler(handler func(inputs ...interface{}) []interface{}) interface{} { + return &unsafeHandler{handler} +} + // getConn gets the underlying `net.Conn` from a request. func getConn(r *http.Request) net.Conn { conn := r.Context().Value(connContextKey) @@ -250,7 +271,14 @@ func getParamValue(c *gin.Context, param *openAPIParam) (interface{}, bool) { } func getRequestBody(c *gin.Context, t reflect.Type, op *openAPIOperation) (interface{}, bool) { - val := reflect.New(t).Interface() + var val interface{} + + if t != nil { + // If we have a type, then use it. Otherwise the body will unmarshal into + // a generic `map[string]interface{}` or `[]interface{}`. + val = reflect.New(t).Interface() + } + if op.requestSchema != nil { body, err := ioutil.ReadAll(c.Request.Body) if err != nil { @@ -416,8 +444,14 @@ func (r *Router) register(method, path string, op *openAPIOperation) { // Then call it to register our handler function. f(path, func(c *gin.Context) { - method := reflect.ValueOf(op.handler) - in := make([]reflect.Value, 0, method.Type().NumIn()) + var method reflect.Value + if op.unsafe() { + method = reflect.ValueOf(op.handler.(*unsafeHandler).handler) + } else { + method = reflect.ValueOf(op.handler) + } + + in := make([]reflect.Value, 0, len(op.dependencies)+len(op.params)+1) // Limit the body size if c.Request.Body != nil { @@ -464,7 +498,7 @@ func (r *Router) register(method, path string, op *openAPIOperation) { } readTimeout := op.bodyReadTimeout - if len(in) != method.Type().NumIn() { + if op.requestSchema != nil { if readTimeout == 0 { // Default to 15s when reading/parsing/validating automatically. readTimeout = 15 * time.Second @@ -476,15 +510,24 @@ func (r *Router) register(method, path string, op *openAPIOperation) { // Parse body i := len(in) - val, success := getRequestBody(c, method.Type().In(i), op) + + var bodyType reflect.Type + if op.unsafe() { + bodyType = reflect.TypeOf(map[string]interface{}{}) + } else { + bodyType = method.Type().In(i) + } + + b, success := getRequestBody(c, bodyType, op) if !success { // Error was already handled in `getRequestBody`. return } - in = append(in, reflect.ValueOf(val)) - if in[i].Kind() == reflect.Ptr { - in[i] = in[i].Elem() + bval := reflect.ValueOf(b) + if bval.Kind() == reflect.Ptr { + bval = bval.Elem() } + in = append(in, bval) } else if readTimeout > 0 { // We aren't processing the input, but still set the timeout. if conn := getConn(c.Request); conn != nil { @@ -494,6 +537,18 @@ func (r *Router) register(method, path string, op *openAPIOperation) { out := method.Call(in) + if op.unsafe() { + // Normal handlers return multiple values. Unsafe handlers return one + // single list of response values. Here we convert. + newOut := make([]reflect.Value, out[0].Len()) + + for i := 0; i < out[0].Len(); i++ { + newOut[i] = out[0].Index(i) + } + + out = newOut + } + // Find and return the first non-zero response. The status code comes // from the registered `huma.Response` struct. // This breaks down with scalar types... so they need to be passed diff --git a/router_test.go b/router_test.go index d2795433..104aa5cd 100644 --- a/router_test.go +++ b/router_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" "time" @@ -523,3 +524,56 @@ func TestBodySlow(t *testing.T) { assert.Equal(t, http.StatusRequestTimeout, w.Code) assert.Contains(t, w.Body.String(), "timed out") } + +func TestRouterUnsafeHandler(t *testing.T) { + r := NewTestRouter(t) + + type Item struct { + ID string `json:"id" readOnly:"true"` + Value int `json:"value"` + } + + readSchema, _ := schema.GenerateWithMode(reflect.TypeOf(Item{}), schema.ModeRead, nil) + writeSchema, _ := schema.GenerateWithMode(reflect.TypeOf(Item{}), schema.ModeWrite, nil) + + items := map[string]Item{} + + res := r.Resource("/test", PathParam("id", "doc")) + + // Write handler + res.With( + RequestSchema(writeSchema), + Response(http.StatusNoContent, "doc"), + ).Put("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} { + id := inputs[0].(string) + item := inputs[1].(map[string]interface{}) + + items[id] = Item{ + ID: id, + Value: int(item["value"].(float64)), + } + + return []interface{}{true} + })) + + // Read handler + res.With( + ResponseJSON(http.StatusOK, "doc", Schema(*readSchema)), + ).Get("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} { + id := inputs[0].(string) + + return []interface{}{items[id]} + })) + + // Create an item + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPut, "/test/some-id", strings.NewReader(`{"value": 123}`)) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusNoContent, w.Code, w.Body.String()) + + // Read the item + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, "/test/some-id", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} diff --git a/validate.go b/validate.go index 5d81fdce..d1bcfca0 100644 --- a/validate.go +++ b/validate.go @@ -45,6 +45,16 @@ func (p *openAPIParam) validate(t reflect.Type) { panic(fmt.Errorf("parameter %s location invalid: %s", p.Name, p.In)) } + if t == nil { + // Unknown type for unsafe handlers defaults to `string` for path params + // and to the given default value's type for everything else. + if p.def != nil { + t = reflect.TypeOf(p.def) + } else { + t = reflect.TypeOf("") + } + } + if p.typ != nil && p.typ != t { panic(fmt.Errorf("parameter %s declared as %s was previously declared as %s: %w", p.Name, t, p.typ, ErrParamInvalid)) } @@ -86,6 +96,11 @@ func (p *openAPIParam) validate(t reflect.Type) { // validate the header and generate schemas func (h *openAPIResponseHeader) validate(t reflect.Type) { + if t == nil { + // Unsafe handlers default to string headers + t = reflect.TypeOf("") + } + if h.Schema == nil { // Generate the schema from the handler function types. s, err := schema.GenerateWithMode(t, schema.ModeRead, nil) @@ -109,41 +124,49 @@ func (o *openAPIOperation) validate(method, path string) { panic(fmt.Errorf("%s at least one response is required: %w", prefix, ErrOperationInvalid)) } + validateHandler := true if o.handler == nil { panic(fmt.Errorf("%s handler is required: %w", prefix, ErrOperationInvalid)) + } else { + if _, ok := o.handler.(*unsafeHandler); ok { + // This is an unsafe handler, so skip validation. + validateHandler = false + } } handler := reflect.ValueOf(o.handler).Type() - totalIn := len(o.dependencies) + len(o.params) - totalOut := len(o.responseHeaders) + len(o.responses) - if !(handler.NumIn() == totalIn || (method != http.MethodGet && handler.NumIn() == totalIn+1)) || handler.NumOut() != totalOut { - expected := "func(" - for _, dep := range o.dependencies { - expected += "? " + reflect.ValueOf(dep.handler).Type().String() + ", " - } - for _, param := range o.params { - expected += param.Name + " ?, " - } - expected = strings.TrimRight(expected, ", ") - expected += ") (" - for _, h := range o.responseHeaders { - expected += h.Name + " ?, " - } - for _, r := range o.responses { - expected += fmt.Sprintf("*Response%d, ", r.StatusCode) - } - expected = strings.TrimRight(expected, ", ") - expected += ")" + if validateHandler { + totalIn := len(o.dependencies) + len(o.params) + totalOut := len(o.responseHeaders) + len(o.responses) + if !(handler.NumIn() == totalIn || (method != http.MethodGet && handler.NumIn() == totalIn+1)) || handler.NumOut() != totalOut { + expected := "func(" + for _, dep := range o.dependencies { + expected += "? " + reflect.ValueOf(dep.handler).Type().String() + ", " + } + for _, param := range o.params { + expected += param.Name + " ?, " + } + expected = strings.TrimRight(expected, ", ") + expected += ") (" + for _, h := range o.responseHeaders { + expected += h.Name + " ?, " + } + for _, r := range o.responses { + expected += fmt.Sprintf("*Response%d, ", r.StatusCode) + } + expected = strings.TrimRight(expected, ", ") + expected += ")" - panic(fmt.Errorf("%s expected handler %s but found %s: %w", prefix, expected, handler, ErrOperationInvalid)) + panic(fmt.Errorf("%s expected handler %s but found %s: %w", prefix, expected, handler, ErrOperationInvalid)) + } } if o.id == "" { verb := method // Try to detect calls returning lists of things. - if handler.NumOut() > 0 { + if validateHandler && handler.NumOut() > 0 { k := handler.Out(0).Kind() if k == reflect.Array || k == reflect.Slice { verb = "list" @@ -157,32 +180,36 @@ func (o *openAPIOperation) validate(method, path string) { } for i, dep := range o.dependencies { - paramType := handler.In(i) + if validateHandler { + paramType := handler.In(i) - // Catch common errors. - if paramType.String() == "gin.Context" { - panic(fmt.Errorf("%s gin.Context should be pointer *gin.Context: %w", prefix, ErrOperationInvalid)) - } + // Catch common errors. + if paramType.String() == "gin.Context" { + panic(fmt.Errorf("%s gin.Context should be pointer *gin.Context: %w", prefix, ErrOperationInvalid)) + } - if paramType.String() == "huma.OpenAPIOperation" { - panic(fmt.Errorf("%s huma.Operation should be pointer *huma.Operation: %w", prefix, ErrOperationInvalid)) + dep.validate(paramType) + } else { + dep.validate(nil) } - - dep.validate(paramType) } types := []reflect.Type{} - for i := len(o.dependencies); i < handler.NumIn(); i++ { - paramType := handler.In(i) + if validateHandler { + for i := len(o.dependencies); i < handler.NumIn(); i++ { + paramType := handler.In(i) - switch paramType.String() { - case "gin.Context", "*gin.Context": - panic(fmt.Errorf("%s expected param but found gin.Context: %w", prefix, ErrOperationInvalid)) - case "huma.Operation", "*huma.OpenAPIOperation": - panic(fmt.Errorf("%s expected param but found huma.Operation: %w", prefix, ErrOperationInvalid)) - } + switch paramType.String() { + case "gin.Context", "*gin.Context": + panic(fmt.Errorf("%s expected param but found gin.Context: %w", prefix, ErrOperationInvalid)) + } - types = append(types, paramType) + types = append(types, paramType) + } + } else { + for i := 0; i < len(o.params); i++ { + types = append(types, nil) + } } requestBody := false @@ -208,20 +235,26 @@ func (o *openAPIOperation) validate(method, path string) { } for i, header := range o.responseHeaders { - header.validate(handler.Out(i)) + if validateHandler { + header.validate(handler.Out(i)) + } else { + header.validate(nil) + } } for i, resp := range o.responses { - respType := handler.Out(len(o.responseHeaders) + i) - // HTTP 204 explicitly forbids a response body. We model this with an - // empty content type. - if resp.ContentType != "" && resp.Schema == nil { - // Generate the schema from the handler function types. - s, err := schema.GenerateWithMode(respType, schema.ModeRead, nil) - if err != nil { - panic(fmt.Errorf("%s response %d schema generation error: %w", prefix, resp.StatusCode, err)) + if validateHandler { + respType := handler.Out(len(o.responseHeaders) + i) + // HTTP 204 explicitly forbids a response body. We model this with an + // empty content type. + if resp.ContentType != "" && resp.Schema == nil { + // Generate the schema from the handler function types. + s, err := schema.GenerateWithMode(respType, schema.ModeRead, nil) + if err != nil { + panic(fmt.Errorf("%s response %d schema generation error: %w", prefix, resp.StatusCode, err)) + } + resp.Schema = s } - resp.Schema = s } } } diff --git a/validate_test.go b/validate_test.go index 56d800a7..7cbf7796 100644 --- a/validate_test.go +++ b/validate_test.go @@ -100,14 +100,6 @@ func TestOperationParamDep(t *testing.T) { return "test" }) }) - - assert.Panics(t, func() { - r.Resource("/", - QueryParam("foo", "Test", ""), - ).Get("Test", func(c *openAPIOperation) string { - return "test" - }) - }) } func TestOperationParamRedeclare(t *testing.T) {