diff --git a/adapters/humaflow/flow/flow.go b/adapters/humaflow/flow/flow.go new file mode 100644 index 00000000..4a9ceed7 --- /dev/null +++ b/adapters/humaflow/flow/flow.go @@ -0,0 +1,251 @@ +// Package flow is a delightfully simple, readable, and tiny HTTP router for Go web applications. Its features include: +// +// * Use named parameters, wildcards and (optionally) regexp patterns in your routes. +// * Create route groups which use different middleware (a bit like chi). +// * Customizable handlers for 404 Not Found and 405 Method Not Allowed responses. +// * Automatic handling of OPTIONS and HEAD requests. +// * Works with http.Handler, http.HandlerFunc, and standard Go middleware. +// +// Example code: +// +// package main +// +// import ( +// "fmt" +// "log" +// "net/http" +// +// "github.com/alexedwards/flow" +// ) +// +// func main() { +// mux := flow.New() +// +// // The Use() method can be used to register middleware. Middleware declared at +// // the top level will used on all routes (including error handlers and OPTIONS +// // responses). +// mux.Use(exampleMiddleware1) +// +// // Routes can use multiple HTTP methods. +// mux.HandleFunc("/profile/:name", exampleHandlerFunc1, "GET", "POST") +// +// // Optionally, regular expressions can be used to enforce a specific pattern +// // for a named parameter. +// mux.HandleFunc("/profile/:name/:age|^[0-9]{1,3}$", exampleHandlerFunc2, "GET") +// +// // The wildcard ... can be used to match the remainder of a request path. +// // Notice that HTTP methods are also optional (if not provided, all HTTP +// // methods will match the route). +// mux.Handle("/static/...", exampleHandler) +// +// // You can create route 'groups'. +// mux.Group(func(mux *flow.Mux) { +// // Middleware declared within in the group will only be used on the routes +// // in the group. +// mux.Use(exampleMiddleware2) +// +// mux.HandleFunc("/admin", exampleHandlerFunc3, "GET") +// +// // Groups can be nested. +// mux.Group(func(mux *flow.Mux) { +// mux.Use(exampleMiddleware3) +// +// mux.HandleFunc("/admin/passwords", exampleHandlerFunc4, "GET") +// }) +// }) +// +// err := http.ListenAndServe(":2323", mux) +// log.Fatal(err) +// } +package flow + +import ( + "context" + "net/http" + "regexp" + "slices" + "strings" +) + +// AllMethods is a slice containing all HTTP request methods. +var AllMethods = []string{http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodConnect, http.MethodOptions, http.MethodTrace} + +var compiledRXPatterns = map[string]*regexp.Regexp{} + +type contextKey string + +// Param is used to retrieve the value of a named parameter or wildcard from the +// request context. It returns the empty string if no matching parameter is +// found. +func Param(ctx context.Context, param string) string { + s, ok := ctx.Value(contextKey(param)).(string) + if !ok { + return "" + } + + return s +} + +// Mux is a http.Handler which dispatches requests to different handlers. +type Mux struct { + NotFound http.Handler + MethodNotAllowed http.Handler + Options http.Handler + routes *[]route + middlewares []func(http.Handler) http.Handler +} + +// New returns a new initialized Mux instance. +func New() *Mux { + return &Mux{ + NotFound: http.NotFoundHandler(), + MethodNotAllowed: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + }), + Options: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + routes: &[]route{}, + } +} + +// Handle registers a new handler for the given request path pattern and HTTP +// methods. +func (m *Mux) Handle(pattern string, handler http.Handler, methods ...string) { + if slices.Contains(methods, http.MethodGet) && !slices.Contains(methods, http.MethodHead) { + methods = append(methods, http.MethodHead) + } + + if len(methods) == 0 { + methods = AllMethods + } + + for _, method := range methods { + route := route{ + method: strings.ToUpper(method), + segments: strings.Split(pattern, "/"), + wildcard: strings.HasSuffix(pattern, "/..."), + handler: m.wrap(handler), + } + + *m.routes = append(*m.routes, route) + } + + // Compile any regular expression patterns and add them to the + // compiledRXPatterns map. + for _, segment := range strings.Split(pattern, "/") { + if strings.HasPrefix(segment, ":") { + _, rxPattern, containsRx := strings.Cut(segment, "|") + if containsRx { + compiledRXPatterns[rxPattern] = regexp.MustCompile(rxPattern) + } + } + } +} + +// HandleFunc is an adapter which allows using a http.HandlerFunc as a handler. +func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, methods ...string) { + m.Handle(pattern, fn, methods...) +} + +// Use registers middleware with the Mux instance. Middleware must have the +// signature `func(http.Handler) http.Handler`. +func (m *Mux) Use(mw ...func(http.Handler) http.Handler) { + m.middlewares = append(m.middlewares, mw...) +} + +// Group is used to create 'groups' of routes in a Mux. Middleware registered +// inside the group will only be used on the routes in that group. See the +// example code at the start of the package documentation for how to use this +// feature. +func (m *Mux) Group(fn func(*Mux)) { + mm := *m + fn(&mm) +} + +// ServeHTTP makes the router implement the http.Handler interface. +func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + urlSegments := strings.Split(r.URL.Path, "/") + allowedMethods := []string{} + + for _, route := range *m.routes { + ctx, ok := route.match(r.Context(), urlSegments) + if ok { + if r.Method == route.method { + route.handler.ServeHTTP(w, r.WithContext(ctx)) + return + } + if !slices.Contains(allowedMethods, route.method) { + allowedMethods = append(allowedMethods, route.method) + } + } + } + + if len(allowedMethods) > 0 { + w.Header().Set("Allow", strings.Join(append(allowedMethods, http.MethodOptions), ", ")) + if r.Method == http.MethodOptions { + m.wrap(m.Options).ServeHTTP(w, r) + } else { + m.wrap(m.MethodNotAllowed).ServeHTTP(w, r) + } + return + } + + m.wrap(m.NotFound).ServeHTTP(w, r) +} + +func (m *Mux) wrap(handler http.Handler) http.Handler { + for i := len(m.middlewares) - 1; i >= 0; i-- { + handler = m.middlewares[i](handler) + } + + return handler +} + +type route struct { + method string + segments []string + wildcard bool + handler http.Handler +} + +func (r *route) match(ctx context.Context, urlSegments []string) (context.Context, bool) { + if !r.wildcard && len(urlSegments) != len(r.segments) { + return ctx, false + } + + for i, routeSegment := range r.segments { + if i > len(urlSegments)-1 { + return ctx, false + } + + if routeSegment == "..." { + ctx = context.WithValue(ctx, contextKey("..."), strings.Join(urlSegments[i:], "/")) + return ctx, true + } + + if strings.HasPrefix(routeSegment, ":") { + key, rxPattern, containsRx := strings.Cut(strings.TrimPrefix(routeSegment, ":"), "|") + + if containsRx { + if compiledRXPatterns[rxPattern].MatchString(urlSegments[i]) { + ctx = context.WithValue(ctx, contextKey(key), urlSegments[i]) + continue + } + } + + if !containsRx && urlSegments[i] != "" { + ctx = context.WithValue(ctx, contextKey(key), urlSegments[i]) + continue + } + + return ctx, false + } + + if urlSegments[i] != routeSegment { + return ctx, false + } + } + + return ctx, true +} diff --git a/adapters/humaflow/flow/flow_test.go b/adapters/humaflow/flow/flow_test.go new file mode 100644 index 00000000..d6b51f55 --- /dev/null +++ b/adapters/humaflow/flow/flow_test.go @@ -0,0 +1,547 @@ +package flow + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestMatching(t *testing.T) { + var tests = []struct { + RouteMethods []string + RoutePattern string + + RequestMethod string + RequestPath string + + ExpectedStatus int + ExpectedParams map[string]string + ExpectedAllowHeader string + }{ + // simple path matching + { + []string{"GET"}, "/one", + "GET", "/one", + http.StatusOK, nil, "", + }, + { + []string{"GET"}, "/one", + "GET", "/two", + http.StatusNotFound, nil, "", + }, + // nested + { + []string{"GET"}, "/parent/child/one", + "GET", "/parent/child/one", + http.StatusOK, nil, "", + }, + { + []string{"GET"}, "/parent/child/one", + "GET", "/parent/child/two", + http.StatusNotFound, nil, "", + }, + // misc no matches + { + []string{"GET"}, "/not/enough", + "GET", "/not/enough/items", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/not/enough/items", + "GET", "/not/enough", + http.StatusNotFound, nil, "", + }, + // wilcards + { + []string{"GET"}, "/prefix/...", + "GET", "/prefix/anything/else", + http.StatusOK, map[string]string{"...": "anything/else"}, "", + }, + { + []string{"GET"}, "/prefix/...", + "GET", "/prefix/", + http.StatusOK, map[string]string{"...": ""}, "", + }, + { + []string{"GET"}, "/prefix/...", + "GET", "/prefix", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/prefix", + "GET", "/prefix/anything/else", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/prefix/", + "GET", "/prefix/anything/else", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/prefix...", + "GET", "/prefix/anything/else", + http.StatusNotFound, nil, "", + }, + // path params + { + []string{"GET"}, "/path-params/:era/:group/:member", + "GET", "/path-params/60/beatles/lennon", + http.StatusOK, map[string]string{"era": "60", "group": "beatles", "member": "lennon"}, "", + }, + { + []string{"GET"}, "/path-params/:era/:group/:member/foo", + "GET", "/path-params/60/beatles/lennon/bar", + http.StatusNotFound, map[string]string{"era": "60", "group": "beatles", "member": "lennon"}, "", + }, + // regexp + { + []string{"GET"}, "/path-params/:era|^[0-9]{2}$/:group|^[a-z].+$", + "GET", "/path-params/60/beatles", + http.StatusOK, map[string]string{"era": "60", "group": "beatles"}, "", + }, + { + []string{"GET"}, "/path-params/:era|^[0-9]{2}$/:group|^[a-z].+$", + "GET", "/path-params/abc/123", + http.StatusNotFound, nil, "", + }, + // kitchen sink + { + []string{"GET"}, "/path-params/:id/:era|^[0-9]{2}$/...", + "GET", "/path-params/abc/12/foo/bar/baz", + http.StatusOK, map[string]string{"id": "abc", "era": "12", "...": "foo/bar/baz"}, "", + }, + { + []string{"GET"}, "/path-params/:id/:era|^[0-9]{2}$/...", + "GET", "/path-params/abc/12", + http.StatusNotFound, nil, "", + }, + // leading and trailing slashes + { + []string{"GET"}, "slashes/one", + "GET", "/slashes/one", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/slashes/two", + "GET", "slashes/two", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/slashes/three/", + "GET", "/slashes/three", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/slashes/four", + "GET", "/slashes/four/", + http.StatusNotFound, nil, "", + }, + // empty segments + { + []string{"GET"}, "/baz/:id/:age", + "GET", "/baz/123/", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/baz/:id/:age/", + "GET", "/baz/123//", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/baz/:id/:age", + "GET", "/baz//21", + http.StatusNotFound, nil, "", + }, + { + []string{"GET"}, "/baz//:age", + "GET", "/baz//21", + http.StatusOK, nil, "", + }, + { + // with a regexp to specifically allow empty segments + []string{"GET"}, "/baz/:id|^$/:age/", + "GET", "/baz//21/", + http.StatusOK, nil, "", + }, + // methods + { + []string{"POST"}, "/one", + "POST", "/one", + http.StatusOK, nil, "", + }, + { + []string{"GET"}, "/one", + "POST", "/one", + http.StatusMethodNotAllowed, nil, "", + }, + // multiple methods + { + []string{"GET", "POST", "PUT"}, "/one", + "POST", "/one", + http.StatusOK, nil, "", + }, + { + []string{"GET", "POST", "PUT"}, "/one", + "PUT", "/one", + http.StatusOK, nil, "", + }, + { + []string{"GET", "POST", "PUT"}, "/one", + "DELETE", "/one", + http.StatusMethodNotAllowed, nil, "", + }, + // all methods + { + []string{}, "/one", + "GET", "/one", + http.StatusOK, nil, "", + }, + { + []string{}, "/one", + "DELETE", "/one", + http.StatusOK, nil, "", + }, + // method casing + { + []string{"gEt"}, "/one", + "GET", "/one", + http.StatusOK, nil, "", + }, + // head requests + { + []string{"GET"}, "/one", + "HEAD", "/one", + http.StatusOK, nil, "", + }, + { + []string{"HEAD"}, "/one", + "HEAD", "/one", + http.StatusOK, nil, "", + }, + { + []string{"HEAD"}, "/one", + "GET", "/one", + http.StatusMethodNotAllowed, nil, "", + }, + // allow header + { + []string{"GET", "PUT"}, "/one", + "DELETE", "/one", + http.StatusMethodNotAllowed, nil, "GET, PUT, HEAD, OPTIONS", + }, + // options + { + []string{"GET", "PUT"}, "/one", + "OPTIONS", "/one", + http.StatusNoContent, nil, "GET, PUT, HEAD, OPTIONS", + }, + } + + for _, test := range tests { + m := New() + + var ctx context.Context + + hf := func(w http.ResponseWriter, r *http.Request) { + ctx = r.Context() + } + + m.HandleFunc(test.RoutePattern, hf, test.RouteMethods...) + + r, err := http.NewRequest(test.RequestMethod, test.RequestPath, nil) + if err != nil { + t.Errorf("NewRequest: %s", err) + } + + rr := httptest.NewRecorder() + m.ServeHTTP(rr, r) + + rs := rr.Result() + + if rs.StatusCode != test.ExpectedStatus { + t.Errorf("%s %s: expected status %d but was %d", test.RequestMethod, test.RequestPath, test.ExpectedStatus, rr.Code) + continue + } + + if rs.StatusCode == http.StatusOK && len(test.ExpectedParams) > 0 { + for expK, expV := range test.ExpectedParams { + actualValStr := Param(ctx, expK) + if actualValStr != expV { + t.Errorf("Param: context value %s expected \"%s\" but was \"%s\"", expK, expV, actualValStr) + } + } + } + + if test.ExpectedAllowHeader != "" { + actualAllowHeader := rs.Header.Get("Allow") + if actualAllowHeader != test.ExpectedAllowHeader { + t.Errorf("%s %s: expected Allow header %q but was %q", test.RequestMethod, test.RequestPath, test.ExpectedAllowHeader, actualAllowHeader) + } + } + + } +} + +func TestMiddleware(t *testing.T) { + used := "" + + mw1 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used += "1" + next.ServeHTTP(w, r) + }) + } + + mw2 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used += "2" + next.ServeHTTP(w, r) + }) + } + + mw3 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used += "3" + next.ServeHTTP(w, r) + }) + } + + mw4 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used += "4" + next.ServeHTTP(w, r) + }) + } + + mw5 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used += "5" + next.ServeHTTP(w, r) + }) + } + + mw6 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + used += "6" + next.ServeHTTP(w, r) + }) + } + + hf := func(w http.ResponseWriter, r *http.Request) {} + + m := New() + m.Use(mw1) + m.Use(mw2) + + m.HandleFunc("/", hf, "GET") + + m.Group(func(m *Mux) { + m.Use(mw3, mw4) + m.HandleFunc("/foo", hf, "GET") + + m.Group(func(m *Mux) { + m.Use(mw5) + m.HandleFunc("/nested/foo", hf, "GET") + }) + }) + + m.Group(func(m *Mux) { + m.Use(mw6) + m.HandleFunc("/bar", hf, "GET") + }) + + m.HandleFunc("/baz", hf, "GET") + + var tests = []struct { + RequestMethod string + RequestPath string + ExpectedUsed string + ExpectedStatus int + }{ + { + RequestMethod: "GET", + RequestPath: "/", + ExpectedUsed: "12", + ExpectedStatus: http.StatusOK, + }, + { + RequestMethod: "GET", + RequestPath: "/foo", + ExpectedUsed: "1234", + ExpectedStatus: http.StatusOK, + }, + { + RequestMethod: "GET", + RequestPath: "/nested/foo", + ExpectedUsed: "12345", + ExpectedStatus: http.StatusOK, + }, + { + RequestMethod: "GET", + RequestPath: "/bar", + ExpectedUsed: "126", + ExpectedStatus: http.StatusOK, + }, + { + RequestMethod: "GET", + RequestPath: "/baz", + ExpectedUsed: "12", + ExpectedStatus: http.StatusOK, + }, + // Check top-level middleware used on errors and OPTIONS + { + RequestMethod: "GET", + RequestPath: "/notfound", + ExpectedUsed: "12", + ExpectedStatus: http.StatusNotFound, + }, + { + RequestMethod: "POST", + RequestPath: "/nested/foo", + ExpectedUsed: "12", + ExpectedStatus: http.StatusMethodNotAllowed, + }, + { + RequestMethod: "OPTIONS", + RequestPath: "/nested/foo", + ExpectedUsed: "12", + ExpectedStatus: http.StatusNoContent, + }, + } + + for _, test := range tests { + used = "" + + r, err := http.NewRequest(test.RequestMethod, test.RequestPath, nil) + if err != nil { + t.Errorf("NewRequest: %s", err) + } + + rr := httptest.NewRecorder() + m.ServeHTTP(rr, r) + + rs := rr.Result() + + if rs.StatusCode != test.ExpectedStatus { + t.Errorf("%s %s: expected status %d but was %d", test.RequestMethod, test.RequestPath, test.ExpectedStatus, rs.StatusCode) + } + + if used != test.ExpectedUsed { + t.Errorf("%s %s: middleware used: expected %q; got %q", test.RequestMethod, test.RequestPath, test.ExpectedUsed, used) + } + } +} + +func TestCustomHandlers(t *testing.T) { + hf := func(w http.ResponseWriter, r *http.Request) {} + + m := New() + m.NotFound = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom not found handler")) + }) + m.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom method not allowed handler")) + }) + m.Options = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom options handler")) + }) + + m.HandleFunc("/", hf, "GET") + + var tests = []struct { + RequestMethod string + RequestPath string + + ExpectedBody string + }{ + { + RequestMethod: "GET", + RequestPath: "/notfound", + ExpectedBody: "custom not found handler", + }, + { + RequestMethod: "POST", + RequestPath: "/", + ExpectedBody: "custom method not allowed handler", + }, + { + RequestMethod: "OPTIONS", + RequestPath: "/", + ExpectedBody: "custom options handler", + }, + } + + for _, test := range tests { + r, err := http.NewRequest(test.RequestMethod, test.RequestPath, nil) + if err != nil { + t.Errorf("NewRequest: %s", err) + } + + rr := httptest.NewRecorder() + m.ServeHTTP(rr, r) + + rs := rr.Result() + + defer rs.Body.Close() + body, err := io.ReadAll(rs.Body) + if err != nil { + t.Fatal(err) + } + + if string(body) != test.ExpectedBody { + t.Errorf("%s %s: expected body %q; got %q", test.RequestMethod, test.RequestPath, test.ExpectedBody, string(body)) + } + } +} + +func TestParams(t *testing.T) { + var tests = []struct { + RouteMethods []string + RoutePattern string + + RequestMethod string + RequestPath string + + ParamName string + HasParam bool + ParamValue string + }{ + { + []string{"GET"}, "/foo/:id", + "GET", "/foo/123", + "id", true, "123", + }, + { + []string{"GET"}, "/foo/:id", + "GET", "/foo/123", + "missing", false, "", + }, + } + + for _, test := range tests { + m := New() + + var ctx context.Context + + hf := func(w http.ResponseWriter, r *http.Request) { + ctx = r.Context() + } + + m.HandleFunc(test.RoutePattern, hf, test.RouteMethods...) + + r, err := http.NewRequest(test.RequestMethod, test.RequestPath, nil) + if err != nil { + t.Errorf("NewRequest: %s", err) + } + + rr := httptest.NewRecorder() + m.ServeHTTP(rr, r) + + actualValStr := Param(ctx, test.ParamName) + if actualValStr != test.ParamValue { + t.Errorf("expected \"%s\" but was \"%s\"", test.ParamValue, actualValStr) + } + } +} diff --git a/adapters/humaflow/humaflow.go b/adapters/humaflow/humaflow.go new file mode 100644 index 00000000..87f83905 --- /dev/null +++ b/adapters/humaflow/humaflow.go @@ -0,0 +1,146 @@ +package humaflow + +import ( + "context" + "io" + "mime/multipart" + "net/http" + "net/url" + "strings" + "time" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humaflow/flow" + "github.com/danielgtaylor/huma/v2/queryparam" +) + +// MultipartMaxMemory is the maximum memory to use when parsing multipart +// form data. +var MultipartMaxMemory int64 = 8 * 1024 + +type goContext struct { + op *huma.Operation + r *http.Request + w http.ResponseWriter +} + +func (c *goContext) Operation() *huma.Operation { + return c.op +} + +func (c *goContext) Context() context.Context { + return c.r.Context() +} + +func (c *goContext) Method() string { + return c.r.Method +} + +func (c *goContext) Host() string { + return c.r.Host +} + +func (c *goContext) URL() url.URL { + return *c.r.URL +} + +func (c *goContext) Param(name string) string { + return flow.Param(c.r.Context(), name) +} + +func (c *goContext) Query(name string) string { + return queryparam.Get(c.r.URL.RawQuery, name) +} + +func (c *goContext) Header(name string) string { + return c.r.Header.Get(name) +} + +func (c *goContext) EachHeader(cb func(name, value string)) { + for name, values := range c.r.Header { + for _, value := range values { + cb(name, value) + } + } +} + +func (c *goContext) BodyReader() io.Reader { + return c.r.Body +} + +func (c *goContext) GetMultipartForm() (*multipart.Form, error) { + err := c.r.ParseMultipartForm(MultipartMaxMemory) + return c.r.MultipartForm, err +} + +func (c *goContext) SetReadDeadline(deadline time.Time) error { + return huma.SetReadDeadline(c.w, deadline) +} + +func (c *goContext) SetStatus(code int) { + c.w.WriteHeader(code) +} + +func (c *goContext) AppendHeader(name string, value string) { + c.w.Header().Add(name, value) +} + +func (c *goContext) SetHeader(name string, value string) { + c.w.Header().Set(name, value) +} + +func (c *goContext) BodyWriter() io.Writer { + return c.w +} + +// NewContext creates a new Huma context from an HTTP request and response. +func NewContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context { + return &goContext{op: op, r: r, w: w} +} + +type Mux interface { + HandleFunc(pattern string, handler http.HandlerFunc, methods ...string) + ServeHTTP(http.ResponseWriter, *http.Request) +} + +type goAdapter struct { + Mux + prefix string +} + +func (a *goAdapter) Handle(op *huma.Operation, handler func(huma.Context)) { + // Convert {param} to :param + path := op.Path + path = strings.ReplaceAll(path, "{", ":") + path = strings.ReplaceAll(path, "}", "") + a.HandleFunc(a.prefix+path, func(w http.ResponseWriter, r *http.Request) { + handler(&goContext{op: op, r: r, w: w}) + }, op.Method) +} + +// NewAdapter creates a new adapter for the given chi router. +func NewAdapter(m Mux) huma.Adapter { + return &goAdapter{Mux: m} +} + +// New creates a new Huma API using an HTTP mux. +// +// mux := http.NewServeMux() +// api := humago.New(mux, huma.DefaultConfig("My API", "1.0.0")) +func New(m Mux, config huma.Config) huma.API { + return huma.NewAPI(config, &goAdapter{m, ""}) +} + +// NewWithPrefix creates a new Huma API using an HTTP mux with a URL prefix. +// This behaves similar to other router's group functionality, adding the prefix +// before each route path (but not in the OpenAPI). The prefix should be used in +// combination with the `OpenAPI().Servers` base path to ensure the correct URLs +// are generated in the OpenAPI spec. +// +// mux := flow.New() +// config := huma.DefaultConfig("My API", "1.0.0") +// config.Servers = []*huma.Server{{URL: "http://example.com/api"}} +// api := humago.NewWithPrefix(mux, "/api", config) +func NewWithPrefix(m Mux, prefix string, config huma.Config) huma.API { + return huma.NewAPI(config, &goAdapter{m, prefix}) +} diff --git a/api_test.go b/api_test.go index a2b5f51b..3eeca4d9 100644 --- a/api_test.go +++ b/api_test.go @@ -13,7 +13,7 @@ import ( ) func TestBlankConfig(t *testing.T) { - adapter := humatest.NewAdapter(chi.NewMux()) + adapter := humatest.NewAdapter() assert.NotPanics(t, func() { huma.NewAPI(huma.Config{}, adapter) @@ -28,7 +28,7 @@ func TestBlankConfig(t *testing.T) { // including the parameter and response definitions & schemas. func ExampleAdapter_handle() { // Create an adapter for your chosen router. - adapter := NewExampleAdapter(chi.NewMux()) + adapter := NewExampleAdapter() // Register an operation with a custom handler. adapter.Handle(&huma.Operation{ @@ -89,7 +89,7 @@ func TestContextValue(t *testing.T) { assert.Equal(t, http.StatusNoContent, resp.Code) } -func TestRouterPrefix(t *testing.T) { +func TestChiRouterPrefix(t *testing.T) { mux := chi.NewMux() var api huma.API mux.Route("/api", func(r chi.Router) { @@ -110,7 +110,7 @@ func TestRouterPrefix(t *testing.T) { }) // Create a test API around the underlying router to make easier requests. - tapi := humatest.NewTestAPI(t, mux, huma.Config{}) + tapi := humatest.Wrap(t, humachi.New(mux, huma.DefaultConfig("Test", "1.0.0"))) // The top-level router should respond to the full path even though the // operation was registered with just `/test`. diff --git a/huma_test.go b/huma_test.go index cd1be7cf..453eaa1d 100644 --- a/huma_test.go +++ b/huma_test.go @@ -1080,7 +1080,7 @@ Content of example2.txt. if feature.Transformers != nil { config.Transformers = append(config.Transformers, feature.Transformers...) } - api := humatest.NewTestAPI(t, r, config) + api := humatest.Wrap(t, humachi.New(r, config)) feature.Register(t, api) var body io.Reader = nil diff --git a/humatest/humatest.go b/humatest/humatest.go index 4711e68b..e186144a 100644 --- a/humatest/humatest.go +++ b/humatest/humatest.go @@ -1,21 +1,22 @@ -// Package humatest provides testing utilities for Huma services. It is based -// on the `chi` router and the standard library `http.Request` & -// `http.ResponseWriter` types. +// Package humatest provides testing utilities for Huma services. It is based on +// the standard library `http.Request` & `http.ResponseWriter` types. package humatest import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" "net/http/httputil" "reflect" "strings" + "testing/iotest" "github.com/danielgtaylor/huma/v2" - "github.com/danielgtaylor/huma/v2/adapters/humachi" - "github.com/go-chi/chi/v5" + "github.com/danielgtaylor/huma/v2/adapters/humaflow" + "github.com/danielgtaylor/huma/v2/adapters/humaflow/flow" ) // TB is a subset of the `testing.TB` interface used by the test API and @@ -28,12 +29,12 @@ type TB interface { // NewContext creates a new test context from an HTTP request and response. func NewContext(op *huma.Operation, r *http.Request, w http.ResponseWriter) huma.Context { - return humachi.NewContext(op, r, w) + return humaflow.NewContext(op, r, w) } -// NewAdapter creates a new test adapter from a chi router. -func NewAdapter(r chi.Router) huma.Adapter { - return humachi.NewAdapter(r) +// NewAdapter creates a new test adapter from a router. +func NewAdapter() huma.Adapter { + return humaflow.NewAdapter(flow.New()) } // TestAPI is a `huma.API` with additional methods specifically for testing. @@ -153,12 +154,12 @@ func (a *testAPI) Do(method, path string, args ...any) *httptest.ResponseRecorde } resp := httptest.NewRecorder() - bytes, _ := httputil.DumpRequest(req, b != nil) + bytes, _ := DumpRequest(req) a.tb.Log("Making request:\n" + strings.TrimSpace(string(bytes))) a.Adapter().ServeHTTP(resp, req) - bytes, _ = httputil.DumpResponse(resp.Result(), resp.Body.Len() > 0) + bytes, _ = DumpResponse(resp.Result()) a.tb.Log("Got response:\n" + strings.TrimSpace(string(bytes))) return resp @@ -189,11 +190,6 @@ func (a *testAPI) Delete(path string, args ...any) *httptest.ResponseRecorder { return a.Do(http.MethodDelete, path, args...) } -// NewTestAPI creates a new test API from a chi router and API config. -func NewTestAPI(tb TB, r chi.Router, config huma.Config) TestAPI { - return Wrap(tb, humachi.New(r, config)) -} - // Wrap returns a `TestAPI` wrapping the given API. func Wrap(tb TB, api huma.API) TestAPI { return &testAPI{api, tb} @@ -203,7 +199,7 @@ func Wrap(tb TB, api huma.API) TestAPI { // and perform requests against them. Optionally takes a configuration object // to customize how the API is created. If no configuration is provided then // a simple default configuration supporting `application/json` is used. -func New(tb TB, configs ...huma.Config) (chi.Router, TestAPI) { +func New(tb TB, configs ...huma.Config) (http.Handler, TestAPI) { if len(configs) == 0 { configs = append(configs, huma.Config{ OpenAPI: &huma.OpenAPI{ @@ -219,6 +215,72 @@ func New(tb TB, configs ...huma.Config) (chi.Router, TestAPI) { DefaultFormat: "application/json", }) } - r := chi.NewRouter() - return r, NewTestAPI(tb, r, configs[0]) + r := flow.New() + return r, Wrap(tb, humaflow.New(r, configs[0])) +} + +func dumpBody(body io.ReadCloser, buf *bytes.Buffer) (io.ReadCloser, error) { + if body == nil { + return nil, nil + } + + b, err := io.ReadAll(body) + if err != nil { + return io.NopCloser(iotest.ErrReader(err)), err + } + body.Close() + if strings.Contains(buf.String(), "json") { + json.Indent(buf, b, "", " ") + } else { + buf.Write(b) + } + return io.NopCloser(bytes.NewReader(b)), nil +} + +// DumpRequest returns a string representation of an HTTP request, automatically +// pretty printing JSON bodies for readability. +func DumpRequest(req *http.Request) ([]byte, error) { + var buf bytes.Buffer + b, err := httputil.DumpRequest(req, false) + + if err == nil { + buf.Write(b) + req.Body, err = dumpBody(req.Body, &buf) + } + + return buf.Bytes(), err +} + +// DumpResponse returns a string representation of an HTTP response, +// automatically pretty printing JSON bodies for readability. +func DumpResponse(resp *http.Response) ([]byte, error) { + var buf bytes.Buffer + b, err := httputil.DumpResponse(resp, false) + + if err == nil { + buf.Write(b) + resp.Body, err = dumpBody(resp.Body, &buf) + } + + return buf.Bytes(), err +} + +// PrintRequest prints a string representation of an HTTP request to stdout, +// automatically pretty printing JSON bodies for readability. +func PrintRequest(req *http.Request) { + b, _ := DumpRequest(req) + // Turn `/r/n` into `/n` for more straightforward output that is also + // compatible with Go's testable examples. + b = bytes.ReplaceAll(b, []byte("\r"), []byte("")) + fmt.Println(string(b)) +} + +// PrintResponse prints a string representation of an HTTP response to stdout, +// automatically pretty printing JSON bodies for readability. +func PrintResponse(resp *http.Response) { + b, _ := DumpResponse(resp) + // Turn `/r/n` into `/n` for more straightforward output that is also + // compatible with Go's testable examples. + b = bytes.ReplaceAll(b, []byte("\r"), []byte("")) + fmt.Println(string(b)) } diff --git a/humatest/humatest_test.go b/humatest/humatest_test.go index cafa2999..8744a274 100644 --- a/humatest/humatest_test.go +++ b/humatest/humatest_test.go @@ -2,16 +2,49 @@ package humatest import ( "context" + "io" "net/http" "net/http/httptest" "strings" "testing" + "testing/iotest" "github.com/danielgtaylor/huma/v2" - "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func ExamplePrintRequest() { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/foo?bar=baz", nil) + req.Header.Set("Foo", "bar") + req.Host = "example.com" + PrintRequest(req) + // Output: GET /foo?bar=baz HTTP/1.1 + // Host: example.com + // Foo: bar +} + +func ExamplePrintResponse() { + resp := &http.Response{ + StatusCode: http.StatusOK, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + ContentLength: -1, + Body: io.NopCloser(strings.NewReader(`{"foo": "bar"}`)), + } + PrintResponse(resp) + // Output: HTTP/1.1 200 OK + // Connection: close + // Content-Type: application/json + // + // { + // "foo": "bar" + // } +} + type Response struct { MyHeader string `header:"My-Header"` Body struct { @@ -96,12 +129,12 @@ func TestContext(t *testing.T) { } func TestAdapter(t *testing.T) { - var _ huma.Adapter = NewAdapter(chi.NewMux()) + var _ huma.Adapter = NewAdapter() } func TestNewAPI(t *testing.T) { - r := chi.NewMux() - var api huma.API = NewTestAPI(t, r, huma.DefaultConfig("Test", "1.0.0")) + var api huma.API + _, api = New(t, huma.DefaultConfig("Test", "1.0.0")) // Should be able to wrap and call utility methods. wrapped := Wrap(t, api) @@ -116,3 +149,18 @@ func TestNewAPI(t *testing.T) { wrapped.Post("/", 1234) }) } + +func TestDumpBodyError(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/foo?bar=baz", nil) + req.Header.Set("Foo", "bar") + req.Host = "example.com" + req.Body = io.NopCloser(iotest.ErrReader(io.ErrUnexpectedEOF)) + + // Error should return. + _, err := DumpRequest(req) + require.Error(t, err) + + // Error should be passed through. + _, err = io.ReadAll(req.Body) + require.Error(t, err) +}