-
-
Notifications
You must be signed in to change notification settings - Fork 241
Description
I've been playing with Huma for the last couple days, mostly stuck on getting middlewares to do what I want them to. This post contains some solutions that others might find helpful and a proposal for a new type of middleware that would make life a lot easier.
I wanted to build these middlewares:
Logging: Log what endpoints are being called, and in case of a >=400 status, also log request headers and body for debugging.DB: Start a DB transaction before handler is called, commit on successful result. Rollback on error (>=400) status.Caching: Cache output of API calls. Cache key should be dependent on input, cache should only be populated if the result has status 200.
Each of these middlewares need to access the response status code, some of them also the input or output structs. Different solutions have been posted for how to get this, mine is based on using r, rw := humago.Unwrap(ctx) on the request and recording the response with httptest.NewRecorder(). I'll attach the full middleware for your reference. It's quite complex and feels clunky. There are also some gotchas: it relies on humago router, but humatest uses humaflow, so the middleware can't be used in tests.
While working on this, I was wishing we could have a middleware that is called after verification and before serialization, so that it can see the final input and output structs. It should also have access to the huma ctx, in order to see operation information like operationID and metadata. And it should be able to edit response objects or replace them with errors. Maybe the middleware could inject itself into the "Resolve" and "Transform" stages somehow?
After a lot of trial and error, the best I could come up with was to wrap the huma.Register() function itself. I'll post the result below. However that's not perfect either: It only works when the handler is actually called. If input verification fails, or the router fails to match a route, these middlewares are not called. So I've had to keep the Logging middleware in place to catch those.
Going through the open issues here on GitHub, I found several related issues and proposals, which could all be handled by this proposal:
- Accessing validation results in the request pipeline #453
- Discussion: How best to catch and log 500s? #470
- Discussion: option to make all APIs registered with security scheme by default #490
- Logging Errors in Huma Middleware #679
- Middleware and handler function call stack #877
Wrapper for huma-Register():
// humaHandlerFunc is the signature for a Huma operation handler
type humaHandlerFunc[I, O any] func(context.Context, *I) (*O, error)
// Register is a wrapper around huma.Register that applies middlewares which need to be aware of input and output objects.
// the middlewares also have access to modify the huma.Operation object
func Register[I, O any](api huma.API, op huma.Operation, handler humaHandlerFunc[I, O]) {
// NOTE: handler funcs will be called in REVERSE ORDER, op will be updated in NORMAL order
op, handler = withDBTransaction(op, handler)
op, handler = withCaching(op, handler)
op, handler = withPanicCatcher(op, handler)
huma.Register(api, op, handler)
}
// withDBTransaction wraps a handler with DB transaction logic
func withDBTransaction[I, O any](op huma.Operation, next humaHandlerFunc[I, O]) (huma.Operation, humaHandlerFunc[I, O]) {
return op, func(ctx context.Context, input *I) (*O, error) {
myCtx := GetMyCtx(ctx)
tx, err := myCtx.Pool.Begin(ctx)
if err != nil {
return nil, huma.Error500InternalServerError("failed to begin DB TX", err)
}
myCtx.Tx = tx
// Call the next handler
output, err := next(ctx, input)
// Commit or rollback based on handler result
if err != nil {
if rbErr := tx.Rollback(ctx); rbErr != nil {
myCtx.Slog.Error("failed to rollback",
"rbError", rbErr.Error(),
"handlerError", err.Error(),
)
}
return output, err
}
if cmErr := tx.Commit(ctx); cmErr != nil {
return nil, huma.Error500InternalServerError("failed to commit", cmErr)
}
return output, nil
}
}Logging middleware with extra logs for failed requests:
// LogReqMiddleware logs each request with method, path, status code, duration, and operation ID.
// For 4xx/5xx errors, it also logs the full request/response for debugging.
// This middleware should be outermost to capture all errors including from inner middlewares.
// The code would be much simpler if we could add it to the register chain, like our DB middleware.
// But then it wouldn't capture Huma verification errors and generic 404s.. so it needs to stay here.
func LogReqMiddleware(ctx huma.Context, next func(huma.Context)) {
start := time.Now()
// Get operation info
op := ctx.Operation()
operationID := "unknown_op_id"
if op != nil {
operationID = op.OperationID
}
// Unwrap to get the underlying http.Request and http.ResponseWriter
// WARNING: this relies on the router being humago. Tests use humaflex and will panic if used.
r, rw := humago.Unwrap(ctx)
// Read and buffer the request body so it can be logged on error
var requestBody []byte
if r.Body != nil {
requestBody, _ = io.ReadAll(ctx.BodyReader())
r.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
// Use httptest.ResponseRecorder to capture response
recorder := httptest.NewRecorder()
wrappedCtx := humago.NewContext(op, r, recorder)
wrappedCtx = huma.WithContext(wrappedCtx, ctx.Context())
next(wrappedCtx)
duration := int(time.Since(start).Milliseconds())
// Only show whitelisted headers to reduce log spam
filteredHeaders := util.FilterMap(r.Header, func(name string, vals []string) bool {
return slices.Contains([]string{
"authorization",
}, strings.ToLower(name))
})
statusCode := recorder.Code
if statusCode >= http.StatusBadRequest {
// Log full details for unsuccessful requests to help reproduce the problem
// Parse JSON bodies so they appear as nested objects, not escaped strings
var reqBodyJSON, resBodyJSON any
json.Unmarshal(requestBody, &reqBodyJSON)
json.Unmarshal(recorder.Body.Bytes(), &resBodyJSON)
slog.Error("request",
slog.String("method", r.Method),
slog.String("path", r.RequestURI),
slog.String("operation", operationID),
slog.Int("status", statusCode),
slog.Int("duration_ms", duration),
slog.Any("req_headers", filteredHeaders),
slog.Any("req_body", reqBodyJSON),
slog.Any("resp_body", resBodyJSON),
)
} else {
slog.Info("request",
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.String("operation", operationID),
// slog.Int("status", statusCode),
slog.Int("duration_ms", duration),
)
}
// Copy recorded response to actual writer
for k, v := range recorder.Header() {
for _, val := range v {
rw.Header().Add(k, val)
}
}
rw.WriteHeader(statusCode)
rw.Write(recorder.Body.Bytes())
recorder.Flush()
}