Skip to content

Commit

Permalink
fix: Unify hook singleton implementation in proxy (milvus-io#34887)
Browse files Browse the repository at this point in the history
Related to milvus-io#34885

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
  • Loading branch information
congqixia authored and sumitd2 committed Aug 6, 2024
1 parent e97a0ba commit 5d78f98
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 64 deletions.
5 changes: 3 additions & 2 deletions internal/distributed/proxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/util/hookutil"
milvusmock "github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
Expand Down Expand Up @@ -1165,8 +1166,8 @@ func TestHttpAuthenticate(t *testing.T) {
}

{
proxy.SetMockAPIHook("foo", nil)
defer proxy.SetMockAPIHook("", nil)
hookutil.SetMockAPIHook("foo", nil)
defer hookutil.SetMockAPIHook("", nil)
ctx.Request.Header.Set("Authorization", "Bearer 123456")
authenticate(ctx)
ctxName, _ := ctx.Get(httpserver.ContextUsername)
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/authentication_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ func TestAuthenticationInterceptor(t *testing.T) {

{
// verify apikey error
SetMockAPIHook("", errors.New("err"))
hookutil.SetMockAPIHook("", errors.New("err"))
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx)
assert.Error(t, err)
}

{
SetMockAPIHook("mockUser", nil)
hookutil.SetMockAPIHook("mockUser", nil)
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md)
authCtx, err := AuthenticationInterceptor(ctx)
Expand All @@ -141,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
user, _ := parseMD(rawToken)
assert.Equal(t, "mockUser", user)
}
hoo = hookutil.DefaultHook{}
hookutil.SetTestHook(hookutil.DefaultHook{})
}
19 changes: 1 addition & 18 deletions internal/proxy/hook_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,20 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"

"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)

var hoo hook.Hook

func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler)
}
}

func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) {
if hoo == nil {
hookutil.InitOnceHook()
hoo = hookutil.Hoo
}
hoo := hookutil.GetHook()
var (
newCtx context.Context
isMock bool
Expand Down Expand Up @@ -80,14 +74,3 @@ func getCurrentUser(ctx context.Context) string {
}
return username
}

func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
hoo = &hookutil.DefaultHook{}
return
}
hoo = &hookutil.MockAPIHook{
MockErr: mockErr,
User: apiUser,
}
}
10 changes: 5 additions & 5 deletions internal/proxy/hook_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestHookInterceptor(t *testing.T) {
err error
)

hoo = mockHoo
hookutil.SetTestHook(mockHoo)
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
Expand All @@ -95,30 +95,30 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr)

hoo = beforeHoo
hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)

beforeHoo.err = nil
hoo = beforeHoo
hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)

hoo = afterHoo
hookutil.SetTestHook(afterHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil
})
assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err)

hoo = &hookutil.DefaultHook{}
hookutil.SetTestHook(&hookutil.DefaultHook{})
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{
method: r.(*req).method,
Expand Down
12 changes: 6 additions & 6 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2592,7 +2592,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
dbName := request.DbName
collectionName := request.CollectionName

v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeInsert,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -2696,7 +2696,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)

username := GetCurUserFromContextOrDefault(ctx)
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeDelete,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -2829,7 +2829,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
nodeID := paramtable.GetStringNodeID()
dbName := request.DbName
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeUpsert,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -3072,7 +3072,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if qt.result != nil {
username := GetCurUserFromContextOrDefault(ctx)
sentSize := proto.Size(qt.result)
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeSearch,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -3269,7 +3269,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
if qt.result != nil {
sentSize := proto.Size(qt.result)
username := GetCurUserFromContextOrDefault(ctx)
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -3595,7 +3595,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*

username := GetCurUserFromContextOrDefault(ctx)
nodeID := paramtable.GetStringNodeID()
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeQuery,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username,
Expand Down
9 changes: 3 additions & 6 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/internalpb"
Expand Down Expand Up @@ -67,9 +66,8 @@ type Timestamp = typeutil.Timestamp
var _ types.Proxy = (*Proxy)(nil)

var (
Params = paramtable.Get()
Extension hook.Extension
rateCol *ratelimitutil.RateCollector
Params = paramtable.Get()
rateCol *ratelimitutil.RateCollector
)

// Proxy of milvus
Expand Down Expand Up @@ -157,7 +155,6 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node)
hookutil.InitOnceHook()
Extension = hookutil.Extension
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
return node, nil
}
Expand Down Expand Up @@ -422,7 +419,7 @@ func (node *Proxy) Start() error {
cb()
}

Extension.Report(map[string]any{
hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeNodeID,
hookutil.NodeIDKey: paramtable.GetNodeID(),
})
Expand Down
5 changes: 2 additions & 3 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/hookutil"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
Expand Down Expand Up @@ -924,9 +925,7 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
}

func VerifyAPIKey(rawToken string) (string, error) {
if hoo == nil {
return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait")
}
hoo := hookutil.GetHook()
user, err := hoo.VerifyAPIKey(rawToken)
if err != nil {
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))
Expand Down
11 changes: 0 additions & 11 deletions internal/util/hookutil/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,6 @@ func (d DefaultHook) After(ctx context.Context, result interface{}, err error, f
return nil
}

// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}

func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}

func (d DefaultHook) Release() {}

type DefaultExtension struct{}
Expand Down
58 changes: 50 additions & 8 deletions internal/util/hookutil/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"plugin"
"sync"
"sync/atomic"

"go.uber.org/zap"

Expand All @@ -32,14 +33,37 @@ import (
)

var (
Hoo hook.Hook
Extension hook.Extension
hoo atomic.Value // hook.Hook
extension atomic.Value // hook.Extension
initOnce sync.Once
)

// hookContainer is Container to wrap hook.Hook interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type hookContainer struct {
hook hook.Hook
}

// extensionContainer is Container to wrap hook.Extension interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type extensionContainer struct {
extension hook.Extension
}

func storeHook(hook hook.Hook) {
hoo.Store(hookContainer{hook: hook})
}

func storeExtension(ext hook.Extension) {
extension.Store(extensionContainer{extension: ext})
}

func initHook() error {
Hoo = DefaultHook{}
Extension = DefaultExtension{}
// setup default hook & extension
storeHook(DefaultHook{})
storeExtension(DefaultExtension{})

path := paramtable.Get().ProxyCfg.SoPath.GetValue()
if path == "" {
Expand All @@ -59,33 +83,39 @@ func initHook() error {
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
}

var hookVal hook.Hook
var ok bool
Hoo, ok = h.(hook.Hook)
hookVal, ok = h.(hook.Hook)
if !ok {
return fmt.Errorf("fail to convert the `Hook` interface")
}
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
if err = hookVal.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
}
storeHook((hookVal))
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
log.Info("receive the hook refresh event", zap.Any("event", event))
go func() {
hookVal := GetHook()
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
log.Info("refresh hook configs", zap.Any("config", soConfig))
if err = Hoo.Init(soConfig); err != nil {
if err = hookVal.Init(soConfig); err != nil {
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
}
storeHook(hookVal)
}()
})

e, err := p.Lookup("MilvusExtension")
if err != nil {
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
}
Extension, ok = e.(hook.Extension)
var extVal hook.Extension
extVal, ok = e.(hook.Extension)
if !ok {
return fmt.Errorf("fail to convert the `Extension` interface")
}
storeExtension(extVal)

return nil
}
Expand All @@ -104,3 +134,15 @@ func InitOnceHook() {
}
})
}

// GetHook returns singleton hook.Hook instance.
func GetHook() hook.Hook {
InitOnceHook()
return hoo.Load().(hookContainer).hook
}

// GetHook returns singleton hook.Extension instance.
func GetExtension() hook.Extension {
InitOnceHook()
return extension.Load().(extensionContainer).extension
}
2 changes: 1 addition & 1 deletion internal/util/hookutil/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestInitHook(t *testing.T) {
Params := paramtable.Get()
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
initHook()
assert.IsType(t, DefaultHook{}, Hoo)
assert.IsType(t, DefaultHook{}, GetHook())

paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
err := initHook()
Expand Down
Loading

0 comments on commit 5d78f98

Please sign in to comment.