Skip to content

Commit

Permalink
refactor(multi-services): refactoring service and method routing for …
Browse files Browse the repository at this point in the history
…multi-services (#1439)
  • Loading branch information
jayantxie authored Jul 22, 2024
1 parent ec5a564 commit bdd7536
Show file tree
Hide file tree
Showing 23 changed files with 241 additions and 336 deletions.
64 changes: 64 additions & 0 deletions internal/mocks/remote/servicesearcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package remote

import (
"github.com/cloudwego/kitex/internal/mocks"
"github.com/cloudwego/kitex/pkg/serviceinfo"
)

type MockSvcSearcher struct {
svcMap map[string]*serviceinfo.ServiceInfo
methodSvcMap map[string]*serviceinfo.ServiceInfo
}

func NewMockSvcSearcher(svcMap, methodSvcMap map[string]*serviceinfo.ServiceInfo) *MockSvcSearcher {
return &MockSvcSearcher{svcMap: svcMap, methodSvcMap: methodSvcMap}
}

func NewDefaultSvcSearcher() *MockSvcSearcher {
svcInfo := mocks.ServiceInfo()
s := map[string]*serviceinfo.ServiceInfo{
mocks.MockServiceName: svcInfo,
}
m := map[string]*serviceinfo.ServiceInfo{
mocks.MockMethod: svcInfo,
mocks.MockExceptionMethod: svcInfo,
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
return &MockSvcSearcher{svcMap: s, methodSvcMap: m}
}

func (s *MockSvcSearcher) SearchService(svcName, methodName string, strict bool) *serviceinfo.ServiceInfo {
if strict {
if svc := s.svcMap[svcName]; svc != nil {
return svc
}
return nil
}
var svc *serviceinfo.ServiceInfo
if svcName == "" {
svc = s.methodSvcMap[methodName]
} else {
svc = s.svcMap[svcName]
}
if svc != nil {
return svc
}
return nil
}
15 changes: 3 additions & 12 deletions pkg/remote/codec/header_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import (
"github.com/bytedance/gopkg/cloud/metainfo"

"github.com/cloudwego/kitex/internal/mocks"
mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/discovery"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
tm "github.com/cloudwego/kitex/pkg/transmeta"
"github.com/cloudwego/kitex/transport"
)
Expand Down Expand Up @@ -315,17 +315,8 @@ var (

func initServerRecvMsg() remote.Message {
svcInfo := mocks.ServiceInfo()
svcSearchMap := map[string]*serviceinfo.ServiceInfo{
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo,
mocks.MockMethod: svcInfo,
mocks.MockExceptionMethod: svcInfo,
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server, false)
svcSearcher := mocksremote.NewDefaultSvcSearcher()
msg := remote.NewMessageWithNewer(svcInfo, svcSearcher, mockSvrRPCInfo, remote.Call, remote.Server)
return msg
}

Expand Down
17 changes: 5 additions & 12 deletions pkg/remote/codec/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/cloudwego/kitex/internal/mocks"
mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/rpcinfo"
Expand All @@ -30,17 +31,8 @@ func TestSetOrCheckMethodName(t *testing.T) {
ri := rpcinfo.NewRPCInfo(nil, rpcinfo.NewEndpointInfo("", "mock", nil, nil),
rpcinfo.NewServerInvocation(), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
svcInfo := mocks.ServiceInfo()
svcSearchMap := map[string]*serviceinfo.ServiceInfo{
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo,
mocks.MockMethod: svcInfo,
mocks.MockExceptionMethod: svcInfo,
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
svcSearcher := mocksremote.NewDefaultSvcSearcher()
msg := remote.NewMessageWithNewer(svcInfo, svcSearcher, ri, remote.Call, remote.Server)
err := SetOrCheckMethodName("mock", msg)
test.Assert(t, err == nil)
ri = msg.RPCInfo()
Expand All @@ -49,7 +41,8 @@ func TestSetOrCheckMethodName(t *testing.T) {
test.Assert(t, ri.Invocation().MethodName() == "mock")
test.Assert(t, ri.To().Method() == "mock")

msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server, false)
m := map[string]*serviceinfo.ServiceInfo{}
msg = remote.NewMessageWithNewer(svcInfo, mocksremote.NewMockSvcSearcher(m, m), ri, remote.Call, remote.Server)
err = SetOrCheckMethodName("dummy", msg)
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "unknown method dummy")
Expand Down
65 changes: 21 additions & 44 deletions pkg/remote/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (

"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/utils"
"github.com/cloudwego/kitex/transport"
)

Expand Down Expand Up @@ -81,6 +80,12 @@ func NewProtocolInfo(tp transport.Protocol, ct serviceinfo.PayloadCodec) Protoco
}
}

// ServiceSearcher is used to search the service info by service name and method name,
// strict equals to true means the service name must match the registered service name.
type ServiceSearcher interface {
SearchService(svcName, methodName string, strict bool) *serviceinfo.ServiceInfo
}

// Message is the core abstraction for Kitex message.
type Message interface {
RPCInfo() rpcinfo.RPCInfo
Expand Down Expand Up @@ -115,15 +120,14 @@ func NewMessage(data interface{}, svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.R
}

// NewMessageWithNewer creates a new Message and set data later.
func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole, refuseTrafficWithoutServiceName bool) Message {
func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearcher ServiceSearcher, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole) Message {
msg := messagePool.Get().(*message)
msg.rpcInfo = ri
msg.targetSvcInfo = targetSvcInfo
msg.svcSearchMap = svcSearchMap
msg.svcSearcher = svcSearcher
msg.msgType = msgType
msg.rpcRole = rpcRole
msg.transInfo = transInfoPool.Get().(*transInfo)
msg.refuseTrafficWithoutServiceName = refuseTrafficWithoutServiceName
return msg
}

Expand All @@ -139,19 +143,18 @@ func newMessage() interface{} {
}

type message struct {
msgType MessageType
data interface{}
rpcInfo rpcinfo.RPCInfo
targetSvcInfo *serviceinfo.ServiceInfo
svcSearchMap map[string]*serviceinfo.ServiceInfo
rpcRole RPCRole
compressType CompressType
payloadSize int
transInfo TransInfo
tags map[string]interface{}
protocol ProtocolInfo
payloadCodec PayloadCodec
refuseTrafficWithoutServiceName bool
msgType MessageType
data interface{}
rpcInfo rpcinfo.RPCInfo
targetSvcInfo *serviceinfo.ServiceInfo
svcSearcher ServiceSearcher
rpcRole RPCRole
compressType CompressType
payloadSize int
transInfo TransInfo
tags map[string]interface{}
protocol ProtocolInfo
payloadCodec PayloadCodec
}

func (m *message) zero() {
Expand Down Expand Up @@ -190,19 +193,7 @@ func (m *message) SpecifyServiceInfo(svcName, methodName string) (*serviceinfo.S
}
return m.targetSvcInfo, nil
}
if svcName == "" && m.refuseTrafficWithoutServiceName {
return nil, NewTransErrorWithMsg(NoServiceName, "no service name while the server has WithRefuseTrafficWithoutServiceName option enabled")
}
var key string
// when client does not pass svcName or passes a special service name, fallback to searching for svcInfo by method name alone
// note: This special name fallback logic shouldn't ideally be here, but it's needed to keep compatibility with older versions (<= v0.10.1)
// due to a mistake in the first release of the multi-service feature
if svcName == "" || isSpecialName(svcName) {
key = methodName
} else {
key = BuildMultiServiceKey(svcName, methodName)
}
svcInfo := m.svcSearchMap[key]
svcInfo := m.svcSearcher.SearchService(svcName, methodName, false)
if svcInfo == nil {
return nil, NewTransErrorWithMsg(UnknownService, fmt.Sprintf("unknown service %s, method %s", svcName, methodName))
}
Expand Down Expand Up @@ -290,10 +281,6 @@ func (m *message) Recycle() {
messagePool.Put(m)
}

func isSpecialName(svcName string) bool {
return svcName == serviceinfo.GenericService || svcName == serviceinfo.CombineService || svcName == serviceinfo.CombineService_
}

// TransInfo contains transport information.
type TransInfo interface {
TransStrInfo() map[string]string
Expand Down Expand Up @@ -373,13 +360,3 @@ func FillSendMsgFromRecvMsg(recvMsg, sendMsg Message) {
sendMsg.SetProtocolInfo(recvMsg.ProtocolInfo())
sendMsg.SetPayloadCodec(recvMsg.PayloadCodec())
}

// BuildMultiServiceKey is used to create a key to search svcInfo from svcSearchMap.
func BuildMultiServiceKey(serviceName, methodName string) string {
var builder utils.StringBuilder
builder.Grow(len(serviceName) + len(methodName) + 1)
builder.WriteString(serviceName)
builder.WriteString(".")
builder.WriteString(methodName)
return builder.String()
}
5 changes: 1 addition & 4 deletions pkg/remote/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (o *Option) AppendBoundHandler(h BoundHandler) {
type ServerOption struct {
TargetSvcInfo *serviceinfo.ServiceInfo

SvcSearchMap map[string]*serviceinfo.ServiceInfo
SvcSearcher ServiceSearcher

TransServerFactory TransServerFactory

Expand Down Expand Up @@ -113,9 +113,6 @@ type ServerOption struct {

GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error

// RefuseTrafficWithoutServiceName is used for a server with multi services
RefuseTrafficWithoutServiceName bool

Option

// invoking chain with recv/send middlewares for streaming APIs
Expand Down
6 changes: 3 additions & 3 deletions pkg/remote/trans/default_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.
svrHdlr := &svrTransHandler{
opt: opt,
codec: opt.Codec,
svcSearchMap: opt.SvcSearchMap,
svcSearcher: opt.SvcSearcher,
targetSvcInfo: opt.TargetSvcInfo,
ext: ext,
}
Expand All @@ -50,7 +50,7 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.

type svrTransHandler struct {
opt *remote.ServerOption
svcSearchMap map[string]*serviceinfo.ServiceInfo
svcSearcher remote.ServiceSearcher
targetSvcInfo *serviceinfo.ServiceInfo
inkHdlFunc endpoint.Endpoint
codec remote.Codec
Expand Down Expand Up @@ -167,7 +167,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
}()
ctx = t.startTracer(ctx, ri)
ctx = t.startProfiler(ctx)
recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName)
recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearcher, ri, remote.Call, remote.Server)
recvMsg.SetPayloadCodec(t.opt.PayloadCodec)
ctx, err = t.transPipe.Read(ctx, conn, recvMsg)
if err != nil {
Expand Down
22 changes: 7 additions & 15 deletions pkg/remote/trans/default_server_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/golang/mock/gomock"

"github.com/cloudwego/kitex/internal/mocks"
remotemocks "github.com/cloudwego/kitex/internal/mocks/remote"
"github.com/cloudwego/kitex/internal/mocks/stats"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/kerrors"
Expand All @@ -34,17 +35,8 @@ import (
)

var (
svcInfo = mocks.ServiceInfo()
svcSearchMap = map[string]*serviceinfo.ServiceInfo{
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo,
remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo,
mocks.MockMethod: svcInfo,
mocks.MockExceptionMethod: svcInfo,
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
svcInfo = mocks.ServiceInfo()
svcSearcher = remotemocks.NewDefaultSvcSearcher()
)

func TestDefaultSvrTransHandler(t *testing.T) {
Expand Down Expand Up @@ -72,7 +64,7 @@ func TestDefaultSvrTransHandler(t *testing.T) {
return nil
},
},
SvcSearchMap: svcSearchMap,
SvcSearcher: svcSearcher,
TargetSvcInfo: svcInfo,
}

Expand Down Expand Up @@ -139,7 +131,7 @@ func TestSvrTransHandlerBizError(t *testing.T) {
return nil
},
},
SvcSearchMap: svcSearchMap,
SvcSearcher: svcSearcher,
TargetSvcInfo: svcInfo,
TracerCtl: tracerCtl,
InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
Expand Down Expand Up @@ -198,7 +190,7 @@ func TestSvrTransHandlerReadErr(t *testing.T) {
return mockErr
},
},
SvcSearchMap: svcSearchMap,
SvcSearcher: svcSearcher,
TargetSvcInfo: svcInfo,
TracerCtl: tracerCtl,
InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
Expand Down Expand Up @@ -256,7 +248,7 @@ func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) {
return nil
},
},
SvcSearchMap: svcSearchMap,
SvcSearcher: svcSearcher,
TargetSvcInfo: svcInfo,
TracerCtl: tracerCtl,
InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
Expand Down
Loading

0 comments on commit bdd7536

Please sign in to comment.