diff --git a/server.go b/server.go index 11f6138295..27bcdd2b5f 100644 --- a/server.go +++ b/server.go @@ -256,7 +256,7 @@ type RequestCtx struct { // Copying Response by value is forbidden. Use pointer to Response instead. Response Response - userValues map[string]interface{} + userValues userData id uint64 @@ -317,31 +317,38 @@ func (ctx *RequestCtx) Hijack(handler HijackHandler) { // SetUserValue stores the given value (arbitrary object) // under the given key in ctx. // -// The value stored in ctx may be obtained by UserValue(). +// The value stored in ctx may be obtained by UserValue*. // // This functionality may be useful for passing arbitrary values between // functions involved in request processing. // // All the values stored in ctx are deleted after returning from RequestHandler. func (ctx *RequestCtx) SetUserValue(key string, value interface{}) { - if ctx.userValues == nil { - ctx.userValues = make(map[string]interface{}, 1) - } - ctx.userValues[key] = value + ctx.userValues.Set(key, value) } -// UserValue returns the value stored via SetUserValue under the given key. +// SetUserValueBytes stores the given value (arbitrary object) +// under the given key in ctx. +// +// The value stored in ctx may be obtained by UserValue*. +// +// This functionality may be useful for passing arbitrary values between +// functions involved in request processing. +// +// All the values stored in ctx are deleted after returning from RequestHandler. +func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) { + ctx.userValues.SetBytes(key, value) +} + +// UserValue returns the value stored via SetUserValue* under the given key. func (ctx *RequestCtx) UserValue(key string) interface{} { - if ctx.userValues == nil { - return nil - } - return ctx.userValues[key] + return ctx.userValues.Get(key) } -func (ctx *RequestCtx) resetUserValues() { - for k := range ctx.userValues { - delete(ctx.userValues, k) - } +// UserValueBytes returns the value stored via SetUserValue* +// under the given key. +func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} { + return ctx.userValues.GetBytes(key) } // IsTLS returns true if the underlying connection is tls.Conn. @@ -1073,7 +1080,7 @@ func (s *Server) serveConn(c net.Conn) error { hijackHandler = ctx.hijackHandler ctx.hijackHandler = nil - ctx.resetUserValues() + ctx.userValues.Reset() // Remove temporary files, which may be uploaded during the request. ctx.Request.RemoveMultipartFormFiles() diff --git a/userdata.go b/userdata.go new file mode 100644 index 0000000000..ecc343e37c --- /dev/null +++ b/userdata.go @@ -0,0 +1,59 @@ +package fasthttp + +type userDataKV struct { + key []byte + value interface{} +} + +type userData []userDataKV + +func (d *userData) Set(key string, value interface{}) { + args := *d + n := len(args) + for i := 0; i < n; i++ { + kv := &args[i] + if string(kv.key) == key { + kv.value = value + return + } + } + + c := cap(args) + if c > n { + args = args[:n+1] + kv := &args[n] + kv.key = append(kv.key[:0], key...) + kv.value = value + *d = args + return + } + + kv := userDataKV{} + kv.key = append(kv.key[:0], key...) + kv.value = value + *d = append(args, kv) +} + +func (d *userData) SetBytes(key []byte, value interface{}) { + d.Set(unsafeBytesToStr(key), value) +} + +func (d *userData) Get(key string) interface{} { + args := *d + n := len(args) + for i := 0; i < n; i++ { + kv := &args[i] + if string(kv.key) == key { + return kv.value + } + } + return nil +} + +func (d *userData) GetBytes(key []byte) interface{} { + return d.Get(unsafeBytesToStr(key)) +} + +func (d *userData) Reset() { + *d = (*d)[:0] +} diff --git a/userdata_test.go b/userdata_test.go new file mode 100644 index 0000000000..e4eeaf2563 --- /dev/null +++ b/userdata_test.go @@ -0,0 +1,41 @@ +package fasthttp + +import ( + "fmt" + "reflect" + "testing" +) + +func TestUserData(t *testing.T) { + var u userData + + for i := 0; i < 10; i++ { + key := []byte(fmt.Sprintf("key_%d", i)) + u.SetBytes(key, i+5) + testUserDataGet(t, &u, key, i+5) + u.SetBytes(key, i) + testUserDataGet(t, &u, key, i) + } + + for i := 0; i < 10; i++ { + key := []byte(fmt.Sprintf("key_%d", i)) + testUserDataGet(t, &u, key, i) + } + + u.Reset() + + for i := 0; i < 10; i++ { + key := []byte(fmt.Sprintf("key_%d", i)) + testUserDataGet(t, &u, key, nil) + } +} + +func testUserDataGet(t *testing.T, u *userData, key []byte, value interface{}) { + v := u.GetBytes(key) + if v == nil && value != nil { + t.Fatalf("cannot obtain value for key=%q", key) + } + if !reflect.DeepEqual(v, value) { + t.Fatalf("unexpected value for key=%q: %d. Expecting %d", v, value) + } +} diff --git a/userdata_timing_test.go b/userdata_timing_test.go new file mode 100644 index 0000000000..3822de3fdb --- /dev/null +++ b/userdata_timing_test.go @@ -0,0 +1,48 @@ +package fasthttp + +import ( + "testing" +) + +func BenchmarkUserDataCustom(b *testing.B) { + keys := []string{"foobar", "baz", "aaa", "bsdfs"} + b.RunParallel(func(pb *testing.PB) { + var u userData + var v interface{} = u + for pb.Next() { + for _, key := range keys { + u.Set(key, v) + } + for _, key := range keys { + vv := u.Get(key) + if _, ok := vv.(userData); !ok { + b.Fatalf("unexpected value %v for key %q", vv, key) + } + } + u.Reset() + } + }) +} + +func BenchmarkUserDataStdMap(b *testing.B) { + keys := []string{"foobar", "baz", "aaa", "bsdfs"} + b.RunParallel(func(pb *testing.PB) { + u := make(map[string]interface{}) + var v interface{} = u + for pb.Next() { + for _, key := range keys { + u[key] = v + } + for _, key := range keys { + vv := u[key] + if _, ok := vv.(map[string]interface{}); !ok { + b.Fatalf("unexpected value %v for key %q", vv, key) + } + } + + for k := range u { + delete(u, k) + } + } + }) +}