Skip to content

Commit 3e2188a

Browse files
committed
fixed work with fiber.Ctx, fiber.UserContext - fixed graceful shutdown and race condition on access to huma.Context outside handler
1 parent f9ffb6a commit 3e2188a

File tree

2 files changed

+472
-114
lines changed

2 files changed

+472
-114
lines changed

adapters/humafiber/humafiber.go

+95-114
Original file line numberDiff line numberDiff line change
@@ -15,168 +15,118 @@ import (
1515
"github.com/gofiber/fiber/v2"
1616
)
1717

18-
type fiberCtx struct {
18+
type fiberAdapter struct {
19+
tester requestTester
20+
router router
21+
}
22+
23+
type fiberWrapper struct {
1924
op *huma.Operation
2025
status int
21-
22-
/*
23-
* Web framework "fiber" https://gofiber.io/ uses high-performance zero-allocation "fasthttp" server https://github.com/valyala/fasthttp
24-
*
25-
* The underlying fasthttp server prohibits to use or refer to `*fasthttp.RequestCtx` outside handler
26-
* The quote from documentation to fasthttp https://github.com/valyala/fasthttp/blob/master/README.md
27-
*
28-
* > VERY IMPORTANT! Fasthttp disallows holding references to RequestCtx or to its' members after returning from RequestHandler. Otherwise data races are inevitable. Carefully inspect all the net/http request handlers converted to fasthttp whether they retain references to RequestCtx or to its' members after returning
29-
*
30-
* As the result "fiber" prohibits to use or refer to `*fiber.Ctx` outside handler
31-
* The quote from documentation to fiber https://docs.gofiber.io/#zero-allocation
32-
*
33-
* > Because fiber is optimized for high-performance, values returned from fiber.Ctx are not immutable by default and will be re-used across requests. As a rule of thumb, you must only use context values within the handler, and you must not keep any references. As soon as you return from the handler, any values you have obtained from the context will be re-used in future requests and will change below your feet
34-
*
35-
* To deal with these limitations, the contributor of to this adapter @excavador (Oleg Tsarev, email: oleg@tsarev.id, telegram: @oleg_tsarev) is clear variable explicitly in the end of huma.Adapter methods Handle and ServeHTTP
36-
*
37-
* You must NOT use member `unsafeFiberCtx` directly in adapter, but instead use `orig()` private method
38-
*/
39-
unsafeFiberCtx *fiber.Ctx
40-
unsafeGolangCtx context.Context
26+
orig *fiber.Ctx
27+
ctx context.Context
4128
}
4229

4330
// check that fiberCtx implements huma.Context
44-
var _ huma.Context = &fiberCtx{}
45-
var _ context.Context = &fiberCtx{}
46-
47-
func (c *fiberCtx) orig() *fiber.Ctx {
48-
var result = c.unsafeFiberCtx
49-
select {
50-
case <-c.unsafeGolangCtx.Done():
51-
panic("handler was done already")
52-
default:
53-
return result
54-
}
55-
}
56-
57-
func (c *fiberCtx) Deadline() (deadline time.Time, ok bool) {
58-
return c.unsafeGolangCtx.Deadline()
59-
}
60-
61-
func (c *fiberCtx) Done() <-chan struct{} {
62-
return c.unsafeGolangCtx.Done()
63-
}
64-
65-
func (c *fiberCtx) Err() error {
66-
return c.unsafeGolangCtx.Err()
67-
}
68-
69-
func (c *fiberCtx) Value(key any) any {
70-
var orig = c.unsafeFiberCtx
71-
select {
72-
case <-c.unsafeGolangCtx.Done():
73-
return nil
74-
default:
75-
var value = orig.UserContext().Value(key)
76-
if value != nil {
77-
return value
78-
}
79-
return orig.Context().Value(key)
80-
}
81-
}
31+
var _ huma.Context = &fiberWrapper{}
8232

83-
func (c *fiberCtx) Operation() *huma.Operation {
33+
func (c *fiberWrapper) Operation() *huma.Operation {
8434
return c.op
8535
}
8636

87-
func (c *fiberCtx) Matched() string {
88-
return c.orig().Route().Path
37+
func (c *fiberWrapper) Matched() string {
38+
return c.orig.Route().Path
8939
}
9040

91-
func (c *fiberCtx) Context() context.Context {
92-
return c
41+
func (c *fiberWrapper) Context() context.Context {
42+
return c.ctx
9343
}
9444

95-
func (c *fiberCtx) Method() string {
96-
return c.orig().Method()
45+
func (c *fiberWrapper) Method() string {
46+
return c.orig.Method()
9747
}
9848

99-
func (c *fiberCtx) Host() string {
100-
return c.orig().Hostname()
49+
func (c *fiberWrapper) Host() string {
50+
return c.orig.Hostname()
10151
}
10252

103-
func (c *fiberCtx) RemoteAddr() string {
104-
return c.orig().Context().RemoteAddr().String()
53+
func (c *fiberWrapper) RemoteAddr() string {
54+
return c.orig.Context().RemoteAddr().String()
10555
}
10656

107-
func (c *fiberCtx) URL() url.URL {
108-
u, _ := url.Parse(string(c.orig().Request().RequestURI()))
57+
func (c *fiberWrapper) URL() url.URL {
58+
u, _ := url.Parse(string(c.orig.Request().RequestURI()))
10959
return *u
11060
}
11161

112-
func (c *fiberCtx) Param(name string) string {
113-
return c.orig().Params(name)
62+
func (c *fiberWrapper) Param(name string) string {
63+
return c.orig.Params(name)
11464
}
11565

116-
func (c *fiberCtx) Query(name string) string {
117-
return c.orig().Query(name)
66+
func (c *fiberWrapper) Query(name string) string {
67+
return c.orig.Query(name)
11868
}
11969

120-
func (c *fiberCtx) Header(name string) string {
121-
return c.orig().Get(name)
70+
func (c *fiberWrapper) Header(name string) string {
71+
return c.orig.Get(name)
12272
}
12373

124-
func (c *fiberCtx) EachHeader(cb func(name, value string)) {
125-
c.orig().Request().Header.VisitAll(func(k, v []byte) {
74+
func (c *fiberWrapper) EachHeader(cb func(name, value string)) {
75+
c.orig.Request().Header.VisitAll(func(k, v []byte) {
12676
cb(string(k), string(v))
12777
})
12878
}
12979

130-
func (c *fiberCtx) BodyReader() io.Reader {
131-
var orig = c.orig()
80+
func (c *fiberWrapper) BodyReader() io.Reader {
81+
var orig = c.orig
13282
if orig.App().Server().StreamRequestBody {
13383
// Streaming is enabled, so send the reader.
13484
return orig.Request().BodyStream()
13585
}
13686
return bytes.NewReader(orig.BodyRaw())
13787
}
13888

139-
func (c *fiberCtx) GetMultipartForm() (*multipart.Form, error) {
140-
return c.orig().MultipartForm()
89+
func (c *fiberWrapper) GetMultipartForm() (*multipart.Form, error) {
90+
return c.orig.MultipartForm()
14191
}
14292

143-
func (c *fiberCtx) SetReadDeadline(deadline time.Time) error {
93+
func (c *fiberWrapper) SetReadDeadline(deadline time.Time) error {
14494
// Note: for this to work properly you need to do two things:
14595
// 1. Set the Fiber app's `StreamRequestBody` to `true`
14696
// 2. Set the Fiber app's `BodyLimit` to some small value like `1`
14797
// Fiber will only call the request handler for streaming once the limit is
14898
// reached. This is annoying but currently how things work.
149-
return c.orig().Context().Conn().SetReadDeadline(deadline)
99+
return c.orig.Context().Conn().SetReadDeadline(deadline)
150100
}
151101

152-
func (c *fiberCtx) SetStatus(code int) {
153-
var orig = c.orig()
102+
func (c *fiberWrapper) SetStatus(code int) {
103+
var orig = c.orig
154104
c.status = code
155105
orig.Status(code)
156106
}
157107

158-
func (c *fiberCtx) Status() int {
108+
func (c *fiberWrapper) Status() int {
159109
return c.status
160110
}
161-
func (c *fiberCtx) AppendHeader(name string, value string) {
162-
c.orig().Append(name, value)
111+
func (c *fiberWrapper) AppendHeader(name string, value string) {
112+
c.orig.Append(name, value)
163113
}
164114

165-
func (c *fiberCtx) SetHeader(name string, value string) {
166-
c.orig().Set(name, value)
115+
func (c *fiberWrapper) SetHeader(name string, value string) {
116+
c.orig.Set(name, value)
167117
}
168118

169-
func (c *fiberCtx) BodyWriter() io.Writer {
170-
return c.orig().Context()
119+
func (c *fiberWrapper) BodyWriter() io.Writer {
120+
return c.orig.Context()
171121
}
172122

173-
func (c *fiberCtx) TLS() *tls.ConnectionState {
174-
return c.orig().Context().TLSConnectionState()
123+
func (c *fiberWrapper) TLS() *tls.ConnectionState {
124+
return c.orig.Context().TLSConnectionState()
175125
}
176126

177-
func (c *fiberCtx) Version() huma.ProtoVersion {
127+
func (c *fiberWrapper) Version() huma.ProtoVersion {
178128
return huma.ProtoVersion{
179-
Proto: c.orig().Protocol(),
129+
Proto: c.orig.Protocol(),
180130
}
181131
}
182132

@@ -188,9 +138,31 @@ type requestTester interface {
188138
Test(*http.Request, ...int) (*http.Response, error)
189139
}
190140

191-
type fiberAdapter struct {
192-
tester requestTester
193-
router router
141+
type contextWrapperValue struct {
142+
Key any
143+
Value any
144+
}
145+
146+
type contextWrapper struct {
147+
values []*contextWrapperValue
148+
context.Context
149+
}
150+
151+
var (
152+
_ context.Context = &contextWrapper{}
153+
)
154+
155+
func (c *contextWrapper) Value(key any) any {
156+
var raw = c.Context.Value(key)
157+
if raw != nil {
158+
return raw
159+
}
160+
for _, pair := range c.values {
161+
if pair.Key == key {
162+
return pair.Value
163+
}
164+
}
165+
return nil
194166
}
195167

196168
func (a *fiberAdapter) Handle(op *huma.Operation, handler func(huma.Context)) {
@@ -199,17 +171,21 @@ func (a *fiberAdapter) Handle(op *huma.Operation, handler func(huma.Context)) {
199171
path = strings.ReplaceAll(path, "{", ":")
200172
path = strings.ReplaceAll(path, "}", "")
201173
a.router.Add(op.Method, path, func(c *fiber.Ctx) error {
202-
var ctx, cancel = context.WithCancel(context.Background())
203-
var fc = &fiberCtx{
204-
op: op,
205-
unsafeFiberCtx: c,
206-
unsafeGolangCtx: ctx,
207-
}
208-
defer func() {
209-
cancel()
210-
fc.unsafeFiberCtx = nil
211-
}()
212-
handler(fc)
174+
var values []*contextWrapperValue
175+
c.Context().VisitUserValuesAll(func(key, value any) {
176+
values = append(values, &contextWrapperValue{
177+
Key: key,
178+
Value: value,
179+
})
180+
})
181+
handler(&fiberWrapper{
182+
op: op,
183+
orig: c,
184+
ctx: &contextWrapper{
185+
values: values,
186+
Context: c.UserContext(),
187+
},
188+
})
213189
return nil
214190
})
215191
}
@@ -218,6 +194,11 @@ func (a *fiberAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
218194
// b, _ := httputil.DumpRequest(r, true)
219195
// fmt.Println(string(b))
220196
resp, err := a.tester.Test(r)
197+
if resp != nil && resp.Body != nil {
198+
defer func() {
199+
_ = resp.Body.Close()
200+
}()
201+
}
221202
if err != nil {
222203
panic(err)
223204
}
@@ -228,7 +209,7 @@ func (a *fiberAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
228209
}
229210
}
230211
w.WriteHeader(resp.StatusCode)
231-
io.Copy(w, resp.Body)
212+
_, _ = io.Copy(w, resp.Body)
232213
}
233214

234215
func New(r *fiber.App, config huma.Config) huma.API {

0 commit comments

Comments
 (0)