From 96c26cfa9fabe99bb794add8ffbd898523a09ec7 Mon Sep 17 00:00:00 2001 From: Tit Petric Date: Thu, 28 Dec 2023 17:19:16 +0100 Subject: [PATCH] Fix tests, allocate BaseMiddleware ptr --- gateway/coprocess_test.go | 2 +- gateway/mw_go_plugin_test.go | 4 +++- gateway/mw_ip_blacklist_test.go | 4 ++-- gateway/mw_ip_whitelist_test.go | 4 ++-- gateway/mw_jwt_test.go | 4 ++-- gateway/mw_modify_headers_test.go | 2 +- gateway/mw_redis_cache_test.go | 4 ++-- gateway/mw_strip_auth_test.go | 10 +++++----- gateway/mw_url_rewrite_test.go | 2 +- 9 files changed, 19 insertions(+), 17 deletions(-) diff --git a/gateway/coprocess_test.go b/gateway/coprocess_test.go index d661acc8251..4f3c190e178 100644 --- a/gateway/coprocess_test.go +++ b/gateway/coprocess_test.go @@ -245,7 +245,7 @@ func equalHeaders(h1, h2 []*coprocess.Header) bool { func TestCoProcessMiddlewareName(t *testing.T) { // Initialize the CoProcessMiddleware - m := &CoProcessMiddleware{} + m := &CoProcessMiddleware{BaseMiddleware: &BaseMiddleware{}} // Get the name using the method name := m.Name() diff --git a/gateway/mw_go_plugin_test.go b/gateway/mw_go_plugin_test.go index 08366271597..a1ffcb38754 100644 --- a/gateway/mw_go_plugin_test.go +++ b/gateway/mw_go_plugin_test.go @@ -20,7 +20,9 @@ func TestLoadPlugin(t *testing.T) { } func TestGoPluginMiddleware_EnabledForSpec(t *testing.T) { - gpm := GoPluginMiddleware{} + gpm := GoPluginMiddleware{ + BaseMiddleware: &BaseMiddleware{}, + } apiSpec := &APISpec{APIDefinition: &apidef.APIDefinition{}} gpm.Spec = apiSpec diff --git a/gateway/mw_ip_blacklist_test.go b/gateway/mw_ip_blacklist_test.go index bbc2f3b41df..1321e724d53 100644 --- a/gateway/mw_ip_blacklist_test.go +++ b/gateway/mw_ip_blacklist_test.go @@ -42,7 +42,7 @@ func TestIPBlacklistMiddleware(t *testing.T) { req.Header.Set(header.XRealIP, tc.xRealIP) } - mw := &IPBlackListMiddleware{} + mw := &IPBlackListMiddleware{BaseMiddleware: &BaseMiddleware{}} mw.Spec = spec _, code := mw.ProcessRequest(rec, req, nil) @@ -58,7 +58,7 @@ func BenchmarkIPBlacklistMiddleware(b *testing.B) { spec := testPrepareIPBlacklistMiddleware() - mw := &IPBlackListMiddleware{} + mw := &IPBlackListMiddleware{BaseMiddleware: &BaseMiddleware{}} mw.Spec = spec rec := httptest.NewRecorder() diff --git a/gateway/mw_ip_whitelist_test.go b/gateway/mw_ip_whitelist_test.go index 429aade08ba..e011f216c40 100644 --- a/gateway/mw_ip_whitelist_test.go +++ b/gateway/mw_ip_whitelist_test.go @@ -42,7 +42,7 @@ func TestIPMiddlewarePass(t *testing.T) { req.Header.Set(header.XRealIP, tc.xRealIP) } - mw := &IPWhiteListMiddleware{} + mw := &IPWhiteListMiddleware{BaseMiddleware: &BaseMiddleware{}} mw.Spec = spec _, code := mw.ProcessRequest(rec, req, nil) @@ -57,7 +57,7 @@ func BenchmarkIPMiddlewarePass(b *testing.B) { b.ReportAllocs() spec := testPrepareIPMiddlewarePass() - mw := &IPWhiteListMiddleware{} + mw := &IPWhiteListMiddleware{BaseMiddleware: &BaseMiddleware{}} mw.Spec = spec rec := httptest.NewRecorder() diff --git a/gateway/mw_jwt_test.go b/gateway/mw_jwt_test.go index 5a6e4b1e741..523793e5468 100644 --- a/gateway/mw_jwt_test.go +++ b/gateway/mw_jwt_test.go @@ -2306,7 +2306,7 @@ func TestGetUserIDFromClaim(t *testing.T) { func TestJWTMiddleware_getSecretToVerifySignature_JWKNoKID(t *testing.T) { const jwkURL = "https://jwk.com" - m := JWTMiddleware{} + m := JWTMiddleware{BaseMiddleware: &BaseMiddleware{}} api := &apidef.APIDefinition{JWTSource: jwkURL} m.Spec = &APISpec{APIDefinition: api} @@ -2408,7 +2408,7 @@ func Test_getOAuthClientIDFromClaim(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - j := JWTMiddleware{} + j := JWTMiddleware{BaseMiddleware: &BaseMiddleware{}} j.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} oauthClientID := j.getOAuthClientIDFromClaim(tc.claims) diff --git a/gateway/mw_modify_headers_test.go b/gateway/mw_modify_headers_test.go index 1790a647d65..7093be7afd5 100644 --- a/gateway/mw_modify_headers_test.go +++ b/gateway/mw_modify_headers_test.go @@ -17,7 +17,7 @@ func TestTransformHeaders_EnabledForSpec(t *testing.T) { "Default": versionInfo, } - th := TransformHeaders{} + th := TransformHeaders{BaseMiddleware: &BaseMiddleware{}} th.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} th.Spec.VersionData.Versions = versions diff --git a/gateway/mw_redis_cache_test.go b/gateway/mw_redis_cache_test.go index 9ca99991a22..2eb6e03859a 100644 --- a/gateway/mw_redis_cache_test.go +++ b/gateway/mw_redis_cache_test.go @@ -27,7 +27,7 @@ func TestRedisCacheMiddlewareUnit(t *testing.T) { { Name: "isTimeStampExpired", Fn: func(t *testing.T) { - mw := &RedisCacheMiddleware{} + mw := &RedisCacheMiddleware{BaseMiddleware: &BaseMiddleware{}} assert.True(t, mw.isTimeStampExpired("invalid")) assert.True(t, mw.isTimeStampExpired("1")) @@ -38,7 +38,7 @@ func TestRedisCacheMiddlewareUnit(t *testing.T) { { Name: "decodePayload", Fn: func(t *testing.T) { - mw := &RedisCacheMiddleware{} + mw := &RedisCacheMiddleware{BaseMiddleware: &BaseMiddleware{}} if data, expire, err := mw.decodePayload("dGVzdGluZwo=|123"); true { assert.Equal(t, "testing\n", data) diff --git a/gateway/mw_strip_auth_test.go b/gateway/mw_strip_auth_test.go index 63152099155..cd59ac7bf3f 100644 --- a/gateway/mw_strip_auth_test.go +++ b/gateway/mw_strip_auth_test.go @@ -38,7 +38,7 @@ func TestStripAuth_stripFromHeaders(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("stripping %+v", tc), func(t *testing.T) { - sa := StripAuth{} + sa := StripAuth{BaseMiddleware: &BaseMiddleware{}} sa.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} req, err := http.NewRequest("GET", "http://example.com", nil) @@ -76,7 +76,7 @@ func TestStripAuth_stripFromHeaders(t *testing.T) { if err != nil { t.Fatal(err) } - sa := StripAuth{} + sa := StripAuth{BaseMiddleware: &BaseMiddleware{}} sa.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} key := "Authorization" @@ -114,7 +114,7 @@ func BenchmarkStripAuth_stripFromHeaders(b *testing.B) { for i := 0; i < b.N; i++ { for _, tc := range testCases { - sa := StripAuth{} + sa := StripAuth{BaseMiddleware: &BaseMiddleware{}} sa.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} req, err := http.NewRequest("GET", "http://example.com", nil) @@ -165,7 +165,7 @@ func TestStripAuth_stripFromParams(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("stripping %s", tc.QueryParam), func(t *testing.T) { - sa := StripAuth{} + sa := StripAuth{BaseMiddleware: &BaseMiddleware{}} sa.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} rawUrl := "http://example.com/abc" @@ -207,7 +207,7 @@ func BenchmarkStripAuth_stripFromParams(b *testing.B) { for i := 0; i < b.N; i++ { for _, tc := range testCases { - sa := StripAuth{} + sa := StripAuth{BaseMiddleware: &BaseMiddleware{}} sa.Spec = &APISpec{APIDefinition: &apidef.APIDefinition{}} req, err := http.NewRequest("GET", "http://example.com/abc", nil) diff --git a/gateway/mw_url_rewrite_test.go b/gateway/mw_url_rewrite_test.go index ce8d3165e38..1da2fbc3072 100644 --- a/gateway/mw_url_rewrite_test.go +++ b/gateway/mw_url_rewrite_test.go @@ -1376,7 +1376,7 @@ func TestURLRewriteMiddleware_CheckHostRewrite(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := &URLRewriteMiddleware{} + m := &URLRewriteMiddleware{BaseMiddleware: &BaseMiddleware{}} r := &http.Request{} err := m.CheckHostRewrite(tt.args.oldPath, tt.args.newTarget, r) assert.Equal(t, tt.errExpected, err != nil)