From daf5749b76d2fd601653af86b4cc68ae4bad94e6 Mon Sep 17 00:00:00 2001 From: tianjipeng Date: Sun, 24 Jan 2021 22:23:27 +0800 Subject: [PATCH] ajust cache middleware config.Next position --- middleware/cache/cache.go | 10 ++++---- middleware/cache/cache_test.go | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 6dfe3d35672..ee9570b6013 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -43,11 +43,6 @@ func New(config ...Config) fiber.Handler { // Return new handler return func(c *fiber.Ctx) error { - // Don't execute middleware if Next returns true - if cfg.Next != nil && cfg.Next(c) { - return c.Next() - } - // Only cache GET methods if c.Method() != fiber.MethodGet { return c.Next() @@ -105,6 +100,11 @@ func New(config ...Config) fiber.Handler { return err } + // Don't cache response if Next returns true + if cfg.Next != nil && cfg.Next(c) { + return nil + } + // Cache response e.body = utils.SafeBytes(c.Response().Body()) e.status = c.Response().StatusCode() diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 38196314ced..73054f15094 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -231,6 +231,49 @@ func Test_Cache_NothingToCache(t *testing.T) { } } +func Test_Cache_CustomNext(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{ + Next: func(c *fiber.Ctx) bool { + return !(c.Response().StatusCode() == fiber.StatusOK) + }, + CacheControl: true, + })) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString(time.Now().String()) + }) + + app.Get("/error", func(c *fiber.Ctx) error { + return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String()) + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + + respCached, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + bodyCached, err := ioutil.ReadAll(respCached.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, bytes.Equal(body, bodyCached)) + utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "") + + errResp, err := app.Test(httptest.NewRequest("GET", "/error", nil)) + utils.AssertEqual(t, nil, err) + errBody, err := ioutil.ReadAll(errResp.Body) + utils.AssertEqual(t, nil, err) + + errRespCached, err := app.Test(httptest.NewRequest("GET", "/error", nil)) + utils.AssertEqual(t, nil, err) + errBodyCached, err := ioutil.ReadAll(errRespCached.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, false, bytes.Equal(errBody, errBodyCached)) + utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "") +} + func Test_CustomKey(t *testing.T) { app := fiber.New() var called bool