Skip to content

Commit

Permalink
hotfix: multi service registry issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantxie committed Jul 26, 2024
1 parent eb99c3f commit 4528947
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 60 deletions.
14 changes: 1 addition & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,19 +431,7 @@ func (s *server) buildLimiterWithOpt() (handler remote.InboundHandler) {
}

func (s *server) check() error {
if len(s.svcs.svcMap) == 0 {
return errors.New("run: no service. Use RegisterService to set one")
}
if s.opt.RefuseTrafficWithoutServiceName {
s.svcs.refuseTrafficWithoutServiceName = true
return nil
}
for name, conflict := range s.svcs.conflictingMethodMap {
if conflict {
return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name)
}
}
return nil
return s.svcs.check(s.opt.RefuseTrafficWithoutServiceName)
}

func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) {
Expand Down
60 changes: 35 additions & 25 deletions server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package server

import (
"errors"
"fmt"

"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand All @@ -32,19 +33,17 @@ func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) *service
}

type services struct {
methodSvcMap map[string]*service // key: method name, value: svcInfo
svcMap map[string]*service // key: service name, value: svcInfo
conflictingMethodMap map[string]bool
fallbackSvc *service
methodSvcsMap map[string][]*service // key: method name
svcMap map[string]*service // key: service name
fallbackSvc *service

refuseTrafficWithoutServiceName bool
}

func newServices() *services {
return &services{
methodSvcMap: map[string]*service{},
svcMap: map[string]*service{},
conflictingMethodMap: map[string]bool{},
methodSvcsMap: map[string][]*service{},
svcMap: map[string]*service{},
}
}

Expand All @@ -59,26 +58,17 @@ func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interfac
s.svcMap[svcInfo.ServiceName] = svc
// method search map
for methodName := range svcInfo.Methods {
if _, ok := s.methodSvcMap[methodName]; ok {
s.handleConflictingMethod(svc, methodName, registerOpts)
svcs := s.methodSvcsMap[methodName]
if registerOpts.IsFallbackService {
svcs = append([]*service{svc}, svcs...)
} else {
s.methodSvcMap[methodName] = svc
svcs = append(svcs, svc)
}
s.methodSvcsMap[methodName] = svcs
}
return nil
}

func (s *services) handleConflictingMethod(svc *service, methodName string, registerOpts *RegisterOptions) {
// true means has conflicting method
if _, ok := s.conflictingMethodMap[methodName]; !ok {
s.conflictingMethodMap[methodName] = true
}
if registerOpts.IsFallbackService {
s.conflictingMethodMap[methodName] = false
s.methodSvcMap[methodName] = svc
}
}

func (s *services) getSvcInfoMap() map[string]*serviceinfo.ServiceInfo {
svcInfoMap := map[string]*serviceinfo.ServiceInfo{}
for name, svc := range s.svcMap {
Expand All @@ -87,6 +77,22 @@ func (s *services) getSvcInfoMap() map[string]*serviceinfo.ServiceInfo {
return svcInfoMap
}

func (s *services) check(refuseTrafficWithoutServiceName bool) error {
if len(s.svcMap) == 0 {
return errors.New("run: no service. Use RegisterService to set one")
}
if refuseTrafficWithoutServiceName {
s.refuseTrafficWithoutServiceName = true
return nil
}
for name, svcs := range s.methodSvcsMap {
if len(svcs) > 1 && svcs[0] != s.fallbackSvc {
return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name)
}
}
return nil
}

func (s *services) SearchService(svcName, methodName string, strict bool) *serviceinfo.ServiceInfo {
if strict || s.refuseTrafficWithoutServiceName {
if svc := s.svcMap[svcName]; svc != nil {
Expand All @@ -96,13 +102,17 @@ func (s *services) SearchService(svcName, methodName string, strict bool) *servi
}
var svc *service
if svcName == "" {
svc = s.methodSvcMap[methodName]
if svcs := s.methodSvcsMap[methodName]; len(svcs) > 0 {
svc = svcs[0]
}
} else {
svc = s.svcMap[svcName]
if svc == nil {
if _, ok := s.conflictingMethodMap[methodName]; !ok {
// no conflicting method
svc = s.methodSvcMap[methodName]
svcs := s.methodSvcsMap[methodName]
// 1. no conflicting method, allow method routing
// 2. $ generally means generic service, maybe mismatch the real service name
if len(svcs) == 1 || len(svcs) > 1 && svcName[0] == '$' {
svc = svcs[0]
}
}
}
Expand Down
207 changes: 185 additions & 22 deletions server/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,200 @@
package server

import (
"fmt"
"testing"

"github.com/cloudwego/kitex/internal/mocks"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/serviceinfo"
)

func TestAddService(t *testing.T) {
svcs := newServices()
err := svcs.addService(mocks.ServiceInfo(), mocks.MyServiceHandler(), &RegisterOptions{})
test.Assert(t, err == nil)
test.Assert(t, len(svcs.svcMap) == 1)
fmt.Println(svcs.methodSvcMap)
test.Assert(t, len(svcs.methodSvcMap) == 5)
test.Assert(t, len(svcs.conflictingMethodMap) == 0)
test.Assert(t, svcs.fallbackSvc == nil)
type svc struct {
svcInfo *serviceinfo.ServiceInfo
isFallbackService bool
}
testcases := []struct {
svcs []svc
refuseTrafficWithoutServiceName bool

err = svcs.addService(mocks.Service3Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true})
test.Assert(t, err == nil)
test.Assert(t, len(svcs.svcMap) == 2)
test.Assert(t, len(svcs.methodSvcMap) == 5)
test.Assert(t, svcs.SearchService("", "mock", false) == mocks.Service3Info())
test.Assert(t, svcs.SearchService("", "mock", true) == nil)
svcs.refuseTrafficWithoutServiceName = true
test.Assert(t, svcs.SearchService("", "mock", false) == nil)
test.Assert(t, len(svcs.conflictingMethodMap) == 1)
test.Assert(t, !svcs.conflictingMethodMap["mock"])
serviceName, methodName string
strict bool
expectSvcInfo *serviceinfo.ServiceInfo

err = svcs.addService(mocks.Service2Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true})
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService3] is already registered as a fallback service")
expectErr bool
}{
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: false,
},
{
svcInfo: mocks.Service2Info(),
isFallbackService: true,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: true,
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: false,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: true,
},
},
expectErr: false,
serviceName: mocks.MockServiceName,
methodName: mocks.MockMethod,
expectSvcInfo: mocks.ServiceInfo(),
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: false,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: true,
},
},
expectErr: false,
serviceName: "",
methodName: mocks.MockMethod,
expectSvcInfo: mocks.Service3Info(),
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: true,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: false,
serviceName: mocks.MockService3Name,
methodName: mocks.MockMethod,
expectSvcInfo: mocks.Service3Info(),
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: true,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: false,
serviceName: "",
methodName: mocks.MockMethod,
expectSvcInfo: mocks.ServiceInfo(),
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: false,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: true,
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: false,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
refuseTrafficWithoutServiceName: true,
expectErr: false,
serviceName: "",
methodName: mocks.MockMethod,
expectSvcInfo: nil,
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: true,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: false,
serviceName: serviceinfo.GenericService,
methodName: mocks.MockMethod,
expectSvcInfo: mocks.ServiceInfo(),
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: true,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: false,
serviceName: "xxxxxx",
methodName: mocks.MockExceptionMethod,
expectSvcInfo: mocks.ServiceInfo(),
},
{
svcs: []svc{
{
svcInfo: mocks.ServiceInfo(),
isFallbackService: true,
},
{
svcInfo: mocks.Service3Info(),
isFallbackService: false,
},
},
expectErr: false,
serviceName: "xxxxxx",
methodName: mocks.MockMethod,
expectSvcInfo: nil,
},
}
for _, tcase := range testcases {
svcs := newServices()
for _, svc := range tcase.svcs {
svcs.addService(svc.svcInfo, mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: svc.isFallbackService})
}
if tcase.expectErr {
test.Assert(t, svcs.check(tcase.refuseTrafficWithoutServiceName) != nil)
} else {
test.Assert(t, svcs.check(tcase.refuseTrafficWithoutServiceName) == nil)
test.Assert(t, svcs.SearchService(tcase.serviceName, tcase.methodName, tcase.strict) == tcase.expectSvcInfo)
}
}
}

func TestCheckMultipleFallbackService(t *testing.T) {
Expand Down

0 comments on commit 4528947

Please sign in to comment.