Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Unify hook singleton implementation in proxy #34887

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/distributed/proxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/util/hookutil"
milvusmock "github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
Expand Down Expand Up @@ -1165,8 +1166,8 @@ func TestHttpAuthenticate(t *testing.T) {
}

{
proxy.SetMockAPIHook("foo", nil)
defer proxy.SetMockAPIHook("", nil)
hookutil.SetMockAPIHook("foo", nil)
defer hookutil.SetMockAPIHook("", nil)
ctx.Request.Header.Set("Authorization", "Bearer 123456")
authenticate(ctx)
ctxName, _ := ctx.Get(httpserver.ContextUsername)
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/authentication_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ func TestAuthenticationInterceptor(t *testing.T) {

{
// verify apikey error
SetMockAPIHook("", errors.New("err"))
hookutil.SetMockAPIHook("", errors.New("err"))
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx)
assert.Error(t, err)
}

{
SetMockAPIHook("mockUser", nil)
hookutil.SetMockAPIHook("mockUser", nil)
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md)
authCtx, err := AuthenticationInterceptor(ctx)
Expand All @@ -141,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
user, _ := parseMD(rawToken)
assert.Equal(t, "mockUser", user)
}
hoo = hookutil.DefaultHook{}
hookutil.SetTestHook(hookutil.DefaultHook{})
}
19 changes: 1 addition & 18 deletions internal/proxy/hook_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,20 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"

"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)

var hoo hook.Hook

func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler)
}
}

func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) {
if hoo == nil {
hookutil.InitOnceHook()
hoo = hookutil.Hoo
}
hoo := hookutil.GetHook()
var (
newCtx context.Context
isMock bool
Expand Down Expand Up @@ -80,14 +74,3 @@ func getCurrentUser(ctx context.Context) string {
}
return username
}

func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
hoo = &hookutil.DefaultHook{}
return
}
hoo = &hookutil.MockAPIHook{
MockErr: mockErr,
User: apiUser,
}
}
10 changes: 5 additions & 5 deletions internal/proxy/hook_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestHookInterceptor(t *testing.T) {
err error
)

hoo = mockHoo
hookutil.SetTestHook(mockHoo)
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
Expand All @@ -95,30 +95,30 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr)

hoo = beforeHoo
hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)

beforeHoo.err = nil
hoo = beforeHoo
hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)

hoo = afterHoo
hookutil.SetTestHook(afterHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil
})
assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err)

hoo = &hookutil.DefaultHook{}
hookutil.SetTestHook(&hookutil.DefaultHook{})
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{
method: r.(*req).method,
Expand Down
12 changes: 6 additions & 6 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2592,7 +2592,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
dbName := request.DbName
collectionName := request.CollectionName

v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeInsert,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -2696,7 +2696,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)

username := GetCurUserFromContextOrDefault(ctx)
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeDelete,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -2829,7 +2829,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
nodeID := paramtable.GetStringNodeID()
dbName := request.DbName
collectionName := request.CollectionName
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeUpsert,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -3072,7 +3072,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if qt.result != nil {
username := GetCurUserFromContextOrDefault(ctx)
sentSize := proto.Size(qt.result)
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeSearch,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -3269,7 +3269,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
if qt.result != nil {
sentSize := proto.Size(qt.result)
username := GetCurUserFromContextOrDefault(ctx)
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username,
Expand Down Expand Up @@ -3595,7 +3595,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*

username := GetCurUserFromContextOrDefault(ctx)
nodeID := paramtable.GetStringNodeID()
v := Extension.Report(map[string]any{
v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeQuery,
hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username,
Expand Down
9 changes: 3 additions & 6 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/internalpb"
Expand Down Expand Up @@ -67,9 +66,8 @@ type Timestamp = typeutil.Timestamp
var _ types.Proxy = (*Proxy)(nil)

var (
Params = paramtable.Get()
Extension hook.Extension
rateCol *ratelimitutil.RateCollector
Params = paramtable.Get()
rateCol *ratelimitutil.RateCollector
)

// Proxy of milvus
Expand Down Expand Up @@ -157,7 +155,6 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node)
hookutil.InitOnceHook()
Extension = hookutil.Extension
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
return node, nil
}
Expand Down Expand Up @@ -422,7 +419,7 @@ func (node *Proxy) Start() error {
cb()
}

Extension.Report(map[string]any{
hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeNodeID,
hookutil.NodeIDKey: paramtable.GetNodeID(),
})
Expand Down
5 changes: 2 additions & 3 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/hookutil"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
Expand Down Expand Up @@ -924,9 +925,7 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
}

func VerifyAPIKey(rawToken string) (string, error) {
if hoo == nil {
return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait")
}
hoo := hookutil.GetHook()
user, err := hoo.VerifyAPIKey(rawToken)
if err != nil {
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))
Expand Down
11 changes: 0 additions & 11 deletions internal/util/hookutil/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,6 @@ func (d DefaultHook) After(ctx context.Context, result interface{}, err error, f
return nil
}

// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}

func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}

func (d DefaultHook) Release() {}

type DefaultExtension struct{}
Expand Down
58 changes: 50 additions & 8 deletions internal/util/hookutil/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"fmt"
"plugin"
"sync"
"sync/atomic"

"go.uber.org/zap"

Expand All @@ -32,14 +33,37 @@
)

var (
Hoo hook.Hook
Extension hook.Extension
hoo atomic.Value // hook.Hook
extension atomic.Value // hook.Extension
initOnce sync.Once
)

// hookContainer is Container to wrap hook.Hook interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type hookContainer struct {
hook hook.Hook
}

// extensionContainer is Container to wrap hook.Extension interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type extensionContainer struct {
extension hook.Extension
}

func storeHook(hook hook.Hook) {
hoo.Store(hookContainer{hook: hook})
}

func storeExtension(ext hook.Extension) {
extension.Store(extensionContainer{extension: ext})
}

func initHook() error {
Hoo = DefaultHook{}
Extension = DefaultExtension{}
// setup default hook & extension
storeHook(DefaultHook{})
storeExtension(DefaultExtension{})

path := paramtable.Get().ProxyCfg.SoPath.GetValue()
if path == "" {
Expand All @@ -59,33 +83,39 @@
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
}

var hookVal hook.Hook

Check warning on line 86 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L86

Added line #L86 was not covered by tests
var ok bool
Hoo, ok = h.(hook.Hook)
hookVal, ok = h.(hook.Hook)

Check warning on line 88 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L88

Added line #L88 was not covered by tests
if !ok {
return fmt.Errorf("fail to convert the `Hook` interface")
}
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
if err = hookVal.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {

Check warning on line 92 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L92

Added line #L92 was not covered by tests
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
}
storeHook((hookVal))

Check warning on line 95 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L95

Added line #L95 was not covered by tests
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
log.Info("receive the hook refresh event", zap.Any("event", event))
go func() {
hookVal := GetHook()

Check warning on line 99 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L99

Added line #L99 was not covered by tests
soConfig := paramtable.GetHookParams().SoConfig.GetValue()
log.Info("refresh hook configs", zap.Any("config", soConfig))
if err = Hoo.Init(soConfig); err != nil {
if err = hookVal.Init(soConfig); err != nil {

Check warning on line 102 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L102

Added line #L102 was not covered by tests
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
}
storeHook(hookVal)

Check warning on line 105 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L105

Added line #L105 was not covered by tests
}()
})

e, err := p.Lookup("MilvusExtension")
if err != nil {
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
}
Extension, ok = e.(hook.Extension)
var extVal hook.Extension
extVal, ok = e.(hook.Extension)

Check warning on line 114 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L113-L114

Added lines #L113 - L114 were not covered by tests
if !ok {
return fmt.Errorf("fail to convert the `Extension` interface")
}
storeExtension(extVal)

Check warning on line 118 in internal/util/hookutil/hook.go

View check run for this annotation

Codecov / codecov/patch

internal/util/hookutil/hook.go#L118

Added line #L118 was not covered by tests

return nil
}
Expand All @@ -104,3 +134,15 @@
}
})
}

// GetHook returns singleton hook.Hook instance.
func GetHook() hook.Hook {
InitOnceHook()
return hoo.Load().(hookContainer).hook
}

// GetHook returns singleton hook.Extension instance.
func GetExtension() hook.Extension {
InitOnceHook()
return extension.Load().(extensionContainer).extension
}
2 changes: 1 addition & 1 deletion internal/util/hookutil/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestInitHook(t *testing.T) {
Params := paramtable.Get()
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
initHook()
assert.IsType(t, DefaultHook{}, Hoo)
assert.IsType(t, DefaultHook{}, GetHook())

paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
err := initHook()
Expand Down
Loading
Loading