From 079276c6ffd4392d4bb84d869b27b5d72596f6cd Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 26 Jul 2024 18:07:53 +0800 Subject: [PATCH] fix: [2.4] Unify hook singleton implementation in proxy (#34888) Cherry-pick from master pr: #34887 Related to #34885 --------- Signed-off-by: Congqi Xia --- internal/distributed/proxy/service_test.go | 5 +- .../proxy/authentication_interceptor_test.go | 6 +- internal/proxy/hook_interceptor.go | 19 +----- internal/proxy/hook_interceptor_test.go | 10 ++-- internal/proxy/impl.go | 12 ++-- internal/proxy/proxy.go | 9 +-- internal/proxy/util.go | 5 +- internal/util/hookutil/default.go | 11 ---- internal/util/hookutil/hook.go | 58 ++++++++++++++++--- internal/util/hookutil/hook_test.go | 2 +- internal/util/hookutil/mock_hook.go | 54 +++++++++++++++++ tests/integration/minicluster_v2.go | 2 +- 12 files changed, 129 insertions(+), 64 deletions(-) create mode 100644 internal/util/hookutil/mock_hook.go diff --git a/internal/distributed/proxy/service_test.go b/internal/distributed/proxy/service_test.go index 56bf839d7f6c8..936785a88d27e 100644 --- a/internal/distributed/proxy/service_test.go +++ b/internal/distributed/proxy/service_test.go @@ -47,6 +47,7 @@ import ( "github.com/milvus-io/milvus/internal/distributed/proxy/httpserver" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/util/hookutil" milvusmock "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -1165,8 +1166,8 @@ func TestHttpAuthenticate(t *testing.T) { } { - proxy.SetMockAPIHook("foo", nil) - defer proxy.SetMockAPIHook("", nil) + hookutil.SetMockAPIHook("foo", nil) + defer hookutil.SetMockAPIHook("", nil) ctx.Request.Header.Set("Authorization", "Bearer 123456") authenticate(ctx) ctxName, _ := ctx.Get(httpserver.ContextUsername) diff --git a/internal/proxy/authentication_interceptor_test.go b/internal/proxy/authentication_interceptor_test.go index be2863cd31663..9e237d3f5f902 100644 --- a/internal/proxy/authentication_interceptor_test.go +++ b/internal/proxy/authentication_interceptor_test.go @@ -119,7 +119,7 @@ func TestAuthenticationInterceptor(t *testing.T) { { // verify apikey error - SetMockAPIHook("", errors.New("err")) + hookutil.SetMockAPIHook("", errors.New("err")) md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) ctx = metadata.NewIncomingContext(ctx, md) _, err = AuthenticationInterceptor(ctx) @@ -127,7 +127,7 @@ func TestAuthenticationInterceptor(t *testing.T) { } { - SetMockAPIHook("mockUser", nil) + hookutil.SetMockAPIHook("mockUser", nil) md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) ctx = metadata.NewIncomingContext(ctx, md) authCtx, err := AuthenticationInterceptor(ctx) @@ -141,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) { user, _ := parseMD(rawToken) assert.Equal(t, "mockUser", user) } - hoo = hookutil.DefaultHook{} + hookutil.SetTestHook(hookutil.DefaultHook{}) } diff --git a/internal/proxy/hook_interceptor.go b/internal/proxy/hook_interceptor.go index 1d3c27a2e126b..7ba5b6aa7d920 100644 --- a/internal/proxy/hook_interceptor.go +++ b/internal/proxy/hook_interceptor.go @@ -8,15 +8,12 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" ) -var hoo hook.Hook - func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler) @@ -24,10 +21,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { } func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) { - if hoo == nil { - hookutil.InitOnceHook() - hoo = hookutil.Hoo - } + hoo := hookutil.GetHook() var ( newCtx context.Context isMock bool @@ -80,14 +74,3 @@ func getCurrentUser(ctx context.Context) string { } return username } - -func SetMockAPIHook(apiUser string, mockErr error) { - if apiUser == "" && mockErr == nil { - hoo = &hookutil.DefaultHook{} - return - } - hoo = &hookutil.MockAPIHook{ - MockErr: mockErr, - User: apiUser, - } -} diff --git a/internal/proxy/hook_interceptor_test.go b/internal/proxy/hook_interceptor_test.go index 3641f86d2541c..d14934bb66fc8 100644 --- a/internal/proxy/hook_interceptor_test.go +++ b/internal/proxy/hook_interceptor_test.go @@ -83,7 +83,7 @@ func TestHookInterceptor(t *testing.T) { err error ) - hoo = mockHoo + hookutil.SetTestHook(mockHoo) res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) @@ -95,7 +95,7 @@ func TestHookInterceptor(t *testing.T) { assert.Equal(t, res, mockHoo.mockRes) assert.Equal(t, err, mockHoo.mockErr) - hoo = beforeHoo + hookutil.SetTestHook(beforeHoo) _, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) @@ -103,7 +103,7 @@ func TestHookInterceptor(t *testing.T) { assert.Equal(t, err, beforeHoo.err) beforeHoo.err = nil - hoo = beforeHoo + hookutil.SetTestHook(beforeHoo) _, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) { assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey)) return nil, nil @@ -111,14 +111,14 @@ func TestHookInterceptor(t *testing.T) { assert.Equal(t, r.method, beforeHoo.method) assert.Equal(t, err, beforeHoo.err) - hoo = afterHoo + hookutil.SetTestHook(afterHoo) _, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) { return re, nil }) assert.Equal(t, re.method, afterHoo.method) assert.Equal(t, err, afterHoo.err) - hoo = &hookutil.DefaultHook{} + hookutil.SetTestHook(&hookutil.DefaultHook{}) res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) { return &resp{ method: r.(*req).method, diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index aadbaebc99827..c79250e319b99 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2591,7 +2591,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) dbName := request.DbName collectionName := request.CollectionName - v := Extension.Report(map[string]any{ + v := hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeInsert, hookutil.DatabaseKey: dbName, hookutil.UsernameKey: username, @@ -2689,7 +2689,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) username := GetCurUserFromContextOrDefault(ctx) collectionName := request.CollectionName - v := Extension.Report(map[string]any{ + v := hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeDelete, hookutil.DatabaseKey: dbName, hookutil.UsernameKey: username, @@ -2822,7 +2822,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) nodeID := paramtable.GetStringNodeID() dbName := request.DbName collectionName := request.CollectionName - v := Extension.Report(map[string]any{ + v := hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeUpsert, hookutil.DatabaseKey: request.DbName, hookutil.UsernameKey: username, @@ -3065,7 +3065,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest) if qt.result != nil { username := GetCurUserFromContextOrDefault(ctx) sentSize := proto.Size(qt.result) - v := Extension.Report(map[string]any{ + v := hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeSearch, hookutil.DatabaseKey: dbName, hookutil.UsernameKey: username, @@ -3262,7 +3262,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea if qt.result != nil { sentSize := proto.Size(qt.result) username := GetCurUserFromContextOrDefault(ctx) - v := Extension.Report(map[string]any{ + v := hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeHybridSearch, hookutil.DatabaseKey: dbName, hookutil.UsernameKey: username, @@ -3590,7 +3590,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* username := GetCurUserFromContextOrDefault(ctx) nodeID := paramtable.GetStringNodeID() - v := Extension.Report(map[string]any{ + v := hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeQuery, hookutil.DatabaseKey: request.DbName, hookutil.UsernameKey: username, diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 22d3dfbb9bcd8..2f6bfada529db 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -31,7 +31,6 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -67,9 +66,8 @@ type Timestamp = typeutil.Timestamp var _ types.Proxy = (*Proxy)(nil) var ( - Params = paramtable.Get() - Extension hook.Extension - rateCol *ratelimitutil.RateCollector + Params = paramtable.Get() + rateCol *ratelimitutil.RateCollector ) // Proxy of milvus @@ -154,7 +152,6 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { node.UpdateStateCode(commonpb.StateCode_Abnormal) expr.Register("proxy", node) hookutil.InitOnceHook() - Extension = hookutil.Extension logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) return node, nil } @@ -418,7 +415,7 @@ func (node *Proxy) Start() error { cb() } - Extension.Report(map[string]any{ + hookutil.GetExtension().Report(map[string]any{ hookutil.OpTypeKey: hookutil.OpTypeNodeID, hookutil.NodeIDKey: paramtable.GetNodeID(), }) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 20c0372210d40..41669eccb62a3 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -36,6 +36,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/hookutil" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -924,9 +925,7 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool { } func VerifyAPIKey(rawToken string) (string, error) { - if hoo == nil { - return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait") - } + hoo := hookutil.GetHook() user, err := hoo.VerifyAPIKey(rawToken) if err != nil { log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err)) diff --git a/internal/util/hookutil/default.go b/internal/util/hookutil/default.go index 6083e9d450959..7bbd467bb6f6a 100644 --- a/internal/util/hookutil/default.go +++ b/internal/util/hookutil/default.go @@ -50,17 +50,6 @@ func (d DefaultHook) After(ctx context.Context, result interface{}, err error, f return nil } -// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST -type MockAPIHook struct { - DefaultHook - MockErr error - User string -} - -func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) { - return m.User, m.MockErr -} - func (d DefaultHook) Release() {} type DefaultExtension struct{} diff --git a/internal/util/hookutil/hook.go b/internal/util/hookutil/hook.go index 1f1c9d89a666c..4d1acb1b11d0e 100644 --- a/internal/util/hookutil/hook.go +++ b/internal/util/hookutil/hook.go @@ -22,6 +22,7 @@ import ( "fmt" "plugin" "sync" + "sync/atomic" "go.uber.org/zap" @@ -32,14 +33,37 @@ import ( ) var ( - Hoo hook.Hook - Extension hook.Extension + hoo atomic.Value // hook.Hook + extension atomic.Value // hook.Extension initOnce sync.Once ) +// hookContainer is Container to wrap hook.Hook interface +// this struct is used to be stored in atomic.Value +// since different type stored in it will cause panicking. +type hookContainer struct { + hook hook.Hook +} + +// extensionContainer is Container to wrap hook.Extension interface +// this struct is used to be stored in atomic.Value +// since different type stored in it will cause panicking. +type extensionContainer struct { + extension hook.Extension +} + +func storeHook(hook hook.Hook) { + hoo.Store(hookContainer{hook: hook}) +} + +func storeExtension(ext hook.Extension) { + extension.Store(extensionContainer{extension: ext}) +} + func initHook() error { - Hoo = DefaultHook{} - Extension = DefaultExtension{} + // setup default hook & extension + storeHook(DefaultHook{}) + storeExtension(DefaultExtension{}) path := paramtable.Get().ProxyCfg.SoPath.GetValue() if path == "" { @@ -59,22 +83,26 @@ func initHook() error { return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error()) } + var hookVal hook.Hook var ok bool - Hoo, ok = h.(hook.Hook) + hookVal, ok = h.(hook.Hook) if !ok { return fmt.Errorf("fail to convert the `Hook` interface") } - if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil { + if err = hookVal.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil { return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error()) } + storeHook((hookVal)) paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) { log.Info("receive the hook refresh event", zap.Any("event", event)) go func() { + hookVal := GetHook() soConfig := paramtable.GetHookParams().SoConfig.GetValue() log.Info("refresh hook configs", zap.Any("config", soConfig)) - if err = Hoo.Init(soConfig); err != nil { + if err = hookVal.Init(soConfig); err != nil { log.Panic("fail to init configs for the hook when refreshing", zap.Error(err)) } + storeHook(hookVal) }() }) @@ -82,10 +110,12 @@ func initHook() error { if err != nil { return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error()) } - Extension, ok = e.(hook.Extension) + var extVal hook.Extension + extVal, ok = e.(hook.Extension) if !ok { return fmt.Errorf("fail to convert the `Extension` interface") } + storeExtension(extVal) return nil } @@ -104,3 +134,15 @@ func InitOnceHook() { } }) } + +// GetHook returns singleton hook.Hook instance. +func GetHook() hook.Hook { + InitOnceHook() + return hoo.Load().(hookContainer).hook +} + +// GetHook returns singleton hook.Extension instance. +func GetExtension() hook.Extension { + InitOnceHook() + return extension.Load().(extensionContainer).extension +} diff --git a/internal/util/hookutil/hook_test.go b/internal/util/hookutil/hook_test.go index 1ac41d8b9682b..1cde9f9bf6c3e 100644 --- a/internal/util/hookutil/hook_test.go +++ b/internal/util/hookutil/hook_test.go @@ -32,7 +32,7 @@ func TestInitHook(t *testing.T) { Params := paramtable.Get() paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "") initHook() - assert.IsType(t, DefaultHook{}, Hoo) + assert.IsType(t, DefaultHook{}, GetHook()) paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so") err := initHook() diff --git a/internal/util/hookutil/mock_hook.go b/internal/util/hookutil/mock_hook.go new file mode 100644 index 0000000000000..0808080c669f8 --- /dev/null +++ b/internal/util/hookutil/mock_hook.go @@ -0,0 +1,54 @@ +//go:build test +// +build test + +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hookutil + +import "github.com/milvus-io/milvus-proto/go-api/v2/hook" + +// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST +type MockAPIHook struct { + DefaultHook + MockErr error + User string +} + +func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) { + return m.User, m.MockErr +} + +func SetMockAPIHook(apiUser string, mockErr error) { + if apiUser == "" && mockErr == nil { + storeHook(&DefaultHook{}) + return + } + storeHook(&MockAPIHook{ + MockErr: mockErr, + User: apiUser, + }) +} + +func SetTestHook(hookVal hook.Hook) { + storeHook(hookVal) +} + +func SetTestExtension(extVal hook.Extension) { + storeExtension(extVal) +} diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index 626cbbdafa81b..5c87bf09fe715 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -465,7 +465,7 @@ func (cluster *MiniClusterV2) GetAvailablePort() (int, error) { func InitReportExtension() *ReportChanExtension { e := NewReportChanExtension() hookutil.InitOnceHook() - hookutil.Extension = e + hookutil.SetTestExtension(e) return e }