diff --git a/context_test.go b/context_test.go index 9122845896..2a4d21856e 100644 --- a/context_test.go +++ b/context_test.go @@ -2059,15 +2059,54 @@ func TestRemoteIPFail(t *testing.T) { } func TestContextWithFallbackValueFromRequestContext(t *testing.T) { - var key struct{} - c := &Context{} - c.Request, _ = http.NewRequest("POST", "/", nil) - c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) - - assert.Equal(t, "value", c.Value(key)) - - c2 := &Context{} - c2.Request, _ = http.NewRequest("POST", "/", nil) - c2.Request = c2.Request.WithContext(context.WithValue(context.TODO(), "key", "value2")) - assert.Equal(t, "value2", c2.Value("key")) + tests := []struct { + name string + getContextAndKey func() (*Context, interface{}) + value interface{} + }{ + { + name: "c with struct context key", + getContextAndKey: func() (*Context, interface{}) { + var key struct{} + c := &Context{} + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) + return c, key + }, + value: "value", + }, + { + name: "c with string context key", + getContextAndKey: func() (*Context, interface{}) { + c := &Context{} + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request = c.Request.WithContext(context.WithValue(context.TODO(), "key", "value")) + return c, "key" + }, + value: "value", + }, + { + name: "c with nil http.Request", + getContextAndKey: func() (*Context, interface{}) { + c := &Context{} + return c, "key" + }, + value: nil, + }, + { + name: "c with nil http.Request.Context()", + getContextAndKey: func() (*Context, interface{}) { + c := &Context{} + c.Request, _ = http.NewRequest("POST", "/", nil) + return c, "key" + }, + value: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, key := tt.getContextAndKey() + assert.Equal(t, tt.value, c.Value(key)) + }) + } }