Skip to content

Commit

Permalink
feat: unsafe handlers for library use
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Apr 26, 2020
1 parent 626a54b commit 26a54c3
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 76 deletions.
39 changes: 39 additions & 0 deletions examples/unsafe/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"net/http"
"reflect"

"github.com/danielgtaylor/huma"
"github.com/danielgtaylor/huma/schema"
)

// Item stores some value.
type Item struct {
ID string `json:"id"`
Value int32 `json:"value"`
}

func main() {
r := huma.NewRouter("Unsafe Test", "1.0.0")

// Generate response schema for docs.
s, _ := schema.Generate(reflect.TypeOf(Item{}))

r.Resource("/unsafe",
huma.PathParam("id", "desc"),
huma.ResponseJSON(http.StatusOK, "doc", huma.Schema(*s)),
).Get("doc", huma.UnsafeHandler(func(inputs ...interface{}) []interface{} {
// Get the ID, which is the first input and will be a string since it's
// a path parameter.
id := inputs[0].(string)

// Return an item with the passed in ID.
return []interface{}{&Item{
ID: id,
Value: 123,
}}
}))

r.Run()
}
9 changes: 9 additions & 0 deletions openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ func (o *openAPIOperation) allResponseHeaders() []*openAPIResponseHeader {
return headers
}

// unsafe returns true if the operation's handler was made with UnsafeHandler.
func (o *openAPIOperation) unsafe() bool {
if _, ok := o.handler.(*unsafeHandler); ok {
return true
}

return false
}

// openAPIServer describes an OpenAPI 3 API server location
type openAPIServer struct {
URL string `json:"url"`
Expand Down
23 changes: 13 additions & 10 deletions resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,19 @@ func (r *Resource) operation(method string, docs string, handler interface{}) {

op.handler = handler
if op.handler != nil {
t := reflect.TypeOf(op.handler)
if t.NumOut() == len(op.responseHeaders)+len(op.responses)+1 {
rtype := t.Out(t.NumOut() - 1)
switch rtype.Kind() {
case reflect.Bool:
op = op.With(Response(http.StatusNoContent, "Success"))
case reflect.String:
op = op.With(ResponseText(http.StatusOK, "Success"))
default:
op = op.With(ResponseJSON(http.StatusOK, "Success"))
// Only apply auto-response if it's *not* an unsafe handler.
if !op.unsafe() {
t := reflect.TypeOf(op.handler)
if t.NumOut() == len(op.responseHeaders)+len(op.responses)+1 {
rtype := t.Out(t.NumOut() - 1)
switch rtype.Kind() {
case reflect.Bool:
op = op.With(Response(http.StatusNoContent, "Success"))
case reflect.String:
op = op.With(ResponseText(http.StatusOK, "Success"))
default:
op = op.With(ResponseJSON(http.StatusOK, "Success"))
}
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,21 @@ func TestResourceGetPathParams(t *testing.T) {

assert.Equal(t, []string{"foo", "bar"}, res.PathParams())
}

func TestResourceUnsafeHandler(t *testing.T) {
r := NewTestRouter(t)

assert.Panics(t, func() {
r.Resource("/unsafe").Get("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} {
return []interface{}{true}
}))
})

assert.NotPanics(t, func() {
r.Resource("/unsafe",
Response(http.StatusNoContent, "doc"),
).Get("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} {
return []interface{}{true}
}))
})
}
71 changes: 63 additions & 8 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,27 @@ var connContextKey = struct{}{}

var timeType = reflect.TypeOf(time.Time{})

type unsafeHandler struct {
handler func(inputs ...interface{}) []interface{}
}

// UnsafeHandler is used to register programmatic handlers without argument
// count and type checking. This is useful for libraries that want to
// programmatically create new resources/operations. Using UnsafeHandler outside
// of that use-case is discouraged.
//
// The function's inputs are the ordered resolved dependencies, parsed
// parameters, and potentially an input body for PUT/POST requests that have
// a request schema defined. The output is a slice of response headers and
// response models.
//
// When using UnsafeHandler, you must manually define schemas for request
// and response bodies. They will be unmarshalled as `interface{}` when
// passed to the handler.
func UnsafeHandler(handler func(inputs ...interface{}) []interface{}) interface{} {
return &unsafeHandler{handler}
}

// getConn gets the underlying `net.Conn` from a request.
func getConn(r *http.Request) net.Conn {
conn := r.Context().Value(connContextKey)
Expand Down Expand Up @@ -250,7 +271,14 @@ func getParamValue(c *gin.Context, param *openAPIParam) (interface{}, bool) {
}

func getRequestBody(c *gin.Context, t reflect.Type, op *openAPIOperation) (interface{}, bool) {
val := reflect.New(t).Interface()
var val interface{}

if t != nil {
// If we have a type, then use it. Otherwise the body will unmarshal into
// a generic `map[string]interface{}` or `[]interface{}`.
val = reflect.New(t).Interface()
}

if op.requestSchema != nil {
body, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
Expand Down Expand Up @@ -416,8 +444,14 @@ func (r *Router) register(method, path string, op *openAPIOperation) {

// Then call it to register our handler function.
f(path, func(c *gin.Context) {
method := reflect.ValueOf(op.handler)
in := make([]reflect.Value, 0, method.Type().NumIn())
var method reflect.Value
if op.unsafe() {
method = reflect.ValueOf(op.handler.(*unsafeHandler).handler)
} else {
method = reflect.ValueOf(op.handler)
}

in := make([]reflect.Value, 0, len(op.dependencies)+len(op.params)+1)

// Limit the body size
if c.Request.Body != nil {
Expand Down Expand Up @@ -464,7 +498,7 @@ func (r *Router) register(method, path string, op *openAPIOperation) {
}

readTimeout := op.bodyReadTimeout
if len(in) != method.Type().NumIn() {
if op.requestSchema != nil {
if readTimeout == 0 {
// Default to 15s when reading/parsing/validating automatically.
readTimeout = 15 * time.Second
Expand All @@ -476,15 +510,24 @@ func (r *Router) register(method, path string, op *openAPIOperation) {

// Parse body
i := len(in)
val, success := getRequestBody(c, method.Type().In(i), op)

var bodyType reflect.Type
if op.unsafe() {
bodyType = reflect.TypeOf(map[string]interface{}{})
} else {
bodyType = method.Type().In(i)
}

b, success := getRequestBody(c, bodyType, op)
if !success {
// Error was already handled in `getRequestBody`.
return
}
in = append(in, reflect.ValueOf(val))
if in[i].Kind() == reflect.Ptr {
in[i] = in[i].Elem()
bval := reflect.ValueOf(b)
if bval.Kind() == reflect.Ptr {
bval = bval.Elem()
}
in = append(in, bval)
} else if readTimeout > 0 {
// We aren't processing the input, but still set the timeout.
if conn := getConn(c.Request); conn != nil {
Expand All @@ -494,6 +537,18 @@ func (r *Router) register(method, path string, op *openAPIOperation) {

out := method.Call(in)

if op.unsafe() {
// Normal handlers return multiple values. Unsafe handlers return one
// single list of response values. Here we convert.
newOut := make([]reflect.Value, out[0].Len())

for i := 0; i < out[0].Len(); i++ {
newOut[i] = out[0].Index(i)
}

out = newOut
}

// Find and return the first non-zero response. The status code comes
// from the registered `huma.Response` struct.
// This breaks down with scalar types... so they need to be passed
Expand Down
54 changes: 54 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -523,3 +524,56 @@ func TestBodySlow(t *testing.T) {
assert.Equal(t, http.StatusRequestTimeout, w.Code)
assert.Contains(t, w.Body.String(), "timed out")
}

func TestRouterUnsafeHandler(t *testing.T) {
r := NewTestRouter(t)

type Item struct {
ID string `json:"id" readOnly:"true"`
Value int `json:"value"`
}

readSchema, _ := schema.GenerateWithMode(reflect.TypeOf(Item{}), schema.ModeRead, nil)
writeSchema, _ := schema.GenerateWithMode(reflect.TypeOf(Item{}), schema.ModeWrite, nil)

items := map[string]Item{}

res := r.Resource("/test", PathParam("id", "doc"))

// Write handler
res.With(
RequestSchema(writeSchema),
Response(http.StatusNoContent, "doc"),
).Put("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} {
id := inputs[0].(string)
item := inputs[1].(map[string]interface{})

items[id] = Item{
ID: id,
Value: int(item["value"].(float64)),
}

return []interface{}{true}
}))

// Read handler
res.With(
ResponseJSON(http.StatusOK, "doc", Schema(*readSchema)),
).Get("doc", UnsafeHandler(func(inputs ...interface{}) []interface{} {
id := inputs[0].(string)

return []interface{}{items[id]}
}))

// Create an item
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/test/some-id", strings.NewReader(`{"value": 123}`))
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNoContent, w.Code, w.Body.String())

// Read the item
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/test/some-id", nil)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
Loading

0 comments on commit 26a54c3

Please sign in to comment.